Load from checkpoint doesn't load model for inference
golnaz-hs opened this issue · comments
Bug description
Hi,
I trained the attention unet model and by ModelCheckpoint, the trained model was saved as a .chpt file. Now, in inference, I encounter an error when loading the model. The same code was working around one month ago, I can't understand what is the problem. (version 2.3.0)
What version are you seeing the problem on?
master
How to reproduce the bug
data_module.setup()
tb_logger = TensorBoardLogger(save_dir='lightning_logs', name=f'{DatasetConfig.MODEL_NAME}')
logger = pl.loggers.CSVLogger(save_dir='logs/', name=f'{DatasetConfig.MODEL_NAME}')
model_checkpoint = ModelCheckpoint(
monitor="valid_iou",
mode="max",
filename="ckpt_{epoch:03d}-vloss_{valid_loss:.4f}_vf1_{valid_iou:.4f}",
auto_insert_metric_name=False,
)
trainer = pl.Trainer(accelerator="auto",
devices="auto",
strategy="auto",
max_epochs=DatasetConfig.NUM_EPOCHS,
# enable_model_summary=False,
callbacks=[model_checkpoint, lr_rate_monitor],
precision="16-mixed",
# limit_val_batches=0.1,
val_check_interval=len(train_loader),
num_sanity_val_steps=0,
logger=[logger, tb_logger]
)
trainer.fit(model, data_module)
##inference
CKPT_PATH = "logs/deeplabv3p-test/version_2/checkpoints/ckpt_023-vloss_0.2485_vf1_0.5624.ckpt"
model = MedicalSegmentationModel.load_from_checkpoint(checkpoint_path=CKPT_PATH)
model = model.eval()
Error messages and logs
# Error messages and logs here please
RuntimeError Traceback (most recent call last)
in <cell line: 1>()
----> 1 model = MedicalSegmentationModel.load_from_checkpoint(checkpoint_path=CKPT_PATH)
2 # model = model.to(device)
3 model = model.eval()
4 frames
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/model_helpers.py in wrapper(*args, **kwargs)
123 " Please call it on the class type and make sure the return value is used."
124 )
--> 125 return self.method(cls, *args, **kwargs)
126
127 return wrapper
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/module.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
1584
1585 """
-> 1586 loaded = _load_from_checkpoint(
1587 cls, # type: ignore[arg-type]
1588 checkpoint_path,
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/saving.py in _load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
89 return _load_state(cls, checkpoint, **kwargs)
90 if issubclass(cls, pl.LightningModule):
---> 91 model = _load_state(cls, checkpoint, strict=strict, **kwargs)
92 state_dict = checkpoint["state_dict"]
93 if not state_dict:
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/saving.py in _load_state(cls, checkpoint, strict, **cls_kwargs_new)
185
186 # load the state_dict on the model automatically
--> 187 keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
188
189 if not strict:
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
2187
2188 if len(error_msgs) > 0:
-> 2189 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2190 self.class.name, "\n\t".join(error_msgs)))
2191 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for MedicalSegmentationModel:
Missing key(s) in state_dict: "model.decoder.blocks.x_0_0.conv1.0.weight", "model.decoder.blocks.x_0_0.conv1.1.weight", "model.decoder.blocks.x_0_0.conv1.1.bias", "model.decoder.blocks.x_0_0.conv1.1.running_mean", "model.decoder.blocks.x_0_0.conv1.1.running_var", "model.decoder.blocks.x_0_0.conv2.0.weight", "model.decoder.blocks.x_0_0.conv2.1.weight", "model.decoder.blocks.x_0_0.conv2.1.bias", "model.decoder.blocks.x_0_0.conv2.1.running_mean", "model.decoder.blocks.x_0_0.conv2.1.running_var", "model.decoder.blocks.x_0_1.conv1.0.weight", "model.decoder.blocks.x_0_1.conv1.1.weight", "model.decoder.blocks.x_0_1.conv1.1.bias", "model.decoder.blocks.x_0_1.conv1.1.running_mean", "model.decoder.blocks.x_0_1.conv1.1.running_var", "model.decoder.blocks.x_0_1.conv2.0.weight", "model.decoder.blocks.x_0_1.conv2.1.weight", "model.decoder.blocks.x_0_1.conv2.1.bias", "model.decoder.blocks.x_0_1.conv2.1.running_mean", "model.decoder.blocks.x_0_1.conv2.1.running_var", "model.decoder.blocks.x_1_1.conv1.0.weight", "model.decoder.blocks.x_1_1.conv1.1.weight", "model.decoder.blocks.x_1_1.conv1.1.bias", "model.decoder.blocks.x_1_1.conv1.1.running_mean", "model.decoder.blocks.x_1_1.conv1.1.running_var", "model.decoder.blocks.x_1_1.conv2.0.weight", "model.decoder.blocks.x_1_1.conv2.1.weight", "model.decoder.blocks.x_1_1.conv2.1.bias", "model.decoder.blocks.x_1_1.conv2.1.running_mean", "model.decoder.blocks.x_1_1.conv2.1.running_var", "model.decoder.blocks.x_0_2.conv1.0.weight", "model.decoder.bl...
Unexpected key(s) in state_dict: "model.decoder.aspp.0.convs.0.0.weight", "model.decoder.aspp.0.convs.0.1.weight", "model.decoder.aspp.0.convs.0.1.bias", "model.decoder.aspp.0.convs.0.1.running_mean", "model.decoder.aspp.0.convs.0.1.running_var", "model.decoder.aspp.0.convs.0.1.num_batches_tracked", "model.decoder.aspp.0.convs.1.0.0.weight", "model.decoder.aspp.0.convs.1.0.1.weight", "model.decoder.aspp.0.convs.1.1.weight", "model.decoder.aspp.0.convs.1.1.bias", "model.decoder.aspp.0.convs.1.1.running_mean", "model.decoder.aspp.0.convs.1.1.running_var", "model.decoder.aspp.0.convs.1.1.num_batches_tracked", "model.decoder.aspp.0.convs.2.0.0.weight", "model.decoder.aspp.0.convs.2.0.1.weight", "model.decoder.aspp.0.convs.2.1.weight", "model.decoder.aspp.0.convs.2.1.bias", "model.decoder.aspp.0.convs.2.1.running_mean", "model.decoder.aspp.0.convs.2.1.running_var", "model.decoder.aspp.0.convs.2.1.num_batches_tracked", "model.decoder.aspp.0.convs.3.0.0.weight", "model.decoder.aspp.0.convs.3.0.1.weight", "model.decoder.aspp.0.convs.3.1.weight", "model.decoder.aspp.0.convs.3.1.bias", "model.decoder.aspp.0.convs.3.1.running_mean", "model.decoder.aspp.0.convs.3.1.running_var", "model.decoder.aspp.0.convs.3.1.num_batches_tracked", "model.decoder.aspp.0.convs.4.1.weight", "model.decoder.aspp.0.convs.4.2.weight", "model.decoder.aspp.0.convs.4.2.bias", "model.decoder.aspp.0.convs.4.2.running_mean", "model.decoder.aspp.0.convs.4.2.running_var", "model.decoder.aspp.0.convs.4.2.num_batche...
size mismatch for model.segmentation_head.0.weight: copying a param with shape torch.Size([3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 16, 3, 3]).
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response
cc @awaelchli
Hey @golnaz-hs
The error message here is telling you that the name of the weights in the checkpoint is
model.decoder.blocks.x_0_0.conv1.1.bias
but the model has
model.decoder.aspp.0.convs.0.1.bias
etc.
So you can clearly see that the model definition has changed. The layers of the submodules have been renamed. It's not possible to make changes to the model and still reload an old checkpoint that has different names. Either revert the changes, or rename the keys in the checkpoint dict to match the new names.