coreylowman / cudarc

Safe rust wrapper around CUDA toolkit

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

More cudnn ops

LaurentMazare opened this issue · comments

Thanks for this amazing crate, it's been instrumental to candle. We've recently added a feature to use the cudnn conv2d which sped things up a lot compared to our handcrafted kernel, and would like to have cudnn support for more ops. Mostly interested in:

  • Conv2d backprop.
  • Conv1d forward + backward.
  • Maybe flash-attention/softmax/...
    Are there any plans to add these to the cudnn safe api? If not would you be ok with people making PR to add it?

You can already do conv2d backprop with the existing methods (See https://github.com/coreylowman/dfdx/blob/main/src/tensor_ops/conv2d/cudnn_kernel.rs#L91).

Conv1d - I'm not sure this exists in cudnn? Not sure about flash attention existence either (or at least it didn't when I last checked cudnn)

But yes open to any contributions here!

Ah great that the conv2d backward step is already there, we'll add it to candle.

For the conv1d, there is some support for Nd convolution I think, e.g. in the cudarc ffi so hopefully having a safe api around this would enable 1d convolutions.

For flash attention, I meant this fused flash attn fprop, though I've actually never used it.

Oh sweet I missed the Nd convolution, nice! Should be able to add that.

If I'm understanding the flash attn thing, it seems like that is something detected at runtime if you're using the cudnn graph API?