google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Potential bug in forward-mode AD with NaNs in intermediate results

romanngg opened this issue · comments

Consider

f = lambda x: x[0] + np.mean(x[1])
x = (np.ones(()), np.ones((0,)))

f(x)  # DeviceArray(nan, dtype=float32)

Then reverse- and forward-mode Jacobians don't match:

jacfwd(f)(x)  # (DeviceArray(nan, dtype=float32), DeviceArray([], dtype=float32))

and

jacrev(f)(x)  # (DeviceArray(1., dtype=float32), DeviceArray([], dtype=float32))

Colab:
https://colab.research.google.com/gist/icml2022anon/3aabf1cb928d01a6756956a86a5947cc/wrong_jacobians.ipynb

Mathematically both Jacobians are wrong and should be (DeviceArray(0., dtype=float32), DeviceArray([], dtype=float32)), but assuming this is not possible, (DeviceArray(1., dtype=float32), DeviceArray([], dtype=float32)) seems more reasonable to me, and also it matches the behavior of differentiating f = lambda x: (x + np.nan), where both Jacobians are 1..

Not sure if this is a bug, or if it falls into undefined behavior territory.

Thanks for bringing this up.

Arguably I don't think both Jacobians are wrong though. You could say that the Jacobian of g = lambda x: x + y (where x and y are of type f32[]) should be 1.0 for any value of y; that would make the jacrev Jacobian correct. After all, for the corresponding math function, g: ℝ → ℝ, g(x)= x + y, for any real-valued x and y we have we have (g(x + ε) - g(x)) / ε = (x + ε - x) / ε = 1.

The issue, as in #1052, is that with the inclusion of nan values floats no longer operate like a field, and arrays of floats no longer operate like vector spaces. In particular, the above statement no longer holds true: the Jacobian of g may not be 1 for some values of y: if y is nan, then (g(x) + g(x + eps)) / eps is nan for small eps...

So I'd say the Jacobian of the mathematical function corresponding to g is only mathematically defined when y is not nan. And similarly for your function f. When the math is undefined, autodiff has no hope.

It's a subtle issue though! Sometimes autodiff fails to correspond to the math. In this case, the math is just undefined, and so the autodiff result is also undefined, with the precise value returned dependent on operational details.

Thanks a lot Matt - you're right, this makes sense as undefined behavior!