.at[] with complex numbers
mariogeiger opened this issue · comments
Mario Geiger commented
import jax
import jax.numpy as jnp
@jax.jit
def f():
return jnp.zeros((1,), dtype=jnp.complex64).at[0].set(1)
f()
DeviceArray([3.6602185e-37+0.j], dtype=complex64)
You Jiacheng commented
Some cases:
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
return jnp.zeros((1,), dtype=jnp.complex64).at[0].set(x)
print(f(1)) # Correct
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
return jnp.zeros((1,), dtype=jnp.complex64).at[0].set(1)
print(f(1)) # wrong
import jax
import jax.numpy as jnp
def f():
return jnp.zeros((1,), dtype=jnp.complex64).at[0].set(1)
print(f()) # correct
import jax
import jax.numpy as jnp
@jax.jit
def f():
return jnp.zeros((1,), dtype=jnp.complex64).at[0].add(1)
print(f()) # correct
import jax
import jax.numpy as jnp
def f():
return jnp.zeros((1,), dtype=jnp.complex64).at[0].set(1)
print(jax.jit(f, backend='cpu')()) # wrong
import jax
import jax.numpy as jnp
@jax.jit
def f():
return jnp.zeros((1,), dtype=jnp.complex64).at[0].set(1)
print(f.lower().compile().compiler_ir()[0].to_string())
HloModule jit_f.0
ENTRY %main.2 () -> c64[1] {
%constant_1 = c64[1]{0} constant({(-245.248779, 3.08454e-41)})
ROOT %copy = c64[1]{0} copy(c64[1]{0} %constant_1)
}
Peter Hawkins commented
The MHLO canonicalizer is mishandling simplifications of complex scatters. This will need a fix in upstream MLIR most likely.
Mario Geiger commented
You are always so efficient to fix issues ❤️
Peter Hawkins commented
This is fixed, but needs a new jaxlib release.
Peter Hawkins commented
Jaxlib 0.3.5 was released, which should fix this issue.