Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

continue training from checkpoint seems broken (high loss values), while reasonable with .eval()

yairkit opened this issue · comments

I tried to load (my trained) model from checkpoint for a fine-tune training.
on the first "on_val_step()" output seems OK, loss scale is same as at the end of pre-train.
but on first "on_train_step()" output is totally different, very bad - just like it's a "training from scratch".

that behavior happens both when I:

  1. stop a train in the middle, and then run the same train with "resume from checkpoint"

  2. manually load a model from checkpoint after pre-training finished, as follow:
    checkpoint = torch.load(config['pre_trained_weights_checkpoint'], map_location=lambda storage, loc: storage) experiment.load_state_dict(checkpoint['state_dict'])
    (where "experiment" my pl.LightningModule)

am I doing something wrong?
what is the best practice for continue training a model from the last weights point is stopped in PL?

Thanks.

the Trainer argument resume_from_checkpoint only restores trainer settings (global step etc.) and loads the state dict of the model.
You also need to load the correct hyperparameters. For example, if your learning rate changes and you start with the initial one, it could lead to a jump in loss as you describe.

experiment = Experiment.load_from_checkpoint("epoch-1.ckpt") 
trainer = Trainer(..., resume_from_checkpoint="epoch-1.ckpt")
trainer.fit(model)

@awaelchli Thanks for the respond!

Maybe I wasn’t precise enough -
The problem is that, for the first forward even without training, (without doing any loss.backward() or optimizer.step()),
I already have a loss that indicates that my model is garbage when it’s configured with model.train().
But everything is OK when I use model.eval() (For the exact same code, dataloader etc…).
It’s like if the train() method was making my model completely useless.
I guess it has something to do with the BN layers, (I used DDP with sync_batch_norm), but can't really find the exact problem.

did anyone face the same issue?

Do you think you could reproduce the issue with this template
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/bug_report_model.py
by adding a batch norm somewhere and running with sync_batch_norm just like you did?

I'm not sure if this is related, but i think i have the same issue here. I wanted to resume training today and the loss was a lot higher than when I finished training yesterday:

image

Disclaimer: I've just started using PyTorchLighning (thank you guys for that awesome framework!!), so perhaps I did something wrong. This is how i tried to continue the training.

model = MyModel.load_from_checkpoint(chkpt)
trainer = pl.Trainer(resume_from_checkpoint=chkpt, gpus=parser.gpus, distributed_backend=parser.distributed_backend)
trainer.fit(model, datamodule=data)

@awaelchli Could you give an example how I'd load the correct hyperparameters? I'm using LR scheduling, and I thought the whole point of passing resume_from_checkpoint to the trainer was that the trainer would load the LR and other hyperparams from the checkpoint.

is the x-axis global step?
can you verify that on the global step chart, the epoch corresponds to the global step of when you resume training (orange start point)?

@gergol Except for the fact that I continued from the latest step of the pre-training
(As if your Orange graph would start at 5.5K exactly), it seems pretty much the same.
mine was even worse.

Can you share some high level details of your architecture / data domain?

Do you think you could reproduce the issue with this template
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/bug_report_model.py
by adding a batch norm somewhere and running with sync_batch_norm just like you did?

@awaelchli I will try to do that later and update, Thanks!

@awaelchli Yeah, the x-axis is global step. My train dataset size is 967, the batch size 104, so with drop_last==True the global step of the checkpoint (5292) is exactly epoch 588, which is the epoch after the checkpoint was saved. I think that's as is should be.

@yairkit The "overlapping" of the orange and grey paths is because the checkpoint was the "best model score" one which was saved a few epochs before I canceled the training.

The model I'm using is a PyTorch Lightning port I made from this one here.

commented

@awaelchli Continuing training indeed doesn't work. I verified the code flow goes into https://github.com/PyTorchLightning/pytorch-lightning/blob/207ff728c940ff7d8bb317a83d22378b759c9292/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L84 and properly restores stuff, but it looks like Trainer is re-initialized to default setting somehow (or there are some weird copies of Trainer object)!? Haven't had the time to dig deeper. To me this happens with DDP+AMP, and it didn't happened on PL 0.8.1, so this is an issue after upgrade for me.

Anyway, I think you can create minimal repro with manual checkpoint at lets say trainer.global_step == 100 (do the manual checkpoint as in docs with self.trainer.save_checkpoint(...) from LightningModule). If you load that checkpoint global_step will start from zero again.

@dlrac We currently don't have a way to fully restore all trainer settings.
We restore global step, current epoch and the states of the callbacks (and maybe a few other things).
It is very well possible that you may have different trainer settings compared to when you trained the checkpoint that you are loading now, and this can influence the optimization. If you used the argparse approach and told the LightningModule to save the hyperparameters, you should be able to load them and double check your trainer args.

The only reason I can imagine why you see a reset global step is because you manually invoked
self.trainer.save_checkpoint(...) with the argument weights_only=True

commented

@awaelchli A few weird things happened, but they are on our end. :)

commented

@bongoramondo Yeah, so we were working on complicated image -> sequence task where allowed sequence values were pulled from many files. One of our developers did something like:

vocabulary = list(set(token_list_0 + token_list_1 + ... + token_list_n))

since set() is not ordered and is dependent on PYTHONHASHSEED which can't be set programatically, loading weights in different program run loaded tokens in different order.

The reason why I first suspected problems on PL side is that we were having some other issues in the past (e.g. EarlyStopping & multi-GPU didn't play along well and similar).

Closing this, feel free to reopen in case anything is unclear!

@awaelchli Thanks for the respond!

Maybe I wasn’t precise enough -
The problem is that, for the first forward even without training, (without doing any loss.backward() or optimizer.step()),
I already have a loss that indicates that my model is garbage when it’s configured with model.train().
But everything is OK when I use model.eval() (For the exact same code, dataloader etc…).
It’s like if the train() method was making my model completely useless.
I guess it has something to do with the BN layers, (I used DDP with sync_batch_norm), but can't really find the exact problem.

did anyone face the same issue?

Hi @yairkit are you sure you had this problem? And not some other mistake? Did you find any solution to it? Are you able to reproduce it? I'm facing a similar problem. So, I wanted to know how you proceeded about it

I'm having the same problem. After resume_from_checkpoint, the loss is higher than the last step before checkpointing. Maybe the trainer does not resume learning rate?

I'm having the same problem. After resume_from_checkpoint, the loss is higher than the last step before checkpointing. Maybe the trainer does not resume learning rate?

Hi, have you solved the problem ?

I am seeing a similar issue – I can resume training flawlessly on a single GPU but when doing multi-GPU training my train loss registers a big increase.

@awaelchli you mention that you cannot restore all trainer settings. Do you have any suggestion what I could look for to debug this?

The LR schedule does resume correctly in my case. I also pass the same hyperparameter settings to both the (interrupted) training run and to the resumed one.

commented

Any progress here? I am also seeing this problem with pl 2.2.1

I am not using a rate scheduler; constant learning rate. I still see this jump. In addition, loss was consistently improving before stopping, however the second forward stops learning and my early stop callback kicks in and stops the run.

Considering switching to pure PyTorch for training due to this issue.