google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

HOWTOS tracking issue

marcvanzee opened this issue · comments

This issue tracks which HOWTOs we would like to add.

Process

  • If you like to work on HOWTO, please create a PR for it and mention this issue in the PR description, once it is merged we will check the box below.
  • If you think we should add a HOWTO, please reply to this issue and we will add it to the list below.

HOWTOs

  • Data-parallel training. @gmittal mentioned they would like to add this based on #1982.
  • Best practices for dynamic length inputs.
  • Loading MNIST from torchvision and HuggingFace dataset (see #1853 for more details).
  • Correctly dealing with the last batch during eval (see #1850 for more details).
  • Gradient checkpointing.
  • Using nn.apply and nn.bind (See #1087).
  • Mixed precision training (suggested by @lkhphuc).
  • Dropout guide (similar to BatchNorm guide)
  • How to load from different datasets: torch, tf.data, HuggingFace and explain that in Flax we really only care about jax numpy arrays.
  • How to do gradient accumulation
  • Freezing parameters
  • Training with multiple optimizers
  • Gradient checkpointing.
  • Flax RNG Design
  • Using scan-over-layers to trade off peak memory with speed
  • How to use Module.bind()

How to train with Mixed Precision training please. Probably only make sense to do after the FLIP "DTYPE" finalizes.

Thanks @lkhphuc, I think that is a great suggestion and I added it to the list.

@marcvanzee Can you describe a bit more on the Dropout guide (similar to BatchNorm guide)?