ssghost / traceax

Stochastic trace estimation using JAX

Home Page:https://mancusolab.github.io/traceax/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Documentation-webpage PyPI-Server Github License Project generated with Hatch

Traceax

traceax is a Python library to perform stochastic trace estimation for linear operators. Namely, given a square linear operator $\mathbf{A}$, traceax provides flexible routines that estimate,

$$\text{trace}(\mathbf{A}) = \sum_i \mathbf{A}_{ii},$$

using only matrix-vector products. traceax is heavily inspired by lineax as well as XTrace.

Installation | Example | Documentation | Notes | Support | Other Software


Installation

Users can download the latest repository and then use pip:

git clone https://github.com/mancusolab/traceax.git
cd traceax
pip install .

Get Started with Example

import jax.numpy as jnp
import jax.random as rdm
import lineax as lx

import traceax as tx

# simulate simple symmetric matrix with exponential eigenvalue decay
seed = 0
N = 1000
key = rdm.PRNGKey(seed)
key, xkey = rdm.split(key)

X = rdm.normal(xkey, (N, N))
Q, R = jnp.linalg.qr(X)
U = jnp.power(0.7, jnp.arange(N))
A = (Q * U) @ Q.T

# should be numerically close
print(jnp.trace(A))  # 3.3333323
print(jnp.sum(U))  # 3.3333335

# setup linear operator
operator = lx.MatrixLinearOperator(A)

# number of matrix vector operators
k = 10

# split key for estimators
key, key1, key2, key3 = rdm.split(key, 4)

# Hutchinson estimator; default samples Rademacher {-1,+1}
hutch = tx.HutchinsonEstimator()
print(hutch.estimate(key1, operator, k))  # (Array(3.7297516, dtype=float32), {})

# Hutch++ estimator; default samples Rademacher {-1,+1}
hpp = tx.HutchPlusPlusEstimator()
print(hpp.estimate(key2, operator, k))  # (Array(3.9572973, dtype=float32), {})

# XTrace estimator; default samples uniformly on n-Sphere
xt = tx.XTraceEstimator()
print(xt.estimate(key3, operator, k))  # (Array(3.1775048, dtype=float32), {'std.err': Array(0.24185811, dtype=float32)})

Documentation

Documentation is available at here.

Notes

  • traceax uses JAX with Just In Time compilation to achieve high-speed computation. However, there are some issues for JAX with Mac M1 chip. To solve this, users need to initiate conda using miniforge, and then install traceax using pip in the desired environment.

Support

Please report any bugs or feature requests in the Issue Tracker. If users have any questions or comments, please contact Linda Serafin (lserafin@usc.edu) or Nicholas Mancuso (nmancuso@usc.edu).

Other Software

Feel free to use other software developed by Mancuso Lab:

  • SuShiE: a Bayesian fine-mapping framework for molecular QTL data across multiple ancestries.
  • MA-FOCUS: a Bayesian fine-mapping framework using TWAS statistics across multiple ancestries to identify the causal genes for complex traits.
  • SuSiE-PCA: a scalable Bayesian variable selection technique for sparse principal component analysis
  • twas_sim: a Python software to simulate TWAS statistics.
  • FactorGo: a scalable variational factor analysis model that learns pleiotropic factors from GWAS summary statistics.
  • HAMSTA: a Python software to estimate heritability explained by local ancestry data from admixture mapping summary statistics.

traceax is distributed under the terms of the Apache-2.0 license.


This project has been set up using Hatch. For details and usage information on Hatch see https://github.com/pypa/hatch.

About

Stochastic trace estimation using JAX

https://mancusolab.github.io/traceax/

License:Apache License 2.0


Languages

Language:Python 100.0%