facebookresearch / ijepa

Official codebase for I-JEPA, the Image-based Joint-Embedding Predictive Architecture. First outlined in the CVPR paper, "Self-supervised learning from images with a joint-embedding predictive architecture."

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Training loss increases

ChristianEschen opened this issue · comments

Hi.

I am trying to train ijepa vit-huge_16_448 on a dataset with medical images.
I use 24 A100 GPUS with 40Gb of memory.
I have adjusted the learning rate using the linear scaling rule
The original ijepa uses lr=0.001 with batch_size=16, gpu=16. This gives a total batch size of 256.
In my experiment I have batch_size=6 and gpus=24. This gives a total batch size of 144.
The fraction between these numbers is 144/256=0.56
So my learning rate should be 0.001*0.56 = 0.00056
The loss is decreasing in the beginning, but after 3 epochs is starts increasing:

1 epoch avg loss :0.028
2 epoch avg loss: 0.005
3 epoch avg loss: 0.005
4 epoch avg loss: 0.006
5 epoch avg loss: 0.008
6 epoch avg loss: 0.012
7 epoch avg loss: 0.015

Why does the loss increase at this point?

commented

Hi @ChristianEschen,

Just a few points:

  • the default batch size is 2048 in all experiments (e.g., see the vit-huge_14_224 config which is run on 16 GPUs). The provided config for the H/16-448 is intended for use with more than 16 GPUs.
  • the linear learning rate scaling rules don't hold very well with self-supervised vision transformers. If you want to run with a smaller batch size, you could first try keeping all other hyper-parameters fixed, and then tune as needed if you see a performance drop
  • finally, yes the behaviour you observe with the training loss is actually correct. It is a bit counterintuitive, but as the network gets better, the representations become more semantic and the task becomes harder, hence the loss goes up. From an optimization perspective, even though this is the loss you actually compute during training, the momentum (moving average) update of the target encoder can result in this loss actually increasing.

edit:
will keep this task open for now, let me know if you still need support (otherwise will close the task).

I'm curious about how one can develop and train a new model when observing an increasing loss during training. In my experience, monitoring the training loss is a fundamental step in the debugging process for neural networks

Hi @MidoAssran,
So, in this case the training loss continue increasing, how could we save the checkpoint? I am using min validation's loss criterion, but this due to the best model is epoch 3, in this repo's code I can't find the define of save checkpoint criterion.