Add random.orthogonal and random.unitary
carlosgmartin opened this issue · comments
Carlos Martin commented
Add functions that sample uniformly from the orthogonal group O(n) and unitary group U(n), as described in the paper How to generate random matrices from the classical compact groups.
from jax import random, numpy as jnp
def orthogonal(key, n):
z = random.normal(key, (n, n))
q, r = jnp.linalg.qr(z)
d = jnp.diag(r)
return q * d / abs(d)
def unitary(key, n):
a, b = random.normal(key, (2, n, n))
z = a + b * 1j
q, r = jnp.linalg.qr(z)
d = jnp.diag(r)
return q * d / abs(d)
n = 5
key = random.PRNGKey(0)
while True:
key, subkey = random.split(key)
q = orthogonal(subkey, n)
assert jnp.allclose(q @ q.T, jnp.eye(n), atol=1e-6)
u = unitary(subkey, n)
assert jnp.allclose(u @ jnp.conj(u.T), jnp.eye(n), atol=1e-6)
print('ok')
Among other things, this is useful for orthogonal weight initialization.
Jake Vanderplas commented
Thanks for the suggestion - I think this would probably be a welcome addition to the jax.random
namespace. Are you interested in contributing your implementations?
Carlos Martin commented