Add lifted value_and_grad?
NeilGirdhar opened this issue · comments
Neil Girdhar commented
Would it be worthwhile to add a lifted version of jax.value_and_grad
? I imagine it would be pretty similar in implementation to the lifted vjp
.
Flax is a neural network library for JAX that is designed for flexibility.
NeilGirdhar opened this issue · comments
Would it be worthwhile to add a lifted version of jax.value_and_grad
? I imagine it would be pretty similar in implementation to the lifted vjp
.