Improve diagnostic for compile-time checks
npuichigo opened this issue · comments
Thanks for your great work first. But the original error message is not good for compile-time checks. When I intentionally caused a shape mistake to examples/simple
, the message is confusing.
error[E0277]: the trait bound `GraphTensor<(luminal::shape::Const<4>,)>: Matmul<(luminal::shape::Const<5>, luminal::shape::Const<5>)>` is not satisfied
--> examples/simple/src/main.rs:12:27
|
12 | let b = model.forward(a).retrieve();
| ------- ^ the trait `Matmul<(luminal::shape::Const<5>, luminal::shape::Const<5>)>` is not implemented for `GraphTensor<(luminal::shape::Const<4>,)>`, which is required by `Linear<5, 5>: luminal::module::Module<GraphTensor<_>>`
| |
| required by a bound introduced by this call
|
= help: the following other types implement trait `Matmul<S>`:
<GraphTensor<(A,)> as Matmul<(A, B)>>
<GraphTensor<(A, B)> as Matmul<(B, C)>>
<GraphTensor<(A, B, C)> as Matmul<(A, C, D)>>
<GraphTensor<(A, B, C)> as Matmul<(C, D)>>
<GraphTensor<(A, B, C, D)> as Matmul<(A, B, D, E)>>
<GraphTensor<(A, B, C, D, E)> as Matmul<(A, B, C, E, F)>>
= note: required for `Linear<5, 5>` to implement `luminal::module::Module<GraphTensor<(luminal::shape::Const<4>,)>>`
Rust 1.78 has a new feature to support custom diagnostic message https://blog.rust-lang.org/2024/05/02/Rust-1.78.0.html#diagnostic-attributes. Let me have a try.
#[diagnostic::on_unimplemented(
message = "`{Self}` and `GraphTensor<{S}>` shapes cannot be multiplied",
label = "Input tensor {Self}",
)]
pub trait Matmul<S: Shape> {
type Output;
fn matmul(self, rhs: GraphTensor<S>) -> Self::Output;
}
Now the message is much more readable.
error[E0277]: `GraphTensor<(luminal::shape::Const<4>,)>` and `GraphTensor<(luminal::shape::Const<5>, luminal::shape::Const<5>)>` shapes cannot be multiplied
--> examples/simple/src/main.rs:12:27
|
12 | let b = model.forward(a).retrieve();
| ------- ^ Input tensor GraphTensor<(luminal::shape::Const<4>,)>
| |
| required by a bound introduced by this call
|
= help: the trait `Matmul<(luminal::shape::Const<5>, luminal::shape::Const<5>)>` is not implemented for `GraphTensor<(luminal::shape::Const<4>,)>`, which is required by `Linear<5, 5>: luminal::module::Module<GraphTensor<_>>`
= help: the following other types implement trait `Matmul<S>`:
<GraphTensor<(A,)> as Matmul<(A, B)>>
<GraphTensor<(A, B)> as Matmul<(B, C)>>
<GraphTensor<(A, B, C)> as Matmul<(A, C, D)>>
<GraphTensor<(A, B, C)> as Matmul<(C, D)>>
<GraphTensor<(A, B, C, D)> as Matmul<(A, B, D, E)>>
<GraphTensor<(A, B, C, D, E)> as Matmul<(A, B, C, E, F)>>
= note: required for `Linear<5, 5>` to implement `luminal::module::Module<GraphTensor<(luminal::shape::Const<4>,)>>`
So I open this issue to persue a more user-friendly diagnoisic messages for compile-time check.
That looks awesome, didnt even know that was possible! I'll look into adding these around where applicable. I think it won't be as straightforward since most operations aren't implemented with custom traits like matmul is, but for the ones where that applies it will be helpful
Do you know if this is possible without a custom trait?
any example? I'm not quite sure what do you mean
As in, this works for matmul because matmul is a trait that gets implemented for valid pairs of tensors. Most operations are elementwise so they require the same shape, like Add. So that's just implemented in the std Add trait but since that's not a luminal trait there's no way to get a custom error message
It's a pity. I'll look into the official forum for the user case.
Ok for now I'll close this since I believe we can't add error messages to non-trait ops. Feel free to open it again if anything else is found.