minzwon / sota-music-tagging-models

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error in loading the model during the training process

sainathadapa opened this issue · comments

Hi, I'm trying to retrain the ShortChunkCNN model for MagnaTagATune dataset. The training process is error-ring out at the 80th epoch, when the self.load is called from the opt_schedule:

RuntimeError: Error(s) in loading state_dict for ShortChunkCNN:
        size mismatch for spec.mel_scale.fb: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([257, 128]).

Can you tell me why this is happening? Also, can you tell me why the following snippet of code is required in self.load?

if 'spec.mel_scale.fb' in S.keys():
            S['spec.mel_scale.fb'] = torch.tensor([])

Thanks for the nice paper, and neat code!

Hi, thank you for your interest!

It is because of the initialization issue of torchaudio.transforms.MelSpectrogram.
When you first initialize a model, shape of model.spec.mel_scale.fb is torch.Size([0]). However, after you perform model.spec(x) or some other actions, the shape becomes torch.Size([257, 128]).

Previously, I tried to ignore this issue by

if 'spec.mel_scale.fb' in S.keys():
    S['spec.mel_scale.fb'] = torch.tensor([])

but it looks like a version dependent temporary solution.

I fixed the lines as follows.

if 'spec.mel_scale.fb' in S.keys():
    S['spec.mel_scale.fb'] = S['spec.mel_scale.fb']

If you pull the most recent version, it will work now.

Best,
Minz

@sainathadapa did it solve your problem?

It did, thanks!