coreylowman / cudarc

Safe rust wrapper around CUDA toolkit

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Modify htod_copy assert_eq

soloist-v opened this issue · comments

    pub fn htod_sync_copy_into<T: DeviceRepr, Dst: DevicePtrMut<T>>(
        self: &Arc<Self>,
        src: &[T],
        dst: &mut Dst,
    ) -> Result<(), result::DriverError> {
        assert_eq!(src.len(), dst.len());
        unsafe { result::memcpy_htod_async(*dst.device_ptr_mut(), src, self.stream) }?;
        self.synchronize()
    }

=>

    pub fn htod_sync_copy_into<T: DeviceRepr, Dst: DevicePtrMut<T>>(
        self: &Arc<Self>,
        src: &[T],
        dst: &mut Dst,
    ) -> Result<(), result::DriverError> {
        assert!(src.len() <= dst.len());
        unsafe { result::memcpy_htod_async(*dst.device_ptr_mut(), src, self.stream) }?;
        self.synchronize()
    }

Not going to do this - you can use the slice_mut() method of CudaSlice/CudaView to make the slices the same length. This makes it explicit in the code about exactly how many things are being copied.

Otherwise, it can be unclear from calling this method what exactly is being copied.

Thanks for the suggestion though 👍