The model loaded is not from a torch.nn.Dataparallel object
kremHabashy opened this issue · comments
Hi Yuan, I fine tuned the SSAST model on my own data, and am trying to load the model to be used for demonstrative purposes. For some reason, when I load in the model path of the "best_audio_model.pth" that is generated after fine tuning, I get the following error from ast_models.py:
ValueError: The model loaded is not from a torch.nn.Dataparallel object. Wrap it with torch.nn.Dataparallel and try again.
I'm not really sure why I am getting this error as I am following the instructions shown in the example at the if name == "main" statement. What am I missing? Shouldn't the created models be already wrapped with dataparallel as seen in traintest.py?
Hi there,
It is a known bug that is not related to dataparallel.
The problem is that I used a trick to encode the pretraining hyperparameters in the model and use the existence of the hyperparameter to check if the model is a dataparallel object. The SSL pretraining code do save the hyperparameters but the fine-tuning code does not, so when you do another round of testing, the code cannot find the hyperparameter and think the model is not dataparallel.
ssast/src/models/ast_models.py
Lines 146 to 147 in 35ae7ab
For a temporal workaround, you can change these two lines of code:
ssast/src/models/ast_models.py
Lines 146 to 147 in 35ae7ab
I will find a time to fix it.
Hey Yuan, sorry, what should I change them to?
Can you paste the code you initialize the AST model in fine-tuning? (not inference)
# initialize an AST model
torch.cuda.set_device('cuda:0')
device = torch.device("cuda:0")
pretrained_mdl_path =r"C:\Users\habashyk\venvs\SSAST2\src\finetune\NRC\exp\test01-NRC-f10-16-t10-16-b12-lr1e-4-ft_avgtok-base---1x-noiseFalse-3\models\best_audio_model.pth"
p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3]
ast_mdl = ASTModel(fstride=10,
tstride=10,
fshape=16,
tshape=16,
input_fdim=128,
input_tdim=50,
model_size='base',
pretrain_stage=False,
load_pretrained_mdl_path=pretrained_mdl_path)
ast_mdl = torch.nn.DataParallel(ast_mdl)
ast_mdl.load_state_dict(sd, strict=False)
ast_mdl.cuda()
ast_mdl.eval()
This now gives the error:
RuntimeError: Error(s) in loading state_dict for DataParallel: size mismatch for module.v.pos_embed: copying a param with shape torch.Size([1, 50, 768]) from checkpoint, the shape in current model is torch.Size([1, 26, 768]).
Have you changed anything in ast_models.py
?
I guess the following might work
pretrained_mdl_path = a
finetuned_mdl_path = b
ast_mdl = ASTModel(fstride=10, # should be consistent with that you fine-tune the model with
tstride=10,
fshape=16,
tshape=16,
input_fdim=128,
input_tdim=50,
model_size='base',
pretrain_stage=False, # should be False
load_pretrained_mdl_path=pretrained_mdl_path) # should be pretrained_mdl_path, not fine_tuned_mdl_path
ast_mdl = torch.nn.DataParallel(ast_mdl)
sd = torch.load(finetuned_mdl_path, map_location=device) # now load finetuned_mdl_path
ast_mdl.load_state_dict(sd, strict=True) # I suggest to use True to see the different and then move to False
To answer the firs question, I commented out line 161
# p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()
I then put the appropriate num mel bins (128) and target length of my dataset (50)
Just to clarify, is the pretrained model path the one for SSAST-Base_Patch-400.pth?
If so, the error is now:
size mismatch for module.v.pos_embed: copying a param with shape torch.Size([1, 514, 768]) from checkpoint, the shape in current model is torch.Size([1, 26, 768]).
Just to clarify, is the pretrained model path the one for SSAST-Base_Patch-400.pth?
Yes, I meant that. Can you let me know where the error code is pointing to?
This is the current output
now load a SSL pretrained models from C:\Users\habashyk\venvs\SSAST2\pretrained_model\SSAST-Base-Patch-400.pth
pretraining patch split stride: frequency=16, time=16
pretraining patch shape: frequency=16, time=16
pretraining patch array dimension: frequency=8, time=3
pretraining number of patches=24
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [7], in <cell line: 8>()
5 pretrained_mdl_path = r"C:\Users\habashyk\venvs\SSAST2\pretrained_model\SSAST-Base-Patch-400.pth"
6 finetuned_mdl_path = r"C:\Users\habashyk\venvs\SSAST2\src\finetune\NRC\exp\test01-NRC-f10-16-t10-16-b12-lr1e-4-ft_avgtok-base---1x-noiseFalse-3\models\best_audio_model.pth"
----> 8 ast_mdl = ASTModel(fstride=10, # should be consistent with that you fine-tune the model with
9 tstride=10,
10 fshape=16,
11 tshape=16,
12 input_fdim=128,
13 input_tdim=50,
14 model_size='base',
15 pretrain_stage=False, # should be False
16 load_pretrained_mdl_path=pretrained_mdl_path) # should be pretrained_mdl_path, not fine_tuned_mdl_path
18 ast_mdl = torch.nn.DataParallel(ast_mdl)
20 sd = torch.load(finetuned_mdl_path, map_location=device) # now load finetuned_mdl_path
File ~\venvs\SSAST2\src\models\ast_models.py:175, in ASTModel.__init__(self, label_dim, fshape, tshape, fstride, tstride, input_fdim, input_tdim, model_size, pretrain_stage, load_pretrained_mdl_path)
171 audio_model = ASTModel(fstride=p_fshape, tstride=p_tshape, fshape=p_fshape, tshape=p_tshape,
172 input_fdim=128, input_tdim=50, pretrain_stage=True,
173 model_size=model_size)
174 audio_model = torch.nn.DataParallel(audio_model)
--> 175 audio_model.load_state_dict(sd, strict=False)
177 self.v = audio_model.module.v
178 self.original_embedding_dim = self.v.pos_embed.shape[2]
File ~\Anaconda3\envs\SSAST2\lib\site-packages\torch\nn\modules\module.py:1497, in Module.load_state_dict(self, state_dict, strict)
1492 error_msgs.insert(
1493 0, 'Missing key(s) in state_dict: {}. '.format(
1494 ', '.join('"{}"'.format(k) for k in missing_keys)))
1496 if len(error_msgs) > 0:
-> 1497 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1498 self.__class__.__name__, "\n\t".join(error_msgs)))
1499 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for DataParallel:
size mismatch for module.v.pos_embed: copying a param with shape torch.Size([1, 514, 768]) from checkpoint, the shape in current model is torch.Size([1, 26, 768]).
To answer the firs question, I commented out line 161
Can you remove the comment (so keep the code original), and then try the script I suggest?
The output error message remains the same.
That's werid. As to this point
pretrained_mdl_path = a
finetuned_mdl_path = b
ast_mdl = ASTModel(fstride=10, # should be consistent with that you fine-tune the model with
tstride=10,
fshape=16,
tshape=16,
input_fdim=128,
input_tdim=50,
model_size='base',
pretrain_stage=False, # should be False
load_pretrained_mdl_path=pretrained_mdl_path) # should be pretrained_mdl_path, not fine_tuned_mdl_path
You should have run this for your fine-tuning, right? Did you get the error that time?
For fine tuning, I ran the bellow run.sh script:
#!/bin/bash
##SBATCH -p sm
##SBATCH -x sls-sm-1,sls-2080-[1,3],sls-1080-3,sls-sm-[5,12]
#SBATCH -p gpu
#SBATCH -x sls-titan-[0-2]
#SBATCH --gres=gpu:4
#SBATCH -c 4
#SBATCH -n 1
#SBATCH --mem=30000
#SBATCH --job-name="ast_as"
#SBATCH --output=./slurm_log/log_%j.txt
set -x
# comment this line if not running on sls cluster
# . /data/sls/scratch/share-201907/slstoolchainrc
# source ../../../venvssast/bin/activate
export TORCH_HOME=../../pretrained_models
mkdir -p ./exp
if [ -e SSAST-Base-Patch-400.pth ]
then
echo "pretrained model already downloaded."
else
wget https://www.dropbox.com/s/ewrzpco95n9jdz6/SSAST-Base-Patch-400.pth?dl=1 -O SSAST-Base-Patch-400.pth
fi
pretrain_path=C:/Users/habashyk/venvs/ssast/pretrained_model/SSAST-Base-Patch-400.pth
dataset=NRC
dataset_mean=-0.94124633
dataset_std=2.8967743
target_length=50
noise=False
task=ft_avgtok
model_size=base
head_lr=1
tr_data=C:/Users/habashyk/venvs/ast/src/NRC_Work/Data/Original_split/datafiles/train_data.json
te_data=C:/Users/habashyk/venvs/ast/src/NRC_Work/Data/eval/datafiles/eval_data.json
bal=none
lr=1e-4
freqm=25
timem=0
mixup=0
fstride=10
tstride=10
fshape=16
tshape=16
batch_size=12
epoch=20
exp_dir=./exp/test01-${dataset}-f${fstride}-${fshape}-t${tstride}-${tshape}-b${batch_size}-lr${lr}-${task}-${model_size}-${pretrain_exp}-${pretrain_model}-${head_lr}x-noise${noise}-3
CUDA_CACHE_DISABLE=1 python -W ignore ../../run.py --dataset ${dataset} \
--data-train ${tr_data} --data-val ${te_data} --exp-dir $exp_dir \
--label-csv C:/Users/habashyk/venvs/ast/src/NRC_Work/nrc_class_label_indices.csv --n_class 4 \
--lr $lr --n-epochs ${epoch} --batch-size $batch_size --save_model True \
--freqm $freqm --timem $timem --mixup ${mixup} --bal ${bal} \
--tstride $tstride --fstride $fstride --fshape ${fshape} --tshape ${tshape} --warmup False --task ${task} \
--model_size ${model_size} --adaptschedule False \
--pretrained_mdl_path ${pretrain_path} \
--dataset_mean ${dataset_mean} --dataset_std ${dataset_std} --target_length ${target_length} \
--num_mel_bins 128 --head_lr ${head_lr} --noise ${noise} \
--lrscheduler_start 10 --lrscheduler_step 5 --lrscheduler_decay 0.5 --wa True --wa_start 6 --wa_end 20 \
--loss BCE --metrics mAP
$SHELL
There was no error then
That should call
Lines 136 to 138 in 35ae7ab
If you use the same augment, initializing AST should not raise an error.
I noticed I had a mistake in the code, this is the error not with the proper modifications
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [3], in <cell line: 21>()
18 ast_mdl = torch.nn.DataParallel(ast_mdl)
20 sd = torch.load(finetuned_mdl_path, map_location=device) # now load finetuned_mdl_path
---> 21 ast_mdl.load_state_dict(sd, strict=True) # I suggest to use True to see the different and then move to False
23 ast_mdl.cuda()
24 ast_mdl.eval()
File ~\Anaconda3\envs\SSAST2\lib\site-packages\torch\nn\modules\module.py:1497, in Module.load_state_dict(self, state_dict, strict)
1492 error_msgs.insert(
1493 0, 'Missing key(s) in state_dict: {}. '.format(
1494 ', '.join('"{}"'.format(k) for k in missing_keys)))
1496 if len(error_msgs) > 0:
-> 1497 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1498 self.__class__.__name__, "\n\t".join(error_msgs)))
1499 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for DataParallel:
size mismatch for module.mlp_head.1.weight: copying a param with shape torch.Size([4, 768]) from checkpoint, the shape in current model is torch.Size([527, 768]).
size mismatch for module.mlp_head.1.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([527]).
Fine tuning once again does not give an error, but loading the model afterwards gives the above ^
You should add label_dim=4
when you initialize the AST model.
Amazing!! works fine now. Thank you so much Yuan for your help, let me know if there is anything I can help with!
Thanks, great to know it works
I am going to add an instruction to the readme file on the correct way to do this.
If you don't mind, can you send me the fine-tuned model to my email yuangong@mit.edu? I want to test it by myself. I won't use that for other purposes.
Also, does the performance consistent with that you got from the fine-tuning run?
Hi,
I also ran into a similar issue but instead when I wanted to load the pretrained model in order to do fine tuning. It seems the trick you used here:
ssast/src/models/ast_models.py
Lines 146 to 147 in 35ae7ab
Didn't play nice with whatever setup I had. However, the fix was simple - in the 'sd' the variables had an extra 'module' prefix, e.g. when I changed it to
p_fshape, p_tshape = sd['module.module.v.patch_embed.proj.weight'].shape[2], sd['module.module.v.patch_embed.proj.weight'].shape[3]
and
p_input_fdim, p_input_tdim = sd['module.module.p_input_fdim'].item(), sd['module.module.p_input_tdim'].item()
It seemed to work just fine!
Hugo