google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Add a HOWTO showing how to get per-example gradients

avital opened this issue · comments

There was an initial attempt at #306 that wasn't completed. We'd like a side-by-side diff view of a simple model (like many of our other HOWTOs) simply showing how you can use vmap to get per-example gradients of a simple model while still also getting the mean gradient on the whole batch.

(A broader example would be an implementation of DP-SGD that'd build on top of this, but that's probably best suited somewhere broader than a small HOWTO)

Oops, we already had an issue for this! #858