Shivanandroy / simpleT5

simpleT5 is built on top of PyTorch-lightning⚡️ and Transformers🤗 that lets you quickly train your T5 models.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TypeError: forward() got an unexpected keyword argument 'cross_attn_head_mask In onnx_predict function

farshadfiruzi opened this issue · comments

Hello,
when I run the fine-tuned mt5 model under onnx, I get the following error:

`TypeError Traceback (most recent call last)
in
----> 1 model.onnx_predict(text)

~\AppData\Roaming\Python\Python38\site-packages\simplet5\simplet5.py in onnx_predict(self, source_text)
469 """ generates prediction from ONNX model """
470 token = self.onnx_tokenizer(source_text, return_tensors="pt")
--> 471 tokens = self.onnx_model.generate(
472 input_ids=token["input_ids"],
473 attention_mask=token["attention_mask"],

C:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\grad_mode.py in decorate_context(*args, **kwargs)
26 def decorate_context(*args, **kwargs):
27 with self.class():
---> 28 return func(*args, **kwargs)
29 return cast(F, decorate_context)
30

C:\ProgramData\Anaconda3\lib\site-packages\transformers\generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, **model_kwargs)
1051 input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
1052 )
-> 1053 return self.beam_search(
1054 input_ids,
1055 beam_scorer,

C:\ProgramData\Anaconda3\lib\site-packages\transformers\generation_utils.py in beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
1788 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1789
-> 1790 outputs = self(
1791 **model_inputs,
1792 return_dict=True,

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() got an unexpected keyword argument 'cross_attn_head_mask'`

I tried to downgrade transformers and onnxruntime but the error still remains.

Which transformers/simpleT5 version are you using?

I am using transformers=4.8.2 and simpleT5=0.1.1
Also, I tried newer version of transformers (4.9.0 and 4.9.1) but cant fix error.

The issue is fixed in the latest version.
Install the latest version: pip install --upgrade simplet5

It works perfect now. Thanks a lot.

May I ask, how exactly did you fix this? I'm looking for the PR or code change which fixed it - trying to adapt this code to MBart and I'm getting the exact same error. @Shivanandroy @farshadfiruzi

Hi @radurevutchi , The current version of SimpleT5 only supports training/inference T5/mT5/byT5 models, Support for quantization and onnx runtime is dropped because of version conflict issues.

Below is what SimpleT5 offers:

from simplet5 import SimpleT5
model = SimpleT5()

model.from_pretrained("t5","t5-base")

model.train(train_df=train_df, # pandas dataframe with 2 columns: source_text & target_text
            eval_df=eval_df, # pandas dataframe with 2 columns: source_text & target_text
            source_max_token_len = 512, 
            target_max_token_len = 128,
            batch_size = 8,
            max_epochs = 5,
            use_gpu = True,
            outputdir = "outputs",
            early_stopping_patience_epochs = 0,
            precision = 32
            )

# load trained T5 model
model.load_model("t5","path/to/trained/model/directory", use_gpu=False)

# predict
model.predict("input text for prediction")

If you want to adapt it for mBart or any other models, I will encourage you to write separate methods for quantization and onnx support in addition to training method.
How to export your model to onnx: https://huggingface.co/transformers/serialization.html