YuanGongND / ssast

Code for the AAAI 2022 paper "SSAST: Self-Supervised Audio Spectrogram Transformer".

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

The model loaded is not from a torch.nn.Dataparallel object

kremHabashy opened this issue · comments

Hi Yuan, I fine tuned the SSAST model on my own data, and am trying to load the model to be used for demonstrative purposes. For some reason, when I load in the model path of the "best_audio_model.pth" that is generated after fine tuning, I get the following error from ast_models.py:
ValueError: The model loaded is not from a torch.nn.Dataparallel object. Wrap it with torch.nn.Dataparallel and try again.
I'm not really sure why I am getting this error as I am following the instructions shown in the example at the if name == "main" statement. What am I missing? Shouldn't the created models be already wrapped with dataparallel as seen in traintest.py?

Hi there,

It is a known bug that is not related to dataparallel.

The problem is that I used a trick to encode the pretraining hyperparameters in the model and use the existence of the hyperparameter to check if the model is a dataparallel object. The SSL pretraining code do save the hyperparameters but the fine-tuning code does not, so when you do another round of testing, the code cannot find the hyperparameter and think the model is not dataparallel.

p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3]
p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()

For a temporal workaround, you can change these two lines of code:

p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3]
p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()

I will find a time to fix it.

Hey Yuan, sorry, what should I change them to?

Can you paste the code you initialize the AST model in fine-tuning? (not inference)

# initialize an AST model
torch.cuda.set_device('cuda:0')
device = torch.device("cuda:0")

pretrained_mdl_path =r"C:\Users\habashyk\venvs\SSAST2\src\finetune\NRC\exp\test01-NRC-f10-16-t10-16-b12-lr1e-4-ft_avgtok-base---1x-noiseFalse-3\models\best_audio_model.pth"

p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3]

ast_mdl = ASTModel(fstride=10, 
                   tstride=10, 
                   fshape=16, 
                   tshape=16,
                   input_fdim=128, 
                   input_tdim=50,
                   model_size='base',
                   pretrain_stage=False,
                   load_pretrained_mdl_path=pretrained_mdl_path)

ast_mdl = torch.nn.DataParallel(ast_mdl)
        
ast_mdl.load_state_dict(sd, strict=False)
ast_mdl.cuda()
ast_mdl.eval()

This now gives the error:
RuntimeError: Error(s) in loading state_dict for DataParallel: size mismatch for module.v.pos_embed: copying a param with shape torch.Size([1, 50, 768]) from checkpoint, the shape in current model is torch.Size([1, 26, 768]).

Have you changed anything in ast_models.py?

I guess the following might work

pretrained_mdl_path = a
finetuned_mdl_path = b

ast_mdl = ASTModel(fstride=10,  # should be consistent with that you fine-tune the model with
                   tstride=10, 
                   fshape=16, 
                   tshape=16,
                   input_fdim=128, 
                   input_tdim=50,
                   model_size='base',
                   pretrain_stage=False, # should be False
                   load_pretrained_mdl_path=pretrained_mdl_path) # should be pretrained_mdl_path, not fine_tuned_mdl_path

ast_mdl = torch.nn.DataParallel(ast_mdl)

sd = torch.load(finetuned_mdl_path, map_location=device) # now load finetuned_mdl_path
ast_mdl.load_state_dict(sd, strict=True) # I suggest to use True to see the different and then move to False

To answer the firs question, I commented out line 161
# p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()
I then put the appropriate num mel bins (128) and target length of my dataset (50)

Just to clarify, is the pretrained model path the one for SSAST-Base_Patch-400.pth?
If so, the error is now:
size mismatch for module.v.pos_embed: copying a param with shape torch.Size([1, 514, 768]) from checkpoint, the shape in current model is torch.Size([1, 26, 768]).

Just to clarify, is the pretrained model path the one for SSAST-Base_Patch-400.pth?

Yes, I meant that. Can you let me know where the error code is pointing to?

This is the current output

now load a SSL pretrained models from C:\Users\habashyk\venvs\SSAST2\pretrained_model\SSAST-Base-Patch-400.pth
pretraining patch split stride: frequency=16, time=16
pretraining patch shape: frequency=16, time=16
pretraining patch array dimension: frequency=8, time=3
pretraining number of patches=24
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [7], in <cell line: 8>()
      5 pretrained_mdl_path = r"C:\Users\habashyk\venvs\SSAST2\pretrained_model\SSAST-Base-Patch-400.pth"
      6 finetuned_mdl_path = r"C:\Users\habashyk\venvs\SSAST2\src\finetune\NRC\exp\test01-NRC-f10-16-t10-16-b12-lr1e-4-ft_avgtok-base---1x-noiseFalse-3\models\best_audio_model.pth"
----> 8 ast_mdl = ASTModel(fstride=10,  # should be consistent with that you fine-tune the model with
      9                    tstride=10, 
     10                    fshape=16, 
     11                    tshape=16,
     12                    input_fdim=128, 
     13                    input_tdim=50,
     14                    model_size='base',
     15                    pretrain_stage=False, # should be False
     16                    load_pretrained_mdl_path=pretrained_mdl_path) # should be pretrained_mdl_path, not fine_tuned_mdl_path
     18 ast_mdl = torch.nn.DataParallel(ast_mdl)
     20 sd = torch.load(finetuned_mdl_path, map_location=device) # now load finetuned_mdl_path

File ~\venvs\SSAST2\src\models\ast_models.py:175, in ASTModel.__init__(self, label_dim, fshape, tshape, fstride, tstride, input_fdim, input_tdim, model_size, pretrain_stage, load_pretrained_mdl_path)
    171 audio_model = ASTModel(fstride=p_fshape, tstride=p_tshape, fshape=p_fshape, tshape=p_tshape,
    172                        input_fdim=128, input_tdim=50, pretrain_stage=True,
    173                        model_size=model_size)
    174 audio_model = torch.nn.DataParallel(audio_model)
--> 175 audio_model.load_state_dict(sd, strict=False)
    177 self.v = audio_model.module.v
    178 self.original_embedding_dim = self.v.pos_embed.shape[2]

File ~\Anaconda3\envs\SSAST2\lib\site-packages\torch\nn\modules\module.py:1497, in Module.load_state_dict(self, state_dict, strict)
   1492         error_msgs.insert(
   1493             0, 'Missing key(s) in state_dict: {}. '.format(
   1494                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   1496 if len(error_msgs) > 0:
-> 1497     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1498                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1499 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for DataParallel:
	size mismatch for module.v.pos_embed: copying a param with shape torch.Size([1, 514, 768]) from checkpoint, the shape in current model is torch.Size([1, 26, 768]).

To answer the firs question, I commented out line 161

Can you remove the comment (so keep the code original), and then try the script I suggest?

The output error message remains the same.

That's werid. As to this point

pretrained_mdl_path = a
finetuned_mdl_path = b

ast_mdl = ASTModel(fstride=10,  # should be consistent with that you fine-tune the model with
                   tstride=10, 
                   fshape=16, 
                   tshape=16,
                   input_fdim=128, 
                   input_tdim=50,
                   model_size='base',
                   pretrain_stage=False, # should be False
                   load_pretrained_mdl_path=pretrained_mdl_path) # should be pretrained_mdl_path, not fine_tuned_mdl_path

You should have run this for your fine-tuning, right? Did you get the error that time?

For fine tuning, I ran the bellow run.sh script:

#!/bin/bash
##SBATCH -p sm
##SBATCH -x sls-sm-1,sls-2080-[1,3],sls-1080-3,sls-sm-[5,12]
#SBATCH -p gpu
#SBATCH -x sls-titan-[0-2]
#SBATCH --gres=gpu:4
#SBATCH -c 4
#SBATCH -n 1
#SBATCH --mem=30000
#SBATCH --job-name="ast_as"
#SBATCH --output=./slurm_log/log_%j.txt

set -x
# comment this line if not running on sls cluster
# . /data/sls/scratch/share-201907/slstoolchainrc
# source ../../../venvssast/bin/activate
export TORCH_HOME=../../pretrained_models
mkdir -p ./exp

if [ -e SSAST-Base-Patch-400.pth ]
then
    echo "pretrained model already downloaded."
else
    wget https://www.dropbox.com/s/ewrzpco95n9jdz6/SSAST-Base-Patch-400.pth?dl=1 -O SSAST-Base-Patch-400.pth
fi

pretrain_path=C:/Users/habashyk/venvs/ssast/pretrained_model/SSAST-Base-Patch-400.pth

dataset=NRC
dataset_mean=-0.94124633
dataset_std=2.8967743
target_length=50
noise=False

task=ft_avgtok
model_size=base
head_lr=1

tr_data=C:/Users/habashyk/venvs/ast/src/NRC_Work/Data/Original_split/datafiles/train_data.json
te_data=C:/Users/habashyk/venvs/ast/src/NRC_Work/Data/eval/datafiles/eval_data.json
bal=none
lr=1e-4
freqm=25
timem=0
mixup=0
fstride=10
tstride=10
fshape=16
tshape=16
batch_size=12
epoch=20

exp_dir=./exp/test01-${dataset}-f${fstride}-${fshape}-t${tstride}-${tshape}-b${batch_size}-lr${lr}-${task}-${model_size}-${pretrain_exp}-${pretrain_model}-${head_lr}x-noise${noise}-3

CUDA_CACHE_DISABLE=1 python -W ignore ../../run.py --dataset ${dataset} \
--data-train ${tr_data} --data-val ${te_data} --exp-dir $exp_dir \
--label-csv C:/Users/habashyk/venvs/ast/src/NRC_Work/nrc_class_label_indices.csv --n_class 4 \
--lr $lr --n-epochs ${epoch} --batch-size $batch_size --save_model True \
--freqm $freqm --timem $timem --mixup ${mixup} --bal ${bal} \
--tstride $tstride --fstride $fstride --fshape ${fshape} --tshape ${tshape} --warmup False --task ${task} \
--model_size ${model_size} --adaptschedule False \
--pretrained_mdl_path ${pretrain_path} \
--dataset_mean ${dataset_mean} --dataset_std ${dataset_std} --target_length ${target_length} \
--num_mel_bins 128 --head_lr ${head_lr} --noise ${noise} \
--lrscheduler_start 10 --lrscheduler_step 5 --lrscheduler_decay 0.5 --wa True --wa_start 6 --wa_end 20 \
--loss BCE --metrics mAP

$SHELL

There was no error then

That should call

ssast/src/run.py

Lines 136 to 138 in 35ae7ab

audio_model = ASTModel(label_dim=args.n_class, fshape=args.fshape, tshape=args.tshape, fstride=args.fstride, tstride=args.tstride,
input_fdim=args.num_mel_bins, input_tdim=args.target_length, model_size=args.model_size, pretrain_stage=False,
load_pretrained_mdl_path=args.pretrained_mdl_path)

If you use the same augment, initializing AST should not raise an error.

I noticed I had a mistake in the code, this is the error not with the proper modifications

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [3], in <cell line: 21>()
     18 ast_mdl = torch.nn.DataParallel(ast_mdl)
     20 sd = torch.load(finetuned_mdl_path, map_location=device) # now load finetuned_mdl_path
---> 21 ast_mdl.load_state_dict(sd, strict=True) # I suggest to use True to see the different and then move to False
     23 ast_mdl.cuda()
     24 ast_mdl.eval()

File ~\Anaconda3\envs\SSAST2\lib\site-packages\torch\nn\modules\module.py:1497, in Module.load_state_dict(self, state_dict, strict)
   1492         error_msgs.insert(
   1493             0, 'Missing key(s) in state_dict: {}. '.format(
   1494                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   1496 if len(error_msgs) > 0:
-> 1497     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1498                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1499 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for DataParallel:
	size mismatch for module.mlp_head.1.weight: copying a param with shape torch.Size([4, 768]) from checkpoint, the shape in current model is torch.Size([527, 768]).
	size mismatch for module.mlp_head.1.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([527]).

Fine tuning once again does not give an error, but loading the model afterwards gives the above ^

You should add label_dim=4 when you initialize the AST model.

Amazing!! works fine now. Thank you so much Yuan for your help, let me know if there is anything I can help with!

Thanks, great to know it works

I am going to add an instruction to the readme file on the correct way to do this.

If you don't mind, can you send me the fine-tuned model to my email yuangong@mit.edu? I want to test it by myself. I won't use that for other purposes.

Also, does the performance consistent with that you got from the fine-tuning run?

Hi,

I also ran into a similar issue but instead when I wanted to load the pretrained model in order to do fine tuning. It seems the trick you used here:

p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3]
p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()

Didn't play nice with whatever setup I had. However, the fix was simple - in the 'sd' the variables had an extra 'module' prefix, e.g. when I changed it to

p_fshape, p_tshape = sd['module.module.v.patch_embed.proj.weight'].shape[2], sd['module.module.v.patch_embed.proj.weight'].shape[3]

and

p_input_fdim, p_input_tdim = sd['module.module.p_input_fdim'].item(), sd['module.module.p_input_tdim'].item()

It seemed to work just fine!

Hugo