google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Variance can be negative in BatchNorm computation, leading to NaN's

billmark opened this issue · comments

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

Although variance is mathematically non-negative, the way it's computed in normalization.py it can be slightly negative due to round-off error. Sometimes this error is large enough to swamp the epsilon, causing the rsqrt() to yield a NaN.

I encountered this case with a batch of large synthetic images that I was using for a test case.

I was able to fix it by changing "var = mean2 - lax.square(mean)" to "var = jnp.maximum(0.0, mean2 - lax.square(mean))", in BatchNorm's call routine.

Mathematically, this change seems sound. So I would suggest making this change, unless there are concerns about performance impact, or other concerns I'm not aware of.

What you expected to happen:

No NaN's in this case.

Logs, error messages, etc:

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

Thanks so much for the report! Do you happen to know how badly the epsilon-guard was off in your example that raised a NaN? We might just add the positivity guard anyway, but I'm curious how "off" your case was.

If I remember correctly, changing the epsilon guard from 1e-6 to 1-e4 made the NaN go away, but changing it from 1e-6 to 1e-5 did not.

My sense is that the primary purpose of the epsilon should be to prevent divide-by-zero errors, since of course variance can mathematically be zero. It seems appropriate to use a different mechanism (e.g. clamp to zero) to deal with negative variance that arises due to numerical errors that arise during the computation of variance. Currently, the code seems to be attempting to use the epsilon to deal with both problems.

fixed by PR #1545