google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Remove `tree_map` deprecation filter after Flax upgrades minimum Python version to 3.10

chiamp opened this issue · comments

Context:

  • As of JAX 0.4.26, jax.tree_map is deprecated
  • #3823 renames all jax.tree_map usages to jax.tree_util.tree_map in Flax, however we get an error in CI because of a CLU dependency
  • After fixing CLU and pushing a new release, the error remains on CI for Python 3.9 tests
    • this is because Flax enforces an earlier version of CLU (before the tree_map fix) on python versions less than 3.10, since the match-case syntax used by CLU is only available in Python 3.10 or greater
    • i.e. because Flax supports a minimum Python version of 3.9, Flax must use an earlier CLU version (where the CLU fix has not landed yet), since the current CLU version with the fix uses a python syntax that isn't available until Python 3.10
  • Our current solution is to add a deprecation warning filter in #3828

Once Flax upgrades its minimum Python version to 3.10, we should remove the deprecation warning filter and remove enforcing an earlier version of CLU.