google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Add random.orthogonal and random.unitary

carlosgmartin opened this issue · comments

Add functions that sample uniformly from the orthogonal group O(n) and unitary group U(n), as described in the paper How to generate random matrices from the classical compact groups.

from jax import random, numpy as jnp

def orthogonal(key, n):
    z = random.normal(key, (n, n))
    q, r = jnp.linalg.qr(z)
    d = jnp.diag(r)
    return q * d / abs(d)

def unitary(key, n):
    a, b = random.normal(key, (2, n, n))
    z = a + b * 1j
    q, r = jnp.linalg.qr(z)
    d = jnp.diag(r)
    return q * d / abs(d)

n = 5
key = random.PRNGKey(0)
while True:
    key, subkey = random.split(key)

    q = orthogonal(subkey, n)
    assert jnp.allclose(q @ q.T, jnp.eye(n), atol=1e-6)

    u = unitary(subkey, n)
    assert jnp.allclose(u @ jnp.conj(u.T), jnp.eye(n), atol=1e-6)

    print('ok')

Among other things, this is useful for orthogonal weight initialization.

Thanks for the suggestion - I think this would probably be a welcome addition to the jax.random namespace. Are you interested in contributing your implementations?