yangjianxin1 / ClipCap-Chinese

基于ClipCap的看图说话Image Caption模型

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

执行 bash scripts/train_finetune_gpt2.sh 报错 RuntimeError: mat1 dim 1 must match mat2 dim 0

zengyi1001 opened this issue · comments

错误如下:

Traceback (most recent call last):
  File "train.py", line 150, in <module>
    main(args)
  File "train.py", line 131, in main
    train(model, train_dataloader, dev_dataloader, optimizer, scheduler, args)
  File "train.py", line 57, in train
    logits = model(clip_embeds, caption_ids, mask)
  File "/home/patrickzeng/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data2/patrickzeng/LLM/CLIP/ClipCap-Chinese/models/model.py", line 105, in forward
    prefix_embeds = self.clip_project(clip_embeds).view(-1, self.prefix_len, self.prefix_size)
  File "/home/patrickzeng/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data2/patrickzeng/LLM/CLIP/ClipCap-Chinese/models/model.py", line 29, in forward
    return self.model(x)
  File "/home/patrickzeng/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/patrickzeng/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward
    input = module(input)
  File "/home/patrickzeng/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/patrickzeng/.local/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/patrickzeng/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 1690, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0

其中 GTP2 模型加载使用的是如下语句:
self.gpt2 = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-cluecorpussmall")

不知道如何解决?