ziplab / SN-Net

[CVPR 2023 Highlight] This is the official implementation of "Stitchable Neural Networks".

Home Page:https://snnet.github.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

数据集类别替换

DDxk369 opened this issue · comments

您好,您的工作真的太棒了,但是我想问下,如果我想把imgnet换成私有的三分类数据的话,我需要如何更改?因为加载的模型似乎只有1000类,烦请百忙之中帮忙解答,谢谢!

 File "main.py", line 283, in <module>
    main()
  File "main.py", line 131, in main
    anchor = create_model(
  File "C:\Users\asus\anaconda3\lib\site-packages\timm\models\factory.py", line 71, in create_model
    model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs)
  File "C:\Users\asus\PycharmProjects\SN-Net-main\SN-Net-main\stitching_deit\models.py", line 69, in deit_tiny_patch16_224
    model.load_state_dict(checkpoint["model"])
  File "C:\Users\asus\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ViTAnchor:

您好,谢谢您的支持~ 如果想在自定义数据集上训练的话,需要注意创建dataset的时候保证类别数量(args.nb_classes)是一致的:

dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)

源码位置:https://github.com/ziplab/SN-Net/blob/c2dcf306089dfe982080a4f43a7ecebffd66abc6/stitching_deit/main.py#LL73C38-L73C51

args.nb_classes这个参数会传进anchor初始化函数,使得prediction head这个fc是 embed_dim x nb_classes大小。

因为anchor是pretrain的,直接用timm训练好的weight应该会导致权重加载不进去,你可以手动下载预训练好的weight,然后手动load,不strict,类似

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)