nlpyang / BertSum

Code for paper Fine-tune BERT for Extractive Summarization

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

expected mask dtype to be Bool but got Long

haidequanbu opened this issue · comments

Hi.when i am tring to reproduce the results using transformers,i got a problem.
Traceback (most recent call last):
File "train.py", line 341, in
train(args, device_id)
File "train.py", line 273, in train
trainer.train(train_iter_fct, args.train_steps)
File "/root/code/BertSum/src/models/trainer.py", line 155, in train
self._gradient_accumulation(
File "/root/code/BertSum/src/models/trainer.py", line 321, in _gradient_accumulation
sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/code/BertSum/src/models/model_builder.py", line 96, in forward
sent_scores = self.encoder(sents_vec, mask_cls).squeeze(-1)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/code/BertSum/src/models/encoder.py", line 97, in forward
x = self.transformer_inter[i](i, x, x, 1 - mask) # all_sents * max_tokens * dim
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/code/BertSum/src/models/encoder.py", line 68, in forward
context = self.self_attn(input_norm, input_norm, input_norm,
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/code/BertSum/src/models/neural.py", line 210, in forward
scores = scores.masked_fill(mask, -1e18)
RuntimeError: expected mask dtype to be Bool but got Long
cuda:11 torch:1.10.1
Does it has any problem with my envirenment?
Thanks!