Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Load from checkpoint doesn't load model for inference

golnaz-hs opened this issue · comments

Bug description

Hi,
I trained the attention unet model and by ModelCheckpoint, the trained model was saved as a .chpt file. Now, in inference, I encounter an error when loading the model. The same code was working around one month ago, I can't understand what is the problem. (version 2.3.0)

What version are you seeing the problem on?

master

How to reproduce the bug

data_module.setup()

tb_logger = TensorBoardLogger(save_dir='lightning_logs', name=f'{DatasetConfig.MODEL_NAME}')
logger = pl.loggers.CSVLogger(save_dir='logs/', name=f'{DatasetConfig.MODEL_NAME}')

model_checkpoint = ModelCheckpoint(
    monitor="valid_iou",
    mode="max",
    filename="ckpt_{epoch:03d}-vloss_{valid_loss:.4f}_vf1_{valid_iou:.4f}",
    auto_insert_metric_name=False,
)

trainer = pl.Trainer(accelerator="auto",  
                    devices="auto", 
                    strategy="auto", 
                    max_epochs=DatasetConfig.NUM_EPOCHS,
                    # enable_model_summary=False,
                    callbacks=[model_checkpoint, lr_rate_monitor],
                    precision="16-mixed",
                    # limit_val_batches=0.1,
                    val_check_interval=len(train_loader),
                    num_sanity_val_steps=0,
                    logger=[logger, tb_logger]
                      )
trainer.fit(model, data_module)



##inference
CKPT_PATH = "logs/deeplabv3p-test/version_2/checkpoints/ckpt_023-vloss_0.2485_vf1_0.5624.ckpt"
model = MedicalSegmentationModel.load_from_checkpoint(checkpoint_path=CKPT_PATH)
model = model.eval()

Error messages and logs

# Error messages and logs here please

RuntimeError Traceback (most recent call last)
in <cell line: 1>()
----> 1 model = MedicalSegmentationModel.load_from_checkpoint(checkpoint_path=CKPT_PATH)
2 # model = model.to(device)
3 model = model.eval()

4 frames
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/model_helpers.py in wrapper(*args, **kwargs)
123 " Please call it on the class type and make sure the return value is used."
124 )
--> 125 return self.method(cls, *args, **kwargs)
126
127 return wrapper

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/module.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
1584
1585 """
-> 1586 loaded = _load_from_checkpoint(
1587 cls, # type: ignore[arg-type]
1588 checkpoint_path,

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/saving.py in _load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
89 return _load_state(cls, checkpoint, **kwargs)
90 if issubclass(cls, pl.LightningModule):
---> 91 model = _load_state(cls, checkpoint, strict=strict, **kwargs)
92 state_dict = checkpoint["state_dict"]
93 if not state_dict:

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/saving.py in _load_state(cls, checkpoint, strict, **cls_kwargs_new)
185
186 # load the state_dict on the model automatically
--> 187 keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
188
189 if not strict:

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
2187
2188 if len(error_msgs) > 0:
-> 2189 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2190 self.class.name, "\n\t".join(error_msgs)))
2191 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for MedicalSegmentationModel:
Missing key(s) in state_dict: "model.decoder.blocks.x_0_0.conv1.0.weight", "model.decoder.blocks.x_0_0.conv1.1.weight", "model.decoder.blocks.x_0_0.conv1.1.bias", "model.decoder.blocks.x_0_0.conv1.1.running_mean", "model.decoder.blocks.x_0_0.conv1.1.running_var", "model.decoder.blocks.x_0_0.conv2.0.weight", "model.decoder.blocks.x_0_0.conv2.1.weight", "model.decoder.blocks.x_0_0.conv2.1.bias", "model.decoder.blocks.x_0_0.conv2.1.running_mean", "model.decoder.blocks.x_0_0.conv2.1.running_var", "model.decoder.blocks.x_0_1.conv1.0.weight", "model.decoder.blocks.x_0_1.conv1.1.weight", "model.decoder.blocks.x_0_1.conv1.1.bias", "model.decoder.blocks.x_0_1.conv1.1.running_mean", "model.decoder.blocks.x_0_1.conv1.1.running_var", "model.decoder.blocks.x_0_1.conv2.0.weight", "model.decoder.blocks.x_0_1.conv2.1.weight", "model.decoder.blocks.x_0_1.conv2.1.bias", "model.decoder.blocks.x_0_1.conv2.1.running_mean", "model.decoder.blocks.x_0_1.conv2.1.running_var", "model.decoder.blocks.x_1_1.conv1.0.weight", "model.decoder.blocks.x_1_1.conv1.1.weight", "model.decoder.blocks.x_1_1.conv1.1.bias", "model.decoder.blocks.x_1_1.conv1.1.running_mean", "model.decoder.blocks.x_1_1.conv1.1.running_var", "model.decoder.blocks.x_1_1.conv2.0.weight", "model.decoder.blocks.x_1_1.conv2.1.weight", "model.decoder.blocks.x_1_1.conv2.1.bias", "model.decoder.blocks.x_1_1.conv2.1.running_mean", "model.decoder.blocks.x_1_1.conv2.1.running_var", "model.decoder.blocks.x_0_2.conv1.0.weight", "model.decoder.bl...
Unexpected key(s) in state_dict: "model.decoder.aspp.0.convs.0.0.weight", "model.decoder.aspp.0.convs.0.1.weight", "model.decoder.aspp.0.convs.0.1.bias", "model.decoder.aspp.0.convs.0.1.running_mean", "model.decoder.aspp.0.convs.0.1.running_var", "model.decoder.aspp.0.convs.0.1.num_batches_tracked", "model.decoder.aspp.0.convs.1.0.0.weight", "model.decoder.aspp.0.convs.1.0.1.weight", "model.decoder.aspp.0.convs.1.1.weight", "model.decoder.aspp.0.convs.1.1.bias", "model.decoder.aspp.0.convs.1.1.running_mean", "model.decoder.aspp.0.convs.1.1.running_var", "model.decoder.aspp.0.convs.1.1.num_batches_tracked", "model.decoder.aspp.0.convs.2.0.0.weight", "model.decoder.aspp.0.convs.2.0.1.weight", "model.decoder.aspp.0.convs.2.1.weight", "model.decoder.aspp.0.convs.2.1.bias", "model.decoder.aspp.0.convs.2.1.running_mean", "model.decoder.aspp.0.convs.2.1.running_var", "model.decoder.aspp.0.convs.2.1.num_batches_tracked", "model.decoder.aspp.0.convs.3.0.0.weight", "model.decoder.aspp.0.convs.3.0.1.weight", "model.decoder.aspp.0.convs.3.1.weight", "model.decoder.aspp.0.convs.3.1.bias", "model.decoder.aspp.0.convs.3.1.running_mean", "model.decoder.aspp.0.convs.3.1.running_var", "model.decoder.aspp.0.convs.3.1.num_batches_tracked", "model.decoder.aspp.0.convs.4.1.weight", "model.decoder.aspp.0.convs.4.2.weight", "model.decoder.aspp.0.convs.4.2.bias", "model.decoder.aspp.0.convs.4.2.running_mean", "model.decoder.aspp.0.convs.4.2.running_var", "model.decoder.aspp.0.convs.4.2.num_batche...
size mismatch for model.segmentation_head.0.weight: copying a param with shape torch.Size([3, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 16, 3, 3]).

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @awaelchli

Hey @golnaz-hs

The error message here is telling you that the name of the weights in the checkpoint is

model.decoder.blocks.x_0_0.conv1.1.bias
but the model has

model.decoder.aspp.0.convs.0.1.bias
etc.

So you can clearly see that the model definition has changed. The layers of the submodules have been renamed. It's not possible to make changes to the model and still reload an old checkpoint that has different names. Either revert the changes, or rename the keys in the checkpoint dict to match the new names.