rust-ndarray / ndarray

ndarray: an N-dimensional array with array views, multidimensional slicing, and efficient operations

Home Page:https://docs.rs/ndarray/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to impl SliceArg for certain SliceInfo?

HyeokSuLee opened this issue · comments

I want to make SliceInfo for generic ObsD dimension.
Like if ObsD == Ix1 then Array1::Slice_mut( sliceInfo_with_Ix1 ), if ObsD == Ix2 then Array2::Slice_mut( sliceInfo_with_Ix2 )
slice_mut argument seems need bound with SliceArg<D>.
And in the code slice.rs, it seems SliceArg implemented with all kind of input dimensions with macros.
Like this

impl<T, Dout> SliceArg<$in_dim> for SliceInfo<T, $in_dim, Dout>
        where
            T: AsRef<[SliceInfoElem]>,
            Dout: Dimension,

So I write code.

pub struct PrioritizedReplayBuffer<ObsD, ObsT> {
    pub data: Array<ObsT, ObsD>,
}

impl<ObsD, ObsT> Store<(ObsT, ObsT)> for PrioritizedReplayBuffer<ObsD, ObsT, ActD, ActT>
where
    ObsD: ndarray::Dimension + SliceAtIndex<ObsD, ObsD> + SliceArg<ObsD>,
    ObsT: PartialOrd + Default + Debug + Zero + Display,
{
    fn store<Sh>(&mut self, (obs): (Array<ObsT, ObsD>,))
    where
        Sh: ShapeBuilder<Dim = ObsD>,
    {
        let slice_info = ObsD::slice_at_index(idx); // <<<<<<<<<<<<<< I want to use slice info here.
        self.data.slice_mut(slice_info);
    }
}

trait SliceAtIndex<D: Dimension, SoutD: Dimension> {
    type Output: SliceArg<D, OutDim = SoutD>;
    fn slice_at_index(index: usize) -> Self::Output;
}

impl<D: Dimension, SoutD: Dimension> SliceAtIndex<D, SoutD> for Ix1 {
    type Output = SliceInfo<[SliceInfoElem; 1], Ix1, Ix1>;
    fn slice_at_index(index: usize) -> Self::Output {
        s![index..index + 1]
    }
}
impl<D: Dimension, SoutD: Dimension> SliceAtIndex<D, SoutD> for Ix2 {
    type Output = SliceInfo<[SliceInfoElem; 2], Ix2, Ix2>;
    fn slice_at_index(index: usize) -> Self::Output {
        s![index..index + 1, ..]
    }
}
impl<D: Dimension, SoutD: Dimension> SliceAtIndex<D, SoutD> for Ix3 {
    type Output = SliceInfo<[SliceInfoElem; 3], Ix3, Ix3>;
    fn slice_at_index(index: usize) -> Self::Output {
        s![index..index + 1, .., ..]
    }
}
impl<D: Dimension, SoutD: Dimension> SliceAtIndex<D, SoutD> for Ix4 {
    type Output = SliceInfo<[SliceInfoElem; 4], Ix4, Ix4>;
    fn slice_at_index(index: usize) -> Self::Output {
        s![index..index + 1, .., .., ..]
    }
}

But the compiler says

[E0277] the trait bound `SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>: SliceArg<D>` is not satisfied. [Note] the trait `SliceArg<D>` is not implemented for `SliceInfo<[SliceInfoElem; 1], Dim<[usize; 1]>, Dim<[usize; 1]>>`...

I cannot impl SliceArg<D> manually cause it is private.

Can you give any help?

I cannot impl SliceArg manually cause it is private.

This is intentional, c.f. https://docs.rs/ndarray/latest/src/ndarray/slice.rs.html#309. An incorrect implementation could break safety invariants and the project is not willing to commit to documenting and maintaining an unsafe interface at the moment. (Meaning the exact requirements for implementing SliceArg are not part of the API and can change at any point in time, even in patch releases.)

Can you give any help?

But doesn't your error come from wanting to implement SliceAtIndex<D, SoutD> with arbitrary D and SoutD instead of a fixed D that is matching Output?

You got some point, so I tried hard but with my brain, I couldn't make it.

So I slice it with index_axis_mut(). with that It solved.

impl<ObsD, ObsT> Store<( Array<ObsT, ObsD>)> for PrioritizedReplayBuffer<ObsD, ObsT, ActD, ActT>
where
    ObsD: ndarray::Dimension + RemoveAxis,
    ObsT: PartialOrd + Default + Debug + Zero + Display,
{
    fn store(&mut self, (obs): (Array<ObsT, ObsD>))
    {
        self.data.index_axis_mut(Axis(0), idx).assign(&obs); //self.data is Array(ObsT,ObsD).
    }
}

But It is hard to make new Array with Genrics.
I'm afraid to lose performance with using ArrayD (dynamic array).
But May be that's the high, clean way.