Support saving/loading models larger than CPU memory
achalddave opened this issue · comments
We should use SHARDED_STATE_DICT when loading/saving checkpoints to avoid loading the entire model in CPU memory, similar to https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py.