coreylowman / cudarc

Safe rust wrapper around CUDA toolkit

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

The meaning expressed by CudaSlice::device_ptr_mut is imprecise.

soloist-v opened this issue · comments

let input_data_arr: [*mut std::ffi::c_void; 2] = unsafe {
    [
        std::mem::transmute(*input_data_device.device_ptr()),
        std::mem::transmute(*output_data_device.device_ptr_mut()),
    ]
};
let cu_stream = device.cu_stream();
let ok = context.enqueue_v2(&input_data_arr, cu_stream.cast());
println!("infer ok: {}", ok);

I think it would be better if device_ptr_mut directly returns sys::CUdeviceptr.

I see what you're saying. However this is definitely a breaking change, and I'm not sure if the gains would be worth that?

Mark the old method using #[deprecated].
How about this approach?

use std::ops::{Deref, DerefMut};

struct DeviceSlice {
    ptr: std::ffi::c_ulonglong,
}

impl DeviceSlice {
    fn device_ptr(&self) -> DevicePtr {
        DevicePtr {
            raw: self.ptr,
            _marker: Default::default(),
        }
    }

    fn device_ptr_mut(&self) -> DeviceMutPtr {
        DeviceMutPtr {
            raw: self.ptr,
            _marker: Default::default(),
        }
    }
}

#[derive(Debug)]
#[repr(transparent)]
struct DevicePtr<'a> {
    raw: std::ffi::c_ulonglong,
    _marker: std::marker::PhantomData<&'a ()>,
}

impl Deref for DevicePtr<'_> {
    type Target = std::ffi::c_ulonglong;

    fn deref(&self) -> &Self::Target {
        &self.raw
    }
}

impl DevicePtr<'_> {
    fn cast<T>(&self) -> *const T {
        unsafe { std::mem::transmute(self) }
    }
}

#[derive(Debug)]
#[repr(transparent)]
struct DeviceMutPtr<'a> {
    raw: std::ffi::c_ulonglong,
    _marker: std::marker::PhantomData<&'a mut ()>,
}

impl Deref for DeviceMutPtr<'_> {
    type Target = std::ffi::c_ulonglong;

    fn deref(&self) -> &Self::Target {
        &self.raw
    }
}

impl DerefMut for DeviceMutPtr<'_> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.raw
    }
}

impl DeviceMutPtr<'_> {
    fn cast<T>(&mut self) -> *mut T {
        unsafe { std::mem::transmute(self) }
    }
}

I've thought about it, and the approach mentioned above is equivalent to using reference, but the latter is more concise, which makes it a better choice. Therefore, I think there is no need for any changes.