pytorch / examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.

Home Page:https://pytorch.org/examples

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

nn.module load_state_dict strict=False can silently load a corrupted model

tonydavis629 opened this issue · comments

Context

Loading a model using model.load_state_dict(state_dict, strict=False) can result in a loaded model which runs without error but is not the intended model in state_dict.

I trained my model using lightning and attempted to load the state_dict from a checkpoint. The state_dict contained keys which included model. ahead of the key values I assume because my lightning model includes a .model attribute. Using strict=False allows the model to load without error, but the model output was nonsense. Manually removing the model. ahead of all of the keys and removing strict=False solved the problem.

  • Pytorch version: 1.12.0
  • Operating System and version: Ubuntu 22.04

Your Environment

Conda environment
pytorch-lightning 1.6.5
torchvision 0.13.0

  • Installed using source? [yes/no]: No
  • Are you planning to deploy it using docker container? [yes/no]: No
  • Is it a CPU or GPU environment?: GPU

Expected Behavior

model.load_state_dict(state_dict, strict=False) should throw an error if it cannot resolve the state_dict or load the model correctly.

Current Behavior

model.load_state_dict(state_dict, strict=False) loads the model incorrectly without an error message.

Possible Solution

Remove strict flag from nn.module.load_state_dict or correct the behavior so it throws an error

Steps to Reproduce

  1. Create a class to train a model using lightning, saving the model as an attribute.
  2. Save a checkpoint.
  3. Load the model using model.load_state_dict(state_dict, strict=False)
  4. Check if output is as expected

Other users complaining of the same bug here: https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/37

class detector(pl.LightningModule):
  def __init__(self, model_path=None,classes=marvel_classes,threshold=0.5, finetune=False, batch_size=10):
    super().__init__()
    self.threshold = threshold
    self.classes = classes
    self.finetune = finetune
    self.batch_size = batch_size
    self.model = self.load_model(model_path)

  def load_model(self, model_path=None):
    """   
    Loads a pretrained model using state_dict if desired 
    """

    print("Loading model...")
    if model_path is not None:
      print("Loading model from:", model_path)
      if model_path.endswith(".ckpt"):
        model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=False)
        state_dict = torch.load(model_path)['state_dict']
        model.load_state_dict(state_dict, strict=False)
      else:
        model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights=FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
        model.load_state_dict(torch.load(model_path)['model_state_dict']) #model_state_dict
    else:
      model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights=FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT, trainable_backbone_layers=3)    

    print("Loading model...done")

    if self.finetune:
      num_classes = len(self.classes)
      in_features = model.roi_heads.box_predictor.cls_score.in_features
      model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

Hi @tonydavis629 , the issue seems unrelated to this repo. Do you encounter the issue with any example in this repo?

BTW, strict is true by default for module.load_state_dict() in PyTorch. Usually, we don't have set it. i.e. model.load_state_dict(state_dict) is most people use.
https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict

Sorry I didn't realize I was in the examples repo. Yes I understand that most people use the default argument, but silently failing is a major issue which is hard to debug.

No worries! Please ask in pytorch-lightning or pytorch repo if needed. I will close this issue then.