rsommerfeld / trocr

Powerful handwritten text recognition. A simple-to-use, unofficial implementation of the paper "TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models".

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Accuracy goes to 0.0 frequently

bgaro opened this issue · comments

Hi, i have problem with the training of the model. Indeed the gradient seems to explode frequently but not at every training. Here is a graph that represents this problem.

MicrosoftTeams-image

I've tried to print the prediction of the model at each validation step but when the gradient explode the model keeps predicting empty labels.
I'm using a portion of the IAM dataset and my labels are structured this way : file-name.png,¤label¤
I'm using the character '¤' since it does not appear in the dataset and so i can predict double quotes (I've modified the csv reader to take this character to mark out the label).
I've tried to force the download of the pretrained weights at the beginning of each training without effect.
I've also tried to increase the word len without any effect too.
I'm surely missing something but can't see what.

Do you have any idea what could cause the model to run this way ?
Thanks

Hi, how did you generate those graphs?

Hi, I generated those graphs by parsing the output of the training. For the gradient norm i followed this topic : https://discuss.pytorch.org/t/check-the-norm-of-gradients/27961/5 and simply add it to the debug print. I tried without in case it caused instability but the same problem appears.

Hi bgaro, the pretrained weights are cached locally after downloading them for the first time. The fact that you don't see a downloading bar does not mean the weights are not applied at the beginning of the training. It should load, as long as VisionEncoderDecoderModel.from_pretrained(paths.trocr_repo) in util.py is executed.

Now regarding your training issue:

  • Did you try looking at the predictions before the accuracy goes down? Where they already empty or did they predict reasonable outputs?
  • How much data are you using? You could try increasing the amount of data
  • I would also try a lower learning rate, especially if you have a low number of images. Maybe try 5e-6 first, going even further down if needed. The setting is currently in scripts.py line 53, I might commit an update later that moves the constant to the constants.py file

Let me know if that helps!

Hi, i managed to solve the issue thanks to your help. To answer your questions :

  • Before the accuracy went down the predictions were coherent. The first validation that caused a sudden drop in accuracy sometimes just contained repeating words or was empty and subsequent ones were completely empty.
  • I have a dataset of about 20000 images and haven't tried to increase it, maybe that can also solve the issue.
  • I lowered the learning rate to the one you suggest and this solved the problem. Indeed, the high gradients were still present but did not have anymore a negative impact on learning.
  • I also tried to just clip the gradient to 100 using native pytorch. This seems to help the problem slightly as there is no longer a sudden drop in accuracy but at the expense of the overall accuracy of the network. Maybe another value can work better.

TLDR : Reducing the LR to 5e-6 did solve the issue