jafioti / luminal

Deep learning at the speed of light.

Home Page:https://luminalai.com

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Improve diagnostic for compile-time checks

npuichigo opened this issue · comments

Thanks for your great work first. But the original error message is not good for compile-time checks. When I intentionally caused a shape mistake to examples/simple, the message is confusing.

error[E0277]: the trait bound `GraphTensor<(luminal::shape::Const<4>,)>: Matmul<(luminal::shape::Const<5>, luminal::shape::Const<5>)>` is not satisfied
  --> examples/simple/src/main.rs:12:27
   |
12 |     let b = model.forward(a).retrieve();
   |                   ------- ^ the trait `Matmul<(luminal::shape::Const<5>, luminal::shape::Const<5>)>` is not implemented for `GraphTensor<(luminal::shape::Const<4>,)>`, which is required by `Linear<5, 5>: luminal::module::Module<GraphTensor<_>>`
   |                   |
   |                   required by a bound introduced by this call
   |
   = help: the following other types implement trait `Matmul<S>`:
             <GraphTensor<(A,)> as Matmul<(A, B)>>
             <GraphTensor<(A, B)> as Matmul<(B, C)>>
             <GraphTensor<(A, B, C)> as Matmul<(A, C, D)>>
             <GraphTensor<(A, B, C)> as Matmul<(C, D)>>
             <GraphTensor<(A, B, C, D)> as Matmul<(A, B, D, E)>>
             <GraphTensor<(A, B, C, D, E)> as Matmul<(A, B, C, E, F)>>
   = note: required for `Linear<5, 5>` to implement `luminal::module::Module<GraphTensor<(luminal::shape::Const<4>,)>>`

Rust 1.78 has a new feature to support custom diagnostic message https://blog.rust-lang.org/2024/05/02/Rust-1.78.0.html#diagnostic-attributes. Let me have a try.

#[diagnostic::on_unimplemented(
    message = "`{Self}` and `GraphTensor<{S}>` shapes cannot be multiplied",
    label = "Input tensor {Self}",
)]
pub trait Matmul<S: Shape> {
    type Output;
    fn matmul(self, rhs: GraphTensor<S>) -> Self::Output;
}

Now the message is much more readable.

error[E0277]: `GraphTensor<(luminal::shape::Const<4>,)>` and `GraphTensor<(luminal::shape::Const<5>, luminal::shape::Const<5>)>` shapes cannot be multiplied
  --> examples/simple/src/main.rs:12:27
   |
12 |     let b = model.forward(a).retrieve();
   |                   ------- ^ Input tensor GraphTensor<(luminal::shape::Const<4>,)>
   |                   |
   |                   required by a bound introduced by this call
   |
   = help: the trait `Matmul<(luminal::shape::Const<5>, luminal::shape::Const<5>)>` is not implemented for `GraphTensor<(luminal::shape::Const<4>,)>`, which is required by `Linear<5, 5>: luminal::module::Module<GraphTensor<_>>`
   = help: the following other types implement trait `Matmul<S>`:
             <GraphTensor<(A,)> as Matmul<(A, B)>>
             <GraphTensor<(A, B)> as Matmul<(B, C)>>
             <GraphTensor<(A, B, C)> as Matmul<(A, C, D)>>
             <GraphTensor<(A, B, C)> as Matmul<(C, D)>>
             <GraphTensor<(A, B, C, D)> as Matmul<(A, B, D, E)>>
             <GraphTensor<(A, B, C, D, E)> as Matmul<(A, B, C, E, F)>>
   = note: required for `Linear<5, 5>` to implement `luminal::module::Module<GraphTensor<(luminal::shape::Const<4>,)>>`

So I open this issue to persue a more user-friendly diagnoisic messages for compile-time check.

That looks awesome, didnt even know that was possible! I'll look into adding these around where applicable. I think it won't be as straightforward since most operations aren't implemented with custom traits like matmul is, but for the ones where that applies it will be helpful

Do you know if this is possible without a custom trait?

any example? I'm not quite sure what do you mean

As in, this works for matmul because matmul is a trait that gets implemented for valid pairs of tensors. Most operations are elementwise so they require the same shape, like Add. So that's just implemented in the std Add trait but since that's not a luminal trait there's no way to get a custom error message

It's a pity. I'll look into the official forum for the user case.

Ok for now I'll close this since I believe we can't add error messages to non-trait ops. Feel free to open it again if anything else is found.