cudarc CUStream_st and `cuda_runtime`
jeromeku opened this issue · comments
Thanks for the great crate!
2 quick questions:
- I've generated bindings for a
CUDA C
library that generatesCUStream_st
andCUStream
types becausecuda_runtime
is included in the originating header.
Briefly, the library is used to launch kernels such as:
extern "C" {
pub fn add_kernel_default(
stream: CUstream,
x_ptr: CUdeviceptr,
y_ptr: CUdeviceptr,
output_ptr: CUdeviceptr,
n_elements: i32,
) -> CUresult;
}
When calling this method, the stream type CUStream
is defined inside the generated bindgen
bindings. This conflicts with the cudarc
CUStream_st
and CUstream
types that I get through the device.cu_stream()
, though the definitions are identical.
What am I doing incorrectly, as this seems like a opaque blunder on my part?
- I can fix the above by manually pasting the generated bindings and replacing the aforementioned definitions with the
cudarc
versions but this seems unnecessarily hacky.
Thoughts?
Here is the complete code (where launcher
contains the generated bindings for the C CUDA
lib):
use anyhow::Result;
use cudarc::driver::CudaDevice;
use cudarc::driver::DevicePtr;
use cudarc::driver::DevicePtrMut;
use triton_rs::launcher::{add_kernel_default, load_add_kernel};
fn main() -> Result<()> {
const N: usize = 1024;
let device = CudaDevice::new(0).unwrap();
let mut a_buf: [f32; N] = [0.0; N];
let x = device.htod_copy(vec![1.0_f32; N]).unwrap();
let y = device.htod_copy(vec![1.0_f32; N]).unwrap();
let mut out = device.alloc_zeros::<f32>(N).unwrap();
let stream = *device.cu_stream();
unsafe {
load_add_kernel();
}
unsafe {
add_kernel_default(
stream,
*x.device_ptr(),
*y.device_ptr(),
*out.device_ptr_mut(),
N.try_into().unwrap(),
);
}
device.dtoh_sync_copy_into(&out, &mut a_buf)?;
let sum = a_buf.into_iter().sum::<f32>();
println!("{:?}", sum);
Ok(())
}
You aren't doing anything incorrectly - this is just a outcome of using multiple bindgens. bindgen might have a built in way to handle this but I haven't looked into it in depth.
Since most (if not all?) of the CUDA types are pointers underneath the hood, you can just cast it to the correct bindgen type:
Try this:
add_kernel_default(
stream,
*x.device_ptr() as _, // rust should auto infer the correct pointer type here
*y.device_ptr() as _,
*out.device_ptr_mut(),
N.try_into().unwrap(),
);