`TracerArrayConversionError` with numpy in the function for autodiff
Fred-Wu opened this issue · comments
I am trying the autodiff functionality in JAX
. In a function to be taken derivatives, I have to use jax.numpy
to call some math operations such as exp
instead of using the one from numpy
. Otherwise it will give error messages below.
import jax
import jax.numpy as jnp
import numpy as np
def model_np(W):
comp1 = np.exp(W[0])
comp2 = np.exp(-np.exp(W[1]))
comp3 = np.exp(W[2])
comp4 = np.exp(-np.exp(W[3]))
return comp1 * comp2 + comp3 * comp4
val = np.array([0.69, 0.69, -1.6, -1.6])
model_grad = jax.jit(jax.jacfwd(model_np, argnums=0))
model_grad(val)
[TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<JVPTrace(level=2/1)> with
primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
tangent = Traced<ShapedArray(float32[])>with<BatchTrace(level=1/1)> with
val = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
batch_dim = 0]()
def model_jnp(W):
comp1 = jnp.exp(W[0])
comp2 = jnp.exp(-jnp.exp(W[1]))
comp3 = jnp.exp(W[2])
comp4 = jnp.exp(-jnp.exp(W[3]))
return comp1 * comp2 + comp3 * comp4
val = np.array([0.69, 0.69, -1.6, -1.6])
model_grad = jax.jit(jax.jacfwd(model_jnp, argnums=0))
model_grad(val)
DeviceArray([ 0.2715211 , -0.5413358 , 0.1649857 , -0.03331004], dtype=float32)
The function calling autodiff is using numpy
in all other places, and there is no issue to calculate other results if model_np
changes to usejax.numpy
and others using numpy
, therefore, I am not sure it is a bug or implemented intentionally.
Yes, that is correct: JAX transforms are only compatible with JAX functions, not NumPy functions, so you should use jax.numpy
instead of numpy
if you wish to use jax.jit
, jax.jacfwd
, and other function transformations.
Yes, that is correct: JAX transforms are only compatible with JAX functions, not NumPy functions, so you should use
jax.numpy
instead ofnumpy
if you wish to usejax.jit
,jax.jacfwd
, and other function transformations.
Since the input does not necessarily need jax.numpy.array
as in my example, but the function needs jax.numpy
, wouldn't it be a bit inconsistent?
Ah, I see - the important piece here is that transformations (like jit
and jacfwd
) act on functions – so if you want to transform a function it must be implemented in terms of JAX functions.
For convenience, if a JAX function receives a numpy array as an argument, it will implicitly convert it to a JAX array if necessary.