clovaai / donut

Official Implementation of OCR-free Document Understanding Transformer (Donut) and Synthetic Document Generator (SynthDoG), ECCV 2022

Home Page:https://arxiv.org/abs/2111.15664

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Input type (float) and bias type (struct c10::BFloat16) should be the same

Coder-Vishali opened this issue · comments

When I try to execute the below code:

_```
from donut import DonutModel
import torch
from PIL import Image

pretrained_model = DonutModel.from_pretrained("naver-clova-ix/donut-base")
if torch.cuda.is_available():
pretrained_model.half()
device = torch.device("cuda")
pretrained_model.to(device)
else:
pretrained_model.encoder.to(torch.bfloat16)
pretrained_model.eval()

task_name = "synthdog"
task_prompt = f"<s_{task_name}>"

input_img = Image.open(r"C:\Project_Files\donut_vqa\image\1_GREEN_crop_31.jpg")
output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
print(output)


I get the below error:

> RuntimeError                              Traceback (most recent call last)
> [c:\Project_Files\donut_vqa\colab_demo_for_donut_base_finetuned_docvqa_230615.ipynb](file:///C:/Project_Files/donut_vqa/colab_demo_for_donut_base_finetuned_docvqa_230615.ipynb) Cell 7 line 1
>      [16](vscode-notebook-cell:/c%3A/Project_Files/donut_vqa/colab_demo_for_donut_base_finetuned_docvqa_230615.ipynb#W6sZmlsZQ%3D%3D?line=15) task_prompt = f"<s_{task_name}>"
>      [18](vscode-notebook-cell:/c%3A/Project_Files/donut_vqa/colab_demo_for_donut_base_finetuned_docvqa_230615.ipynb#W6sZmlsZQ%3D%3D?line=17) input_img = Image.open(r"C:\Project_Files\donut_vqa\image\crop_31.jpg")
> ---> [19](vscode-notebook-cell:/c%3A/Project_Files/donut_vqa/colab_demo_for_donut_base_finetuned_docvqa_230615.ipynb#W6sZmlsZQ%3D%3D?line=18) output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
>      [20](vscode-notebook-cell:/c%3A/Project_Files/donut_vqa/colab_demo_for_donut_base_finetuned_docvqa_230615.ipynb#W6sZmlsZQ%3D%3D?line=19) print(output)
> 
> File [c:\Project_Files\donut_vqa\.venv\lib\site-packages\donut\model.py:452](file:///C:/Project_Files/donut_vqa/.venv/lib/site-packages/donut/model.py:452), in DonutModel.inference(self, image, prompt, image_tensors, prompt_tensors, return_json, return_attentions)
>     448     prompt_tensors = self.decoder.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
>     450 prompt_tensors = prompt_tensors.to(self.device)
> --> 452 last_hidden_state = self.encoder(image_tensors)
>     453 if self.device.type != "cuda":
>     454     last_hidden_state = last_hidden_state.to(torch.float32)
> 
> File [c:\Project_Files\donut_vqa\.venv\lib\site-packages\torch\nn\modules\module.py:1518](file:///C:/Project_Files/donut_vqa/.venv/lib/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
>    1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
>    1517 else:
> -> 1518     return self._call_impl(*args, **kwargs)
> 
> File [c:\Project_Files\donut_vqa\.venv\lib\site-packages\torch\nn\modules\module.py:1527](file:///C:/Project_Files/donut_vqa/.venv/lib/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
>    1522 # If we don't have any hooks, we want to skip the rest of the logic in
>    1523 # this function, and just call forward.
>    1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
>    1525         or _global_backward_pre_hooks or _global_backward_hooks
> ...
>     455                     _pair(0), self.dilation, self.groups)
> --> 456 return F.conv2d(input, weight, bias, self.stride,
>     457                 self.padding, self.dilation, self.groups)
> 
> RuntimeError: Input type (float) and bias type (struct c10::BFloat16) should be the same

What should be the input image shape?
The image which I use is in shape: (19, 273, 3)

check your package versions. You might want to stay with the exact versions listed in the project requirements.txt

check your package versions. You might want to stay with the exact versions listed in the project requirements.txt

I think this problem occurs when you are using CPU instead of GPU.

Use GPU instead of CPU and try again