google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`TracerArrayConversionError` with numpy in the function for autodiff

Fred-Wu opened this issue · comments

I am trying the autodiff functionality in JAX. In a function to be taken derivatives, I have to use jax.numpy to call some math operations such as exp instead of using the one from numpy. Otherwise it will give error messages below.

import jax
import jax.numpy as jnp
import numpy as np
def model_np(W):
    comp1 = np.exp(W[0])
    comp2 = np.exp(-np.exp(W[1]))
    comp3 = np.exp(W[2])
    comp4 = np.exp(-np.exp(W[3]))
    return comp1 * comp2 + comp3 * comp4

val = np.array([0.69, 0.69, -1.6, -1.6])
model_grad = jax.jit(jax.jacfwd(model_np, argnums=0))
model_grad(val)

[TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<JVPTrace(level=2/1)> with
  primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
  tangent = Traced<ShapedArray(float32[])>with<BatchTrace(level=1/1)> with
    val = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
    batch_dim = 0]()
def model_jnp(W):
    comp1 = jnp.exp(W[0])
    comp2 = jnp.exp(-jnp.exp(W[1]))
    comp3 = jnp.exp(W[2])
    comp4 = jnp.exp(-jnp.exp(W[3]))
    return comp1 * comp2 + comp3 * comp4

val = np.array([0.69, 0.69, -1.6, -1.6])
model_grad = jax.jit(jax.jacfwd(model_jnp, argnums=0))
model_grad(val)

DeviceArray([ 0.2715211 , -0.5413358 ,  0.1649857 , -0.03331004], dtype=float32)

The function calling autodiff is using numpy in all other places, and there is no issue to calculate other results if model_np changes to usejax.numpy and others using numpy, therefore, I am not sure it is a bug or implemented intentionally.

Yes, that is correct: JAX transforms are only compatible with JAX functions, not NumPy functions, so you should use jax.numpy instead of numpy if you wish to use jax.jit, jax.jacfwd, and other function transformations.

Yes, that is correct: JAX transforms are only compatible with JAX functions, not NumPy functions, so you should use jax.numpy instead of numpy if you wish to use jax.jit, jax.jacfwd, and other function transformations.

Since the input does not necessarily need jax.numpy.array as in my example, but the function needs jax.numpy, wouldn't it be a bit inconsistent?

Ah, I see - the important piece here is that transformations (like jit and jacfwd) act on functions – so if you want to transform a function it must be implemented in terms of JAX functions.

For convenience, if a JAX function receives a numpy array as an argument, it will implicitly convert it to a JAX array if necessary.