Not able to execute pengi.generate and pengi.decribe
radhavishnu opened this issue · comments
Vishnu Radhakrishnan commented
On successfully pip installing and running
from wrapper import PengiWrapper as Pengi
pengi = Pengi(config="base")
generated_response = pengi.generate(audio_paths='/content/Robin.mp3',
text_prompts=["generate metadata"],
add_texts=[""],
max_len=30,
beam_size=3,
temperature=1.0,
stop_token=' <|endoftext|>'
)
the following error message comes
RuntimeError Traceback (most recent call last)
/content/Pengi/wrapper.py in get_model_and_tokenizer(self, config_path)
93 try:
---> 94 model.load_state_dict(model_state_dict)
95 except:
4 frames
RuntimeError: Error(s) in loading state_dict for PENGI:
Unexpected key(s) in state_dict: "caption_encoder.base.embeddings.position_ids", "caption_decoder.gpt.transformer.h.0.attn.bias", "caption_decoder.gpt.transformer.h.0.attn.masked_bias", "caption_decoder.gpt.transformer.h.1.attn.bias", "caption_decoder.gpt.transformer.h.1.attn.masked_bias", "caption_decoder.gpt.transformer.h.2.attn.bias", "caption_decoder.gpt.transformer.h.2.attn.masked_bias", "caption_decoder.gpt.transformer.h.3.attn.bias", "caption_decoder.gpt.transformer.h.3.attn.masked_bias", "caption_decoder.gpt.transformer.h.4.attn.bias", "caption_decoder.gpt.transformer.h.4.attn.masked_bias", "caption_decoder.gpt.transformer.h.5.attn.bias", "caption_decoder.gpt.transformer.h.5.attn.masked_bias", "caption_decoder.gpt.transformer.h.6.attn.bias", "caption_decoder.gpt.transformer.h.6.attn.masked_bias", "caption_decoder.gpt.transformer.h.7.attn.bias", "caption_decoder.gpt.transformer.h.7.attn.masked_bias", "caption_decoder.gpt.transformer.h.8.attn.bias", "caption_decoder.gpt.transformer.h.8.attn.masked_bias", "caption_decoder.gpt.transformer.h.9.attn.bias", "caption_decoder.gpt.transformer.h.9.attn.masked_bias", "caption_decoder.gpt.transformer.h.10.attn.bias", "caption_decoder.gpt.transformer.h.10.attn.masked_bias", "caption_decoder.gpt.transformer.h.11.attn.bias", "caption_decoder.gpt.transformer.h.11.attn.masked_bias".
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
2150
2151 if len(error_msgs) > 0:
-> 2152 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2153 self.__class__.__name__, "\n\t".join(error_msgs)))
2154 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for PENGI:
Missing key(s) in state_dict: "audio_encoder.base.htsat.spectrogram_extractor.stft.conv_real.weight", "audio_encoder.base.htsat.spectrogram_extractor.stft.conv_imag.weight", "audio_encoder.base.htsat.logmel_extractor.melW", "audio_encoder.base.htsat.bn0.weight", "audio_encoder.base.htsat.bn0.bias", "audio_encoder.base.htsat.bn0.running_mean", "audio_encoder.base.htsat.bn0.running_var", "audio_encoder.base.htsat.patch_embed.proj.weight", "audio_encoder.base.htsat.patch_embed.proj.bias", "audio_encoder.base.htsat.patch_embed.norm.weight", "audio_encoder.base.htsat.patch_embed.norm.bias", "audio_encoder.base.htsat.layers.0.blocks.0.norm1.weight", "audio_encoder.base.htsat.layers.0.blocks.0.norm1.bias", "audio_encoder.base.htsat.layers.0.blocks.0.attn.relative_position_bias_table", "audio_encoder.base.htsat.layers.0.blocks.0.attn.relative_position_index", "audio_encoder.base.htsat.layers.0.blocks.0.attn.qkv.weight", "audio_encoder.base.htsat.layers.0.blocks.0.attn.qkv.bias", "audio_encoder.base.htsat.layers.0.blocks.0.attn.proj.weight", "audio_encoder.base.htsat.layers.0.blocks.0.attn.proj.bias", "audio_encoder.base.htsat.layers.0.blocks.0.norm2.weight", "audio_encoder.base.htsat.layers.0.blocks.0.norm2.bias", "audio_encoder.base.htsat.layers.0.blocks.0.mlp.fc1.weight", "audio_encoder.base.htsat.layers.0.blocks.0.mlp.fc1.bias", "audio_encoder.base.htsat.layers.0.blocks.0.mlp.fc2.weight", "audio_encoder.base.htsat.layers.0.blocks.0.mlp.fc2.bias", "audio_encoder.base.htsat.laye...
Unexpected key(s) in state_dict: "ncoder.base.htsat.spectrogram_extractor.stft.conv_real.weight", "ncoder.base.htsat.spectrogram_extractor.stft.conv_imag.weight", "ncoder.base.htsat.logmel_extractor.melW", "ncoder.base.htsat.bn0.weight", "ncoder.base.htsat.bn0.bias", "ncoder.base.htsat.bn0.running_mean", "ncoder.base.htsat.bn0.running_var", "ncoder.base.htsat.bn0.num_batches_tracked", "ncoder.base.htsat.patch_embed.proj.weight", "ncoder.base.htsat.patch_embed.proj.bias", "ncoder.base.htsat.patch_embed.norm.weight", "ncoder.base.htsat.patch_embed.norm.bias", "ncoder.base.htsat.layers.0.blocks.0.norm1.weight", "ncoder.base.htsat.layers.0.blocks.0.norm1.bias", "ncoder.base.htsat.layers.0.blocks.0.attn.relative_position_bias_table", "ncoder.base.htsat.layers.0.blocks.0.attn.relative_position_index", "ncoder.base.htsat.layers.0.blocks.0.attn.qkv.weight", "ncoder.base.htsat.layers.0.blocks.0.attn.qkv.bias", "ncoder.base.htsat.layers.0.blocks.0.attn.proj.weight", "ncoder.base.htsat.layers.0.blocks.0.attn.proj.bias", "ncoder.base.htsat.layers.0.blocks.0.norm2.weight", "ncoder.base.htsat.layers.0.blocks.0.norm2.bias", "ncoder.base.htsat.layers.0.blocks.0.mlp.fc1.weight", "ncoder.base.htsat.layers.0.blocks.0.mlp.fc1.bias", "ncoder.base.htsat.layers.0.blocks.0.mlp.fc2.weight", "ncoder.base.htsat.layers.0.blocks.0.mlp.fc2.bias", "ncoder.base.htsat.layers.0.blocks.1.attn_mask", "ncoder.base.htsat.layers.0.blocks.1.norm1.weight", "ncoder.base.htsat.layers.0.blocks.1.norm1.bias", "ncode...
Soham commented
Have you downloaded the checkpoints and moved them to the configs
folder?
Vishnu Radhakrishnan commented
Yes downloaded and saved as base.pth and base_no_text_enc.pth in the config folder
Soham commented
Two issues:
- There might be a file name mismatch at your end. I can reproduce your error when I switch weight file names i.e. rename
base_no_text_enc.pth
tobase.pth
. Maybe there is a file name switch during the copy at your end? - Unrelated, but the audio file has to be passed as a list:
audio_paths=['/content/Robin.mp3']