google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

NNX `_compute_stats` function missing `use_fast_variance` and `mask` argument

chiamp opened this issue · comments

NNX _compute_stats function is missing use_fast_variance and mask argument, compared to the Linen equivalent. Is this intentional or should we add this into NNX?

cc: @cgarciae

This is not intentional, most likely it was added after the initial port of the code.
If you can add use_fast_variance it would amazing!