jafioti / luminal

Deep learning at the speed of light.

Home Page:https://luminalai.com

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Access a tensor's `Shape`

swfsql opened this issue · comments

Hi, this issue is mostly revolving around an instance of implementing a model. Is it possible to support this?

I'm trying to define a LinearBiased generic model, which can wrap around the current Linear (unbiased) one, or be based on it and use the matmul operation directly (the problem persists for both cases).

Example:

pub struct LinearBiased<const A: usize, const B: usize> {
    pub linear: Linear<A, B>,
    pub bias: GraphTensor<R1<B>>,
}

When trying to have a single generic forward implementation, one difficulty is accessing the tensor Shape type (compile-time) information after the matmul (or unbiased linear) operation, as the next data type (to be used to get added with the bias) becomes an associated <Linear as Model>::Output type.

One option, in which I haven't explored nor tested much, is having a trait that allows accessing a tensor's shape:

pub trait TensorHasShape {
    type Shape: Shape;
    type WithShape<New: Shape>: TensorShape<Shape = New>;
    fn shape_tracker(&self) -> &ShapeTracker;
}

impl<S: Shape> TensorHasShape for GraphTensor<S> {
    type Shape = S;
    type WithShape<New: Shape> = GraphTensor<New>;
    fn shape_tracker(&self) -> &ShapeTracker {
        &self.shape
    }
}

Note: I have experienced some weird compilation errors in dfdx when the HasShape was unified between both shapes and tensors, so if this got adopted I'd recommend having them split apart.

With this, it's then possible to access <Linear as Model<Input>>::Output, then < as TensorHasShape>::Shape, and then use those to define the necessary BroadcastShapeTo and std::ops::Add bounds for the bias.

Thanks in advance!

I think you're on the right track with the TensorHasShape trait. What would the full bound look like for the LinearBiased if you had access to that trait?

It looks somewhat bloated, but it looks like this, if wrapping Linear:

impl<const A: usize, const B: usize, Input> Module<Input> for LinearBiased<A, B>
where
    Linear<A, B>: Module<Input>,
    <Linear<A, B> as Module<Input>>::Output: TensorShape,
    <<Linear<A, B> as Module<Input>>::Output as TensorShape>::Shape: Shape,
    R1<B>: BroadcastShapeTo<
        <<Linear<A, B> as Module<Input>>::Output as TensorShape>::Shape,
        <<<Linear<A, B> as Module<Input>>::Output as TensorShape>::Shape as Shape>::AllButLast,
    >,
    <Linear<A, B> as Module<Input>>::Output:
        std::ops::Add<GraphTensor<<<Linear<A, B> as Module<Input>>::Output as TensorShape>::Shape>>,
{
    type Output = <<Linear<A, B> as Module<Input>>::Output as std::ops::Add<
        GraphTensor<<<Linear<A, B> as Module<Input>>::Output as TensorShape>::Shape>,
    >>::Output;
    fn forward(&self, x: Input) -> Self::Output {
        let x = self.linear.forward(x);
        x + self.bias.clone().expand()
    }
}

By using that bound expansion experiment I mentioned in discord, it looks like this:

impl<const A: usize, const B: usize, Input> Module<Input> for LinearBiased<A, B> {
    #![tag(
        add(name = L1, ty = Linear::<A, B>, bound = Module::<Input>),
        add(name = L1TensorShape, ty = L1::Output, bound = TensorShape),
        add(name = L1Shape, ty = L1TensorShape::Shape, bound = Shape),
        add(name = _BiasExpand, ty = R1::<B>, bound = BroadcastShapeTo::<L1TensorShape::Shape, L1Shape::AllButLast>),
        add(name = L2, ty = L1::Output, bound = std::ops::Add::<GraphTensor::<L1TensorShape::Shape>>),
    )]
    #[tag(expand)]
    type Output = L2::Output;
    fn forward(&self, x: Input) -> Self::Output {
        let x = self.linear.forward(x);
        x + self.bias.clone().expand()
    }
}

(the former is actually just the expanded version of the one using proc-macros)

In both cases they use the struct:

pub struct LinearBiased<const A: usize, const B: usize> {
    pub linear: Linear<A, B>,
    pub bias: GraphTensor<R1<B>>,
}

I've just noted that to implement Module using generic inputs, we also need more information on that TensorHasShape trait, such as the WithShape<New: Shape>. I've edited the first post to reflect this.
I also left in there a "shape_tracker" method. Dfdx is able to return a reference to the actual shape (a ref to a type that impl Shape), but I'm not sure if this is possible on luminal - I'm always quite confused when dealing with shapes so idk.

Finally, the GraphTensor methods would need to be implementations from traits also, since we don't know we are necessarily dealing with a eg. GraphTensor on the forward impl - ie. to use a sum_reduce there would need to be a trait SumReduce that GraphTensor implements.

I'd like to experiment, and see how it would look like - I'm aware this is not necessarily the way that luminal would want to go. But one note is that this would also be closer to how dfdx has done those implementations.

It's also possible to wrap around the current methods, and in this way it's not necessary to make changes on luminal, such as this, based on dfdx:

// this code can be on client, ie. not on luminal

pub trait SumReduceTo: TensorHasShape {
    fn sum_reduce<Dst: Shape, Ax: Axes>(self) -> Self::WithShape<Dst>
    where
        Self::Shape: HasAxes<Ax> + ReduceShapeTo<Dst, Ax>;
}

impl<S: Shape> SumReduceTo for GraphTensor<S> {
    fn sum_reduce<Dst: Shape, Ax: Axes>(self) -> Self::WithShape<Dst>
    where
        Self::Shape: HasAxes<Ax> + ReduceShapeTo<Dst, Ax>,
    {
        // this uses the luminal internal GraphTensor method 
        <GraphTensor<S>>::sum_reduce(self)
    }
}

This has been resolved now with runtime shapes. The new linear has support for an optional bias in it.

Forward function:

impl Module<GraphTensor> for Linear {
    type Output = GraphTensor;

    fn forward(&self, input: GraphTensor) -> Self::Output {
        let mut output = input.matmul(if self.permute {
            self.weight.permute((1, 0))
        } else {
            self.weight
        });
        if let Some(bias) = self.bias {
            output += bias.expand_to(output.shape);
        }
        output
    }
}