NVIDIA / warp

A Python framework for high performance GPU simulation and graphics

Home Page:https://nvidia.github.io/warp/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[feature request] Cooperative groups

sweeneychris opened this issue · comments

Thanks so much for providing such a great library. I found warp to be easily usable, easy to inject into existing pytorch code, and most of all it is quite performant! One of the few limitations I have found at the moment is that there is no clear way to leverage cooperative groups from warp. This is basically the only functionality that I leverage in c++ that I cannot recreate with warp at the moment and it does have an appreciable affect on the overall speed of my kernels. My assumption is that custom native functions would not compile with cooperative groups because of missing the #include <cooperative_groups.h> but please correct me if I'm wrong!

Thinking on it more, though, I believe that a small api to cooperative groups would be an extremely useful feature to consider adding to the warp library. Some very basic ability to get the thread block and call .sync() within a kernel would be immensely useful. I am wondering if this is on the warp feature roadmap?

Hi @sweeneychris, thanks for the feedback - you're the first person I've had ask for cooperative groups!

In general we want to support more cooperative operations in Warp, the challenge is finding the right abstraction. Can you let me know a little more about what you'd need, I am guessing just _syncthreads() would not be enough?

Also curious if you're using CG for e.g.: cooperative reductions? Perhaps we could provide a higher-level API for these types of common computations?

Thanks,
Miles

@mmacklin Yes indeed cooperative reductions would be hugely useful and it may be simpler to expose an API around those rather than expose all of the block objects, etc.

In general, the patterns where I use shared memory in ML applications for massive speedups look like this.

Forward passes: fetch data from shared memory, perform per-thread computation using the shared memory data, write to global memory output
Backwards passes: fetch data from shared memory, perform per-thread computation to compute gradients using the shared memory data, perform a block-wise reduction of the gradient value, (atomically) write the gradient value to the global memory output.

In general, I couldn't figure a simple way to "mix and match" the use of shared memory with warp functions. Since shared memory is scoped, it wasn't easy for me to create a minimal example where I can leverage shared memory for fast reads/reductions while using warp for the other bits of code. Seems if you want to leverage shared memory at all, you have to write your entire function in cuda (which is fine! just an opportunity for the warp python api). Any API solution that helps me with the patterns above would be hugely helpful!