Zasder3 / train-CLIP

A PyTorch Lightning solution to training OpenAI's CLIP from scratch.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

About the accuracy computation.

rookiecm opened this issue · comments

Thanks for sharing you code. I am a little bit confused about the accuracy computation in L69-70 in wrapper.py:

acc_i = (torch.argmax(image_logits) == ground_truth).sum()
acc_t = (torch.argmax(image_logits.t()) == ground_truth).sum()

It seems that torch.argmax retures the max value index accross all dimensions while ground_truth is with each row or column. Should we change to?

acc_i = (torch.argmax(image_logits, 0) == ground_truth).sum()
acc_t = (torch.argmax(image_logits.t(), 0) == ground_truth).sum().

Thanks.

It looks like it! Thanks for catching this bug. I'll send a fix right now.

Made an update and also condensed it into a single term due to the diagonal being the same. Should look good now!