themrzmaster / git-re-basin-pytorch

Git Re-Basin: Merging Models modulo Permutation Symmetries in PyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

result of renset20 on cifar10 seems to be bad

zwei-lin opened this issue · comments

Hi,
Thanks for the pytorch code of git-re-basin.
I try to train two renset20 models on cifar10 and use git-re-basin to match the weights. However, the test accuracy is not good after weight interpolation between the two models (see in the fig). The result is no better than the naive one. Is there anything wrong with the code?
cifar10_resnet_weight_matching_interp_accuracy_epoch

This may be an issue with using batchnorm instead of layernorm as noted by @KellerJordan. Check out https://twitter.com/kellerjordan0/status/1570837651741364226 for more info

yes @samuela ! I am doing some tests and with groupnorm it works. Pushing the code soon

Thanks!

Fixed!

Glad to see someone taking a stab at porting this to pytorch - thanks for the work there!

Not sure if this is a separate issue or not, but I'm still getting fairly poor performance for RN22 with the GroupNorm (see below). Definitely better than naive and zwei-lin's post, but hardly an "almost convex basin".

Here's what I did:

python -m train.resnet_cifar10_train --lr 0.01 --epochs 100 --seed 10
python -m train.resnet_cifar10_train --lr 0.01 --epochs 100 --seed 11
python -m matching.resnet_cifar10_weight_matching --model_a cifar10_10_resnet_depth_22_2.pt --model_b cifar10_11_resnet_depth_22_2.pt

Do you get better results with RN22? Do you use different args?

cifar10_resnet22_2_weight_matching_interp_accuracy_epoch

Eh, to answer my own question, the method seems to be a bit hit or miss, but restarting a few times seems to eventually find a good solution.

This is not my code, so I can't really comment on specifics, but in my experience common failure modes include

  • model is not wide enough (we present 32x width ResNet results on CIFAR-10 in the paper)
  • normalization details can matter sometimes. batch normalization is naughty (have to recalculate batch stats after averaging weights a la SWA). I haven't run any experiments on GroupNorm so I'm not sure how that would work/not work. But that's def diverging from the code that we ran in the paper.
  • PermutationSpec is missing a permutation/has a bug somewhere

I'm also noticing that your learning rate (0.01) seems to be a bit lower than the learning rate we used in the paper (0.1 max, cosine decay). That shouldn't make a big difference but it's something else to look into.

We also ran for IIRC 250 epochs. Not sure that would matter much either, but it's something else to consider.

But again, not familiar with this code so I can't comment too much on specifics. But HTH in terms of a few things to try though!