pydata / sparse

Sparse multi-dimensional arrays for the PyData ecosystem

Home Page:https://sparse.pydata.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Faster matmul for case sparse.coo * np.ndarray

HatPdotS opened this issue · comments

I am using this module quite heavily in a recent project and found that sparse.matmul could work faster.
I have implemented a faster version for the case matmul(sparse.coo, np.ndarray) using numba for the case of 2D arrays (cut be generalized). I also implemented a parallellized Version. The single threaded version is roughly 3 times faster than sparse.matmul in this case and scales linearly with the number of entries in the sparse matrix. The multi-threaded Version has some overhead but scales well for very large matrices (10000000+ entries).

Here is the code:

@numba.njit()
def matmul_fast(dim1:np.ndarray,dim2:np.ndarray,V:np.ndarray,Arr: np.ndarray,out: np.ndarray):
    for i in range(V.shape[0]):
        out[dim1[i]] += V[i] * Arr[dim2[i]]
    return out
@numba.njit(parallel=True,fastmath=True,boundscheck=False)
def matmul_fast_parallel(dim1:np.ndarray,dim2:np.ndarray,V:np.ndarray,Arr: np.ndarray,out: np.ndarray,n: np.int64 = 100000):
    f = V.shape[0]
    s = out.shape
    out_loc = np.zeros(s,dtype=np.float64)
    g = 0
    i = 0
    for i_0 in numba.prange(0,f//n+1):
        i = i_0 * n
        out_loc = np.zeros(s,dtype=np.float64)
        for g in range(i,i+n):
            if g>=f:
                break
            out_loc[dim1[g]] += V[g] * Arr[dim2[g]]
        out += out_loc
    return out

These functions can be driven by a non numba function that unpacks the sparse array object into coordinates, data and shape.

def matmul_sparse(Arr_sparse, Arr):
    out = np.zeros((Arr_sparse.shape[0],Arr.shape[1]))
    return matmul_fast_parallel(Arr_sparse.coords[0],Arr_sparse.coords[1],Arr_sparse.data,Arr,out)

Are there any optimizations planned that improve the performance of sparse.matmul?
If not would there be interest in introducing specific solutions like this as an option in the current implementation?

Hello, there's no precedent for multithreaded ops in this library, but please don't hesitate to make a PR for the single-threaded version. I'd be quite happy to accept it.