rinnakk / japanese-pretrained-models

Code for producing Japanese pretrained models provided by rinna Co., Ltd.

Home Page:https://huggingface.co/rinna

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Tensor size does not match

hirokisince1998 opened this issue · comments

Description

GPT-2 train fails with an error "RuntimeError: The size of tensor a (768) must match the size of tensor b (1024) at non-singleton dimension 3".

I followed the steps of "Train japanese-gpt2-xsmall from scratch", except that n_gpus was set to 1 and mecab_dict_path was changed to the path of unidic-csj-3.0.1.1.

What's wrong?

Full output of python -m task.pretrain_gpt2.train:

local rank: [0], global_rank: [0]
Number of training files: 502
Number of dev files: 1
----- Loading dev data -----
{'n_docs': 10000, 'n_sents': 131762, 'n_tokens': 4241376}
----- Hyper-parameters -----
balanced_corpora: None
batch_size: 20
check_loss_after_n_step: 100.0
checkpoint_path: None
corpora: ['jp_cc100', 'jp_wiki']
enable_log: True
eval_batch_size: 40
filename_note: None
init_lr: 0.0007
l2_penalty: 0.01
master_port: 12321
max_grad_norm: 1.0
max_seq_len: 1024
model_config_filepath: model/gpt2-ja-xsmall-config.json
model_size: xsmall
n_accum_steps: 3
n_epochs: 10
n_gpus: 1
n_nodes: 1
n_train_files_per_group: 10
n_training_steps: 1600000
n_warmup_steps: 2000.0
node_rank: 0
resume_training: False
save_model: True
seed: 42
small_data: False
use_amp: True
validate_after_n_step: 5000.0
world_size: 1
{'n_docs': 1367409, 'n_sents': 8632681, 'n_tokens': 288213354}
Traceback (most recent call last):
  File "/Users/hiroki/.conda/envs/transformers/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/hiroki/.conda/envs/transformers/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/var/tmp/hiroki/japanese-pretrained-models/src/task/pretrain_gpt2/train.py", line 580, in <module>
    train(0, config)
  File "/var/tmp/hiroki/japanese-pretrained-models/src/task/pretrain_gpt2/train.py", line 409, in train
    loss, ppl = forward_step(model, tokenizer, batch_data)
  File "/var/tmp/hiroki/japanese-pretrained-models/src/task/pretrain_gpt2/train.py", line 85, in forward_step
    gpt2_outputs = model(input_ids=input_ids, return_dict=True)
  File "/Users/hiroki/.conda/envs/transformers/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/hiroki/.conda/envs/transformers/lib/python3.8/site-packages/transformers-4.4.2-py3.8.egg/transformers/models/gpt2/modeling_gpt2.py", line 904, in forward
  File "/Users/hiroki/.conda/envs/transformers/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/hiroki/.conda/envs/transformers/lib/python3.8/site-packages/transformers-4.4.2-py3.8.egg/transformers/models/gpt2/modeling_gpt2.py", line 752, in forward
  File "/Users/hiroki/.conda/envs/transformers/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/hiroki/.conda/envs/transformers/lib/python3.8/site-packages/transformers-4.4.2-py3.8.egg/transformers/models/gpt2/modeling_gpt2.py", line 290, in forward
  File "/Users/hiroki/.conda/envs/transformers/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/hiroki/.conda/envs/transformers/lib/python3.8/site-packages/transformers-4.4.2-py3.8.egg/transformers/models/gpt2/modeling_gpt2.py", line 241, in forward
  File "/Users/hiroki/.conda/envs/transformers/lib/python3.8/site-packages/transformers-4.4.2-py3.8.egg/transformers/models/gpt2/modeling_gpt2.py", line 176, in _attn
RuntimeError: The size of tensor a (768) must match the size of tensor b (1024) at non-singleton dimension 3

Environment

python == 3.8.13
PyTorch == 1.12.1
transformers == 4.4.2

Hi, I have made a commit to fix a misalignment in the config file.
Please try it again, preferably with a newer version of transformers.

Solved, thank you.

cool.