nn.module load_state_dict strict=False can silently load a corrupted model
tonydavis629 opened this issue · comments
Context
Loading a model using model.load_state_dict(state_dict, strict=False) can result in a loaded model which runs without error but is not the intended model in state_dict.
I trained my model using lightning and attempted to load the state_dict from a checkpoint. The state_dict contained keys which included model.
ahead of the key values I assume because my lightning model includes a .model attribute. Using strict=False allows the model to load without error, but the model output was nonsense. Manually removing the model.
ahead of all of the keys and removing strict=False solved the problem.
- Pytorch version: 1.12.0
- Operating System and version: Ubuntu 22.04
Your Environment
Conda environment
pytorch-lightning 1.6.5
torchvision 0.13.0
- Installed using source? [yes/no]: No
- Are you planning to deploy it using docker container? [yes/no]: No
- Is it a CPU or GPU environment?: GPU
Expected Behavior
model.load_state_dict(state_dict, strict=False) should throw an error if it cannot resolve the state_dict or load the model correctly.
Current Behavior
model.load_state_dict(state_dict, strict=False) loads the model incorrectly without an error message.
Possible Solution
Remove strict flag from nn.module.load_state_dict or correct the behavior so it throws an error
Steps to Reproduce
- Create a class to train a model using lightning, saving the model as an attribute.
- Save a checkpoint.
- Load the model using model.load_state_dict(state_dict, strict=False)
- Check if output is as expected
Other users complaining of the same bug here: https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/37
class detector(pl.LightningModule):
def __init__(self, model_path=None,classes=marvel_classes,threshold=0.5, finetune=False, batch_size=10):
super().__init__()
self.threshold = threshold
self.classes = classes
self.finetune = finetune
self.batch_size = batch_size
self.model = self.load_model(model_path)
def load_model(self, model_path=None):
"""
Loads a pretrained model using state_dict if desired
"""
print("Loading model...")
if model_path is not None:
print("Loading model from:", model_path)
if model_path.endswith(".ckpt"):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=False)
state_dict = torch.load(model_path)['state_dict']
model.load_state_dict(state_dict, strict=False)
else:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights=FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
model.load_state_dict(torch.load(model_path)['model_state_dict']) #model_state_dict
else:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights=FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT, trainable_backbone_layers=3)
print("Loading model...done")
if self.finetune:
num_classes = len(self.classes)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
Hi @tonydavis629 , the issue seems unrelated to this repo. Do you encounter the issue with any example in this repo?
BTW, strict
is true by default for module.load_state_dict()
in PyTorch. Usually, we don't have set it. i.e. model.load_state_dict(state_dict)
is most people use.
https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict
Sorry I didn't realize I was in the examples repo. Yes I understand that most people use the default argument, but silently failing is a major issue which is hard to debug.
No worries! Please ask in pytorch-lightning or pytorch repo if needed. I will close this issue then.