coreylowman / cudarc

Safe rust wrapper around CUDA toolkit

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

cudarc CUStream_st and `cuda_runtime`

jeromeku opened this issue · comments

Thanks for the great crate!

2 quick questions:

  • I've generated bindings for a CUDA C library that generates CUStream_st and CUStream types because cuda_runtime is included in the originating header.

Briefly, the library is used to launch kernels such as:

extern "C" {
    pub fn add_kernel_default(
        stream: CUstream,
        x_ptr: CUdeviceptr,
        y_ptr: CUdeviceptr,
        output_ptr: CUdeviceptr,
        n_elements: i32,
    ) -> CUresult;
}

When calling this method, the stream type CUStream is defined inside the generated bindgen bindings. This conflicts with the cudarc CUStream_st and CUstream types that I get through the device.cu_stream(), though the definitions are identical.
What am I doing incorrectly, as this seems like a opaque blunder on my part?

  • I can fix the above by manually pasting the generated bindings and replacing the aforementioned definitions with the cudarc versions but this seems unnecessarily hacky.

Thoughts?

Here is the complete code (where launcher contains the generated bindings for the C CUDA lib):

use anyhow::Result;
use cudarc::driver::CudaDevice;
use cudarc::driver::DevicePtr;
use cudarc::driver::DevicePtrMut;

use triton_rs::launcher::{add_kernel_default, load_add_kernel};

fn main() -> Result<()> {
    const N: usize = 1024;
    let device = CudaDevice::new(0).unwrap();

    let mut a_buf: [f32; N] = [0.0; N];

    let x = device.htod_copy(vec![1.0_f32; N]).unwrap();
    let y = device.htod_copy(vec![1.0_f32; N]).unwrap();
    let mut out = device.alloc_zeros::<f32>(N).unwrap();
    let stream = *device.cu_stream();

    unsafe {
        load_add_kernel();
    }

    unsafe {
        add_kernel_default(
            stream,
            *x.device_ptr(),
            *y.device_ptr(),
            *out.device_ptr_mut(),
            N.try_into().unwrap(),
        );
    }

    device.dtoh_sync_copy_into(&out, &mut a_buf)?;
    let sum = a_buf.into_iter().sum::<f32>();
    println!("{:?}", sum);

    Ok(())
}

You aren't doing anything incorrectly - this is just a outcome of using multiple bindgens. bindgen might have a built in way to handle this but I haven't looked into it in depth.

Since most (if not all?) of the CUDA types are pointers underneath the hood, you can just cast it to the correct bindgen type:

Try this:

        add_kernel_default(
            stream,
            *x.device_ptr() as _, // rust should auto infer the correct pointer type here
            *y.device_ptr() as _,
            *out.device_ptr_mut(),
            N.try_into().unwrap(),
        );