The meaning expressed by CudaSlice::device_ptr_mut is imprecise.
soloist-v opened this issue · comments
let input_data_arr: [*mut std::ffi::c_void; 2] = unsafe {
[
std::mem::transmute(*input_data_device.device_ptr()),
std::mem::transmute(*output_data_device.device_ptr_mut()),
]
};
let cu_stream = device.cu_stream();
let ok = context.enqueue_v2(&input_data_arr, cu_stream.cast());
println!("infer ok: {}", ok);
I think it would be better if device_ptr_mut directly returns sys::CUdeviceptr.
I see what you're saying. However this is definitely a breaking change, and I'm not sure if the gains would be worth that?
Mark the old method using #[deprecated].
How about this approach?
use std::ops::{Deref, DerefMut};
struct DeviceSlice {
ptr: std::ffi::c_ulonglong,
}
impl DeviceSlice {
fn device_ptr(&self) -> DevicePtr {
DevicePtr {
raw: self.ptr,
_marker: Default::default(),
}
}
fn device_ptr_mut(&self) -> DeviceMutPtr {
DeviceMutPtr {
raw: self.ptr,
_marker: Default::default(),
}
}
}
#[derive(Debug)]
#[repr(transparent)]
struct DevicePtr<'a> {
raw: std::ffi::c_ulonglong,
_marker: std::marker::PhantomData<&'a ()>,
}
impl Deref for DevicePtr<'_> {
type Target = std::ffi::c_ulonglong;
fn deref(&self) -> &Self::Target {
&self.raw
}
}
impl DevicePtr<'_> {
fn cast<T>(&self) -> *const T {
unsafe { std::mem::transmute(self) }
}
}
#[derive(Debug)]
#[repr(transparent)]
struct DeviceMutPtr<'a> {
raw: std::ffi::c_ulonglong,
_marker: std::marker::PhantomData<&'a mut ()>,
}
impl Deref for DeviceMutPtr<'_> {
type Target = std::ffi::c_ulonglong;
fn deref(&self) -> &Self::Target {
&self.raw
}
}
impl DerefMut for DeviceMutPtr<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.raw
}
}
impl DeviceMutPtr<'_> {
fn cast<T>(&mut self) -> *mut T {
unsafe { std::mem::transmute(self) }
}
}
I've thought about it, and the approach mentioned above is equivalent to using reference, but the latter is more concise, which makes it a better choice. Therefore, I think there is no need for any changes.