HOWTOS tracking issue
marcvanzee opened this issue · comments
Marc van Zee commented
This issue tracks which HOWTOs we would like to add.
Process
- If you like to work on HOWTO, please create a PR for it and mention this issue in the PR description, once it is merged we will check the box below.
- If you think we should add a HOWTO, please reply to this issue and we will add it to the list below.
HOWTOs
- Data-parallel training. @gmittal mentioned they would like to add this based on #1982.
- Best practices for dynamic length inputs.
- Loading MNIST from torchvision and HuggingFace dataset (see #1853 for more details).
- Correctly dealing with the last batch during eval (see #1850 for more details).
- Gradient checkpointing.
- Using
nn.apply
andnn.bind
(See #1087). - Mixed precision training (suggested by @lkhphuc).
- Dropout guide (similar to BatchNorm guide)
- How to load from different datasets: torch, tf.data, HuggingFace and explain that in Flax we really only care about jax numpy arrays.
- How to do gradient accumulation
- Freezing parameters
- Training with multiple optimizers
- Gradient checkpointing.
- Flax RNG Design
- Using scan-over-layers to trade off peak memory with speed
- How to use Module.bind()
Phúc H. Lê Khắc commented
How to train with Mixed Precision training please. Probably only make sense to do after the FLIP "DTYPE" finalizes.
Marc van Zee commented
Thanks @lkhphuc, I think that is a great suggestion and I added it to the list.
Mrinal Tyagi commented
@marcvanzee Can you describe a bit more on the Dropout guide (similar to BatchNorm guide)?