williamberman / onenormest

matrix one norm estimation in jax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Matrix One Norm Estimation

This is an implementation of matrix one norm estimation in jax as specified by http://eprints.maths.manchester.ac.uk/321/1/35608.pdf

Testing

The implementation passes the scipy test suite with some minor relaxations, i.e. number of column resamples. Relaxed tests are documented in ./test_onenormest.py

Benchmarks

Some basic benchmarks using a GPU on the Google Colab free tier see ~8x improvement from the scipy CPU implementation for 4096x4096 matrices.

benchmarks

Implementation details

There are existing implementations in scipy and octave

The algorithm as specified is imperative and control flow heavy. Additionally, a few variables have non-constant dimensions. This implementation has a few quirks to get jax to jit compile.

Early loop breaks

The main loop has many conditional early breaks. We handle this by manual continuation passing into a branch of jax.lax.cond.

Fixed dimensions

ind_hist and ind must have fixed dimensions.

In the scipy implementation and Higham, ind_hist is a growable array that stores indices of the used unit vectors. In the octave implementation, ind_hist is a fixed sized array that writes 1 into index j when e_j is used. We use the octave implementation to keep the array a fixed size.

ind is shape (n,) in Higham but only the first t values are read out of it. The first t values are read for writing to ind_hist and it is read out of with column indices of Y which is shape (n, t). Because we only test elementary vectors a single time, it is not guaranteed we'll have t elementary vectors to test on each loop. We handle this by filling non used elements of ind with a sentinel value n. n will be used to fill columns in X with the zero vector instead of elementary vectors. These zero vectors will cause norm estimations of 0 which are always correct underestimations of the one norm. Note that because ind can have the additional sentinel value of n, ind_hist must be extended to length n + 1. Noting in ind_hist that the sentinel value has been used has no effect.

About

matrix one norm estimation in jax


Languages

Language:Python 100.0%