google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

.at[] with complex numbers

mariogeiger opened this issue · comments

import jax
import jax.numpy as jnp

@jax.jit
def f():
    return jnp.zeros((1,), dtype=jnp.complex64).at[0].set(1)

f()
DeviceArray([3.6602185e-37+0.j], dtype=complex64)

Some cases:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
    return jnp.zeros((1,), dtype=jnp.complex64).at[0].set(x)

print(f(1)) # Correct
import jax
import jax.numpy as jnp

@jax.jit
def f(x):
    return jnp.zeros((1,), dtype=jnp.complex64).at[0].set(1)

print(f(1)) # wrong
import jax
import jax.numpy as jnp

def f():
    return jnp.zeros((1,), dtype=jnp.complex64).at[0].set(1)

print(f()) # correct
import jax
import jax.numpy as jnp

@jax.jit
def f():
    return jnp.zeros((1,), dtype=jnp.complex64).at[0].add(1)

print(f()) # correct
import jax
import jax.numpy as jnp

def f():
    return jnp.zeros((1,), dtype=jnp.complex64).at[0].set(1)

print(jax.jit(f, backend='cpu')()) # wrong
import jax
import jax.numpy as jnp

@jax.jit
def f():
    return jnp.zeros((1,), dtype=jnp.complex64).at[0].set(1)

print(f.lower().compile().compiler_ir()[0].to_string())
HloModule jit_f.0

ENTRY %main.2 () -> c64[1] {
  %constant_1 = c64[1]{0} constant({(-245.248779, 3.08454e-41)})
  ROOT %copy = c64[1]{0} copy(c64[1]{0} %constant_1)
}

The MHLO canonicalizer is mishandling simplifications of complex scatters. This will need a fix in upstream MLIR most likely.

You are always so efficient to fix issues ❤️

This is fixed, but needs a new jaxlib release.

Jaxlib 0.3.5 was released, which should fix this issue.