coreylowman / cudarc

Safe rust wrapper around CUDA toolkit

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Issue with thread safety intrinsics of `Arc<CudaDevice>`

Narsil opened this issue · comments

I am not sure wrapping the CudaDevice within an Arc is living to it's full potential.

Here is a simple crate showcasing the issue:

https://github.com/Narsil/axum_cudarc

It seems that it's an Arc the device can be shared, however, the allocation doesn't work meaning that the cuda ctx is not really thread safe.
Seems like a gotcha to me, and we shouldn't be able to do that.
I'm still not sure what's the best way to include GPU code on a webserver (single thread owning the GPU and communication with CPU threads on the webserver, vs all webservers threads accessing the GPU and reusing 1 stream for ordering (if wanted).

In any case, it seems that the current behavior is not desirable.

Did I miss something ?

i think that it's not thread safe if both thread use it at the same time, but since it's on the same stream, it's safe (at least earler lol).

the issue that replaced Rc with Arc: #10

Hmm looking at the context management section from https://docs.nvidia.com/cuda/cuda-driver-api/driver-vs-runtime-api.html#driver-vs-runtime-api, it seems like we may need to do more work for this to work with the driver api.

If i'm understanding the docs correctly, I thikn we'd just need to call cuCtxSetCurrent with the primary ctx on each thread.

cuCtxGetCurrent supposedly returns null when there is no context bound to the current thread. Perhaps some of the API calls can do that check and bind the primary ctx to the thread if its not already? I'm not sure about the overhead of doing this

Yeah I was able to reproduce the error, and it's fixed if you call cuCtxSetCurrent on each thread.

My question is:

  1. Is it reasonable to expect people to call dev.bind_to_thread()?
  2. Do either of you know of any way to automatically do something when an object is sent to another thread?

I'm wondering if there's a way to automatically call cuCtxSetCurrent when the thread starts, without adding a check to all CudaDevice methods whether a context is set or not.

2. Do either of you know of any way to automatically do something when an object is sent to another thread?

maybe only pass a "fake" object between thread that needs to be bound to get the "real" object?

I like that idea, but there's definitely ways around it which could get confusing.

We could start using the runtime API instead, but we'd still need to use the driver API to load modules.

Maybe it's not actually that big a deal to insert a self.bind_to_thread() call into all the CudaDevice functions. 🤔

this would decrease performance though (just sliiihghtly). how about a function “to_sendable” that returns an object that only has “bind_to_thread”? this would also make sure that it’s impossible to “forget” a call to bind_to_thread on a function

I like the idea of having 2 objects, and the sendable object that cannot be used without bounding to the thread.

Means Arc can be removed too iiuc. (Arc is only used to make CudaDevice sendable right ?)

struct CudaDevice{ // Not Send
    handle: Arc<Handle> // Send
}

impl Handle{
    fn bound() -> CudaDevice;
}

Maybe ?

(Also I'm still unsure if it's a good idea to control cuda from 2 different threads altogether). I worked around that by having a single threading owning the GPU and sending commands to it from the webserver.

It means it cannot queue more work in the GPU, while sending data back to caller, but if that part is fast enough it shouldn't matter really

It'd be good to understand the different use cases for multi-threading, and how people might expect using the same vs different devices to work.

If all threads are using the same device (as all of our examples have so far), then I think the jobs each thread submit to the GPU will be executed sequentially. This is because, as far as I understand, they will all share the same underlying GPU context & stream.

If users would expect concurrent execution, we can work towards that (maybe just force people to instantiate a new device in each thread).

Arc/Rc around device is to actually enforce device lifetimes. This was done to avoid having CudaSlice<'a, T> everywhere (so instead everything that depends on CudaDevice, holds a reference to it).

I originally just had Rc, but added Arc for thread stuff. We'd need one of the two at least