Potential bug in forward-mode AD with NaNs in intermediate results
romanngg opened this issue · comments
Consider
f = lambda x: x[0] + np.mean(x[1])
x = (np.ones(()), np.ones((0,)))
f(x) # DeviceArray(nan, dtype=float32)
Then reverse- and forward-mode Jacobians don't match:
jacfwd(f)(x) # (DeviceArray(nan, dtype=float32), DeviceArray([], dtype=float32))
and
jacrev(f)(x) # (DeviceArray(1., dtype=float32), DeviceArray([], dtype=float32))
Mathematically both Jacobians are wrong and should be (DeviceArray(0., dtype=float32), DeviceArray([], dtype=float32))
, but assuming this is not possible, (DeviceArray(1., dtype=float32), DeviceArray([], dtype=float32))
seems more reasonable to me, and also it matches the behavior of differentiating f = lambda x: (x + np.nan)
, where both Jacobians are 1.
.
Not sure if this is a bug, or if it falls into undefined behavior territory.
Thanks for bringing this up.
Arguably I don't think both Jacobians are wrong though. You could say that the Jacobian of g = lambda x: x + y
(where x
and y
are of type f32[]
) should be 1.0
for any value of y
; that would make the jacrev
Jacobian correct. After all, for the corresponding math function, g: ℝ → ℝ
, g(x)= x + y
, for any real-valued x
and y
we have we have (g(x + ε) - g(x)) / ε = (x + ε - x) / ε = 1
.
The issue, as in #1052, is that with the inclusion of nan
values floats no longer operate like a field, and arrays of floats no longer operate like vector spaces. In particular, the above statement no longer holds true: the Jacobian of g
may not be 1
for some values of y
: if y
is nan
, then (g(x) + g(x + eps)) / eps
is nan
for small eps
...
So I'd say the Jacobian of the mathematical function corresponding to g
is only mathematically defined when y
is not nan
. And similarly for your function f
. When the math is undefined, autodiff has no hope.
It's a subtle issue though! Sometimes autodiff fails to correspond to the math. In this case, the math is just undefined, and so the autodiff result is also undefined, with the precise value returned dependent on operational details.
Thanks a lot Matt - you're right, this makes sense as undefined behavior!