google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Add lifted value_and_grad?

NeilGirdhar opened this issue · comments

Would it be worthwhile to add a lifted version of jax.value_and_grad? I imagine it would be pretty similar in implementation to the lifted vjp.