数据集类别替换
DDxk369 opened this issue · comments
DDxk369 commented
您好,您的工作真的太棒了,但是我想问下,如果我想把imgnet换成私有的三分类数据的话,我需要如何更改?因为加载的模型似乎只有1000类,烦请百忙之中帮忙解答,谢谢!
File "main.py", line 283, in <module>
main()
File "main.py", line 131, in main
anchor = create_model(
File "C:\Users\asus\anaconda3\lib\site-packages\timm\models\factory.py", line 71, in create_model
model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs)
File "C:\Users\asus\PycharmProjects\SN-Net-main\SN-Net-main\stitching_deit\models.py", line 69, in deit_tiny_patch16_224
model.load_state_dict(checkpoint["model"])
File "C:\Users\asus\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ViTAnchor:
Zizheng Pan commented
您好,谢谢您的支持~ 如果想在自定义数据集上训练的话,需要注意创建dataset的时候保证类别数量(args.nb_classes)是一致的:
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
args.nb_classes
这个参数会传进anchor初始化函数,使得prediction head这个fc是 embed_dim x nb_classes大小。
因为anchor是pretrain的,直接用timm训练好的weight应该会导致权重加载不进去,你可以手动下载预训练好的weight,然后手动load,不strict,类似
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)