google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

gradient of jnp.prod on a three-element array of complex dtype behaves incorrectly when jitted

FlorianH-1QBit opened this issue · comments

It looks like there is an issue that surfaces when using jax.grad on a jitted function which takes the product of a complex three-element array:

import jax
import jaxlib
import jax.numpy as jnp

print(f'{jax.__version__=}')
print(f'{jaxlib.__version__=}')

@jax.jit
def fn(x):
    return jnp.abs(jnp.prod(jnp.array([x, 1, 1], dtype=jnp.complex64)))

x = jnp.array(1.)
with jax.disable_jit():
    print('disable_jit:', jax.grad(fn)(x))  # correct result

print('with jit:', jax.grad(fn)(x))  # incorrect, indeterminate result

gives

jax.__version__='0.3.5'
jaxlib.__version__='0.3.5'
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
disable_jit: 1.0
with jit: nan

The last line sometimes prints nan, sometime 0.0, and sometimes another indeterminate number.

If [x, 1, 1] is replaced with a list of a different length, it works fine. If the dtype is real, it's also fine.
The problem occurs both on a Macbook and on a Ubuntu google cloud machine.
I'm working around this issue by replacing jnp.prod with a for-loop.

I believe I narrowed it down to an issue with mhlo.pad:
After playing around with the MHLO module produced by the above, here is a more minimal example:

import jax
import jax.numpy as jnp
from jax.interpreters.mlir import ReplicaAxisContext, lower_jaxpr_to_module, module_to_string
from jax.interpreters.xla import AxisEnv


def fn():
    arr = jnp.array([2.+0j])
    val = jnp.array(1.+0j)
    return jax.lax.pad(arr, val, ((0,1,0),))


jaxpr = jax.make_jaxpr(fn)()
print('JAXPR:')
print(jaxpr)
print()

m = lower_jaxpr_to_module('fn', jaxpr, platform='cpu', axis_context=ReplicaAxisContext(AxisEnv(0,(),())),
                          name_stack='xla_computation(fn)/', donated_args=())
print('MHLO_MODULE:')
print(module_to_string(m))
print()

xla_comp = jax.xla_computation(fn)()
print('HLO_TEXT:')
print(xla_comp.as_hlo_text())
print()

fn_jitted = jax.jit(fn)
print('OUTPUT:')
print(fn_jitted())

Output:

JAXPR:
{ lambda a:c64[1]; . let
    b:c64[2] = pad[padding_config=((0, 1, 0),)] a (1+0j)
  in (b,) }

MHLO_MODULE:
module @fn.0 {
  func public @main() -> tensor<2xcomplex<f32>> {
    %0 = mhlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor<1xcomplex<f32>> loc(#loc0)
    %1 = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>> loc(#loc0)
    %2 = "mhlo.pad"(%0, %1) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<1xcomplex<f32>>, tensor<complex<f32>>) -> tensor<2xcomplex<f32>> loc(#loc1)
    return %2 : tensor<2xcomplex<f32>> loc(#loc0)
  } loc(#loc0)
} loc(#loc0)
#loc0 = loc(unknown)
#loc1 = loc("xla_computation(fn)/jit(main)/pad[padding_config=((0, 1, 0),)]"("/Users/florianhopfmueller/code/debug_pad.py":10:1))


HLO_TEXT:
HloModule xla_computation_fn.1

ENTRY main.3 {
  constant.1 = c64[2]{0} constant({(-1.20104458e-28, 4.58981e-41), (2.8026e-45, 0)})
  ROOT tuple.2 = (c64[2]{0}) tuple(constant.1)
}



OUTPUT:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[-6.8912085e-28+4.5898e-41j  2.8025969e-45+0.0000e+00j]

I followed jax.xla_computation in extracting the MHLO module and HLO text. The MHLO module looks right, but a strange constant appears in the HLO text. I don't know enough about the chain of compilers to pursue this further, but I hope this helps!

Looks similar to #10159, which was however fixed with jaxlib 0.3.5. Both of the above issues are still present with jaxlib 0.3.7 and jaxlib 0.3.10.

Thanks, you are right, this looks like another complex-number bug in the MHLO canonicalizer. I reported it to the folks who can fix it.

tensorflow/tensorflow@39ad008 fixed this, so the fix should be in the next jaxlib release.