coreylowman / dfdx

Deep learning in Rust, with shape checked tensors and neural networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Split `TryConcatAlong` into different traits

swfsql opened this issue · comments

This issue is a request to separate the TryConcatAlong trait into two different traits, one for Tensors and another for Shapes.

edit: A small example on disc.


I'm testing creating structs to represent tensor operations to be used as Modules inside model definitions, but when trying to create a Module to represent Tensor concatenation, the rust type/trait system go crazy on the trait bounds. Usually the shape of a tensor is a generic parameter that's actually a "shape", but when doing something like:

impl<A, B, Ax, E: Dtype, D, T: Tape<E, D>, R: Tape<E, D>>
    Module<(Tensor<A, E, D, T>, Tensor<B, E, D, R>)> for ConcatTensorAlong<Ax>
where
    (A, B): TryConcatAlong<Ax>, // <- this line is problematic
    // etc
{
    type Output = /**/;

    fn try_forward(
        &self,
        x: (Tensor<A, E, D, T>, Tensor<B, E, D, R>),
    ) -> Result<Self::Output, Error> {
        // etc
    }
}

In this case, ConcatTensorAlong would be the struct representing a Module.
But when adding what was indicated by as "problematic"*, the rust type system (for some reason) insists in considering A and B to be yet other Tensors, so it tries to verify the trait bounds of Tensor<Tensor<Tensor<...>>> recursively, until it crashes.

*but only if I actually try to use ConcatTensorAlong inside a Model and call forward on it.
This is the error I get:

error[E0275]: overflow evaluating the requirement `((_, _, _, _), (..., ..., ..., ...)): dfdx::prelude::TryConcatAlong<...>`
   --> src/c4/w3/pa_02_image_segmentation.rs:211:35
    |
211 |         let prediction: _ = model.forward(x);
    |                                   ^^^^^^^
    |
    = help: consider increasing the recursion limit by adding a `#![recursion_limit = "256"]` attribute to your crate (`coursera_exercises`)
    = note: required for `(dfdx::prelude::Tensor<(_, _, _, _), _, _, _>, dfdx::prelude::Tensor<(_, _, _, _), _, _, _>)` to implement `dfdx::prelude::TryConcatAlong<dfdx::prelude::Axis<1>>`
    = note: 123 redundant requirements hidden
    = note: required for `(Tensor<Tensor<Tensor<Tensor<..., ..., ..., ...>, ..., ..., ...>, ..., ..., ...>, ..., ..., ...>, ...)` to implement `dfdx::prelude::TryConcatAlong<dfdx::prelude::Axis<1>>`
    = note: the full type name has been written to '/workspaces/coursera-deep-learning-specialization/r/target/debug/deps/coursera_exercises-ec2247d6c3cc6475.long-type-15920087071787293129.txt'
    = note: required for `dfdx::nn::ops::ConcatTensorAlong<dfdx::prelude::Axis<1>>` to implement `dfdx::nn::Module<(dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<(_, _, _, _), _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, 

And well, increasing the recursion limit doesn't help. At first I didn't even read the thing, assumed that my model was too big or something, and increased the recursion to 1024 but noticed that the rust compiler itself crashed instead.

Currently I've noticed that if TryConcatAlong is separated into two different traits, one for Tensor and another for Shape, and adjusting the trait bounds accordingly around the code, rust no longer crashes for that kind of Module definition.