Remove tree_multimap filterwarning once Jraph uses jax.tree_map instead of jax.tree_multimap
marcvanzee opened this issue · comments
Marc van Zee commented
Jraph uses jax.tree_multimap
, which is deprecated and causes our tests to fail. Therefore I added a filterwarning in pytest.ini
. We should remove this filter warning once Jraph fixes this.
Marc van Zee commented
Update: Optax is fine now, but Jraph is still problematic. See google-deepmind/jraph#38