rust-ndarray / ndarray

ndarray: an N-dimensional array with array views, multidimensional slicing, and efficient operations

Home Page:https://docs.rs/ndarray/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Implement numpy's `triu` and `tril` methods

geetmankar opened this issue · comments

Hi, I want to propose an implementation of numpy.triu and numpy.tril methods (link here) for arrays that are atleast 2 dimensions.

import numpy as np

np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], k=0)

returns:

array([[ 1,  2,  3],
       [ 0,  5,  6],
       [ 0,  0,  9],
       [ 0,  0,  0]])

and

np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], k=1)

returns:

array([[ 0,  2,  3],
       [ 0,  0,  6],
       [ 0,  0,  0],
       [ 0,  0,  0]])

and

np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], k=-1)

returns:

array([[ 1,  2,  3],
       [ 4,  5,  6],
       [ 0,  8,  9],
       [ 0,  0, 12]])

I implemented a basic version of np.triu and np.tril for 2D owned float square ndarrays for personal use, however I did not write it to support negative values of k.

pub trait Tri {
    fn triu(&self, k: usize) -> Self;
    fn tril(&self, k: usize) -> Self;
}

impl Tri for ArrayBase<OwnedRepr<f64>, Ix2> {
    fn triu(&self, k: usize) -> Self {
        let cols = self.len_of(Axis(1));
        let copied_arr = self.clone();

        let mut result_arr = Array2::<f64>::zeros((cols, cols).f());

        for (i, row) in copied_arr.axis_iter(Axis(0)).enumerate() {
            let slice_main = (i as usize + k)..self.len_of(Axis(0));
            row.to_owned()
                .slice(s![slice_main.clone()])
                .assign_to(result_arr.slice_mut(s![i, slice_main.clone()]));
        }

        return result_arr;
    }

   fn tril(&self, k: usize) -> Self {
        return self.clone().t().to_owned().triu(k).t().to_owned();
    }
}

A test of this code:

use ndarray::Array

// ... paste the `Tri` implementation and trait here or import it

fn main() {
    let a = Array::<f64, _>::ones((3, 3).f());
    println!("a.triu(0) = {:?}\n", a.clone().triu(0));
    println!("a.triu(1) = {:?}\n\n", a.clone().triu(1));
    println!("a.tril(0) = {:?}\n", a.clone().tril(0));
    println!("a.tril(1) = {:?}\n", a.clone().tril(1));
}

The result is:

Compiling nbody_code v0.1.0 (C:\Users\Rebel1\projects\nbody_code)
    Finished dev [unoptimized + debuginfo] target(s) in 1.40s
     Running `target\debug\nbody_code.exe`
a.triu(0) = [[1.0, 1.0, 1.0],
             [0.0, 1.0, 1.0],
             [0.0, 0.0, 1.0]], shape=[3, 3], strides=[1, 3], layout=Ff (0xa), const ndim=2

a.triu(1) = [[0.0, 1.0, 1.0],
             [0.0, 0.0, 1.0],
             [0.0, 0.0, 0.0]], shape=[3, 3], strides=[1, 3], layout=Ff (0xa), const ndim=2


a.tril(0) = [[1.0, 0.0, 0.0],
             [1.0, 1.0, 0.0],
             [1.0, 1.0, 1.0]], shape=[3, 3], strides=[3, 1], layout=Cc (0x5), const ndim=2  

a.tril(1) = [[0.0, 0.0, 0.0],
             [1.0, 0.0, 0.0],
             [1.0, 1.0, 0.0]], shape=[3, 3], strides=[3, 1], layout=Cc (0x5), const ndim=2  

ndarray should probably have those functions.

Making it generic, and with less to_owned and clone will require some work. Thank you for writing the proposition.

Sweet, I'll be waiting for the upgrade.

I'll be waiting for the upgrade.

Just to be clear, we're mostly offering bug fixes and support at the moment because all maintainers are busy. If you are actually waiting for it, I advise you to be patient :)

No worries, the code snippet I wrote is currently good enough for my pet project, so I won't need to hold my breath. Thanks for the heads-up though!