postBG / DTA.pytorch

Official implementation of Drop to Adapt: Learning Discriminative Features for Unsupervised Domain Adaptation presented at ICCV 2019.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Problem when saving and loading models

mboudiaf opened this issue · comments

Hi,

First things first, thank you for writing and sharing such well structured code. I just wanted to bring one issue to your attention: the code as such has problems when one wants to save his model for later use. Indeed, your overwrite the load_state_dict() method (to properly recover weights from imagenet pretraining I guess), but use the same overwritten function to recover weights from a given checkpoint :

    if args.classifier_ckpt_path:
        print("Load class classifier from {}".format(args.classifier_ckpt_path))
        ckpt = torch.load(args.classifier_ckpt_path)
        class_classifier.load_state_dict(ckpt['classifier_state_dict'])

Due to the filtering conditions in your load_state_dict() defined method, this doesn't properly recover the weights from the previous network. In the case of a given checkpoint, you should directly use the pytorch load_state_dict() method (I bring this to your attention because it doesn't actually raise any error and is not absolutely obvious when looking at the training statistics of the recovered net). Thanks again for youR work :)

Hi, @mboudiaf.
Thanks for using our code. We just fixed the raised issue.
Please check it out and notify us if there still exists any problem.
Thanks again.