Can't load optimizer state due to `state_steps`
rowhanm opened this issue · comments
Hi, I recently upgraded to PyTorch 1.12 and have had issues with loading a saved optimizer state using FSDP here and the issue seems something that is addressed in comments here -
From what I understand, Adam's step
state changed into a singleton tensor and when I call gather_full_optim_state_dict()
this step
is converted to an int.
Sample saving dict code:
model = FSDP(model, ...)
# call on all ranks
optim_state = model.gather_full_optim_state_dict(optimizer)
if rank == 0:
# save only on rank 0
checkpoint = {
'optimizer': optim_state,
...
}
torch.save(checkpoint)
Now when I load this optim state dict back - I do the following:
model = FSDP(model, ...)
torch.distributed.barrier()
# on all ranks
checkpoint = torch.load(snapshot_name)
curr_opt_state_dict = checkpoint["optimizer"]
optim_shard_dict = model.get_shard_from_optim_state_dict(curr_opt_state_dict)
optimizer.load_state_dict(optim_shard_dict)
This always fails the assertion in the Adam code - https://github.com/pytorch/pytorch/blob/master/torch/optim/adamw.py#L204 because I imagine the step was converted to an int
within FSDP and Adam expects it to be a singleton tensor.
My question is am I saving the state dict correctly? Do I need to call optimizer.state_dict()
on top of model.gather_full_optim_state_dict()
?
A workaround I'm using to get things to bypass the assertion is to convert the ints
back to singleton tensors in the adamw function however that does not seem safe. Any thoughts?
Apologies if my understanding is incorrect, I followed some of the discussion here - #776 for the state_dict saving logic.
Hey, thanks for the detailed question! I think what you are doing is correct. #776 is largely different from your issue, which is related to the optimizer state.
I am not sure whether you are running into problem 1 or problem 2 below or both.
- load pre-1.12 checkpoint and crash
- using same version (Post 1.12), save a checkpoint and load it back causing crashes
For 1, I suggest you just use torch.load and torch.save manually and patch the checkpoint so that they are compatible with 1.12. You can save 2 version of the checkpoints (one for pre 1.12, one for post 1.12) and load the correct one to avoid crashes.
For 2, that would be a bug. Please send us a minimal reproduction case if you can. PR to fix is even more awesome! ;-)
Hi, I think it is the second alternative. Saving a checkpoint and then running optimizer.step()
(with or without load) causes a crash.
Here is a minimal reproduction - https://gist.github.com/rowhanm/71272f157d8c9450d6b1c7639a612126.
[python==3.7.5, pytorch==1.12.0, fairscale==0.4.6(can't upgrade due to being restricted to py3.7; doesn't matter since this function remains the same)]
I've narrowed down the problem to be this line here - https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2443 and I am able to fix the issue in my script by converting the state step
back into a singleton tensor from an int.
I don't have too much context on what the comment in the source "comparison with original state dict would fail.
" means, and I'm not sure if my fix would cause any issues later. If there are no side-effects my proposed fix to the bug would be:
- either not do the original singleton tensor -> int conversion, or
- convert the step state back to int after serializing (not sure where exactly this should take place).
I see. This makes sense. We likely don't have test case that catches this issue. I will find a time to fix this.
btw, here is the error I got when running your sample code:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/private/home/m1n/e/py38_miniconda_pt_nightly_fairscale/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/private/home/m1n/git/fairscale/t.py", line 99, in demo_basic
optimizer.step()
File "/private/home/m1n/e/py38_miniconda_pt_nightly_fairscale/lib/python3.8/site-packages/torch/optim/optimizer.py", line 109, in wrapper
return func(*args, **kwargs)
File "/private/home/m1n/e/py38_miniconda_pt_nightly_fairscale/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/private/home/m1n/e/py38_miniconda_pt_nightly_fairscale/lib/python3.8/site-packages/torch/optim/adamw.py", line 161, in step
adamw(params_with_grad,
File "/private/home/m1n/e/py38_miniconda_pt_nightly_fairscale/lib/python3.8/site-packages/torch/optim/adamw.py", line 204, in adamw
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
RuntimeError: API has changed, `state_steps` argument must contain a list of singleton tensors
yep, that is the main issue :) Adam expects state_steps
to be a singleton tensor but gather_full_optim_state_dict()
converts it into an int. If you uncomment line 108-112 in my Gist, that basically fixes the issue.
I can throw in a small test case + PR that fixes it in a bit. Again, not sure if that is the best possible fix since the git blame
on
Thanks for trying a fix!
My best memory is that this is needed because if the step is a singleton tensor, then it maybe treated like a sharded optimizer state and gets handled by the gather function. In a way, this step scalar is assumed to be the same across all ranks, which is true for FSDP at least. Maybe there are reasons why it changed from scalar to a tensor in the first place but I haven't looked into it.
BTW, when I ran your test code with pt 1.8, it gave a different error in the loss function, which is very interesting too.
Super weird. I had only tested on pyt 1.12 which gives this error and 1.11 as expected does not since Adam expects state_steps
to be an int.
I tried but can't test with 1.8 unfortunately due to not having a GPU with correct CUDA capabilities, could you tell me what the error is that you see with 1.8?
no need to worry, but here is the error of 1.8
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/private/home/m1n/e/py38_miniconda_pt181_fairscale/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
fn(i, *args)
File "/private/home/m1n/git/fairscale/t.py", line 82, in demo_basic
loss = criterion(preds, target)
File "/private/home/m1n/e/py38_miniconda_pt181_fairscale/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/private/home/m1n/e/py38_miniconda_pt181_fairscale/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1047, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "/private/home/m1n/e/py38_miniconda_pt181_fairscale/lib/python3.8/site-packages/torch/nn/functional.py", line 2693, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "/private/home/m1n/e/py38_miniconda_pt181_fairscale/lib/python3.8/site-packages/torch/nn/functional.py", line 2388, in nll_loss
ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward
hmmm...not sure if it's a mixed precision issue. Seems like something I've seen before with incorrect typecasting when using AMP