coreylowman / cudarc

Safe rust wrapper around CUDA toolkit

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

cudnn support

coreylowman opened this issue · comments

See initial discussion in #1 and #16 for wip

On thread safety:

The cuDNN library is thread-safe. Its functions can be called from multiple host threads, so long as the threads do not share the same cuDNN handle simultaneously.
Source

Re thread safety I think if we enforce that the handles are Arc<Mutex<CudnnHandle>>, we should be fine

or just let each thread spawn a new handle (no data is associated to a handle)

@M1ngXU let's split the cudnn pr into a couple separate PRs, so we just have what we need for dfdx:

  1. Add sys.rs/bindgen.sh/result.rs & link to cudnn in build.rs (you already have this in the big PR, just need to split it out)
    1. i think to start we just need to create/destroy handle, cudnnSetStream
  2. Add tensor descriptors cudnnCreateTensorDescriptor, cudnnDestroyTensorDescriptor, cudnnSetTensorNdDescriptor
  3. Add reduction support
  4. Add conv2d support
  5. Add pool2d support

@M1ngXU let's split the cudnn pr into a couple separate PRs, so we just have what we need for dfdx:

  1. Add sys.rs/bindgen.sh/result.rs & link to cudnn in build.rs (you already have this in the big PR, just need to split it out)

    1. i think to start we just need to create/destroy handle, cudnnSetStream
  2. Add tensor descriptors cudnnCreateTensorDescriptor, cudnnDestroyTensorDescriptor, cudnnSetTensorNdDescriptor

  3. Add reduction support

  4. Add conv2d support

  5. Add pool2d support

ok, i'll do number one now