About the accuracy computation.
rookiecm opened this issue · comments
Thanks for sharing you code. I am a little bit confused about the accuracy computation in L69-70 in wrapper.py:
acc_i = (torch.argmax(image_logits) == ground_truth).sum()
acc_t = (torch.argmax(image_logits.t()) == ground_truth).sum()
It seems that torch.argmax retures the max value index accross all dimensions while ground_truth is with each row or column. Should we change to?
acc_i = (torch.argmax(image_logits, 0) == ground_truth).sum()
acc_t = (torch.argmax(image_logits.t(), 0) == ground_truth).sum()
.
Thanks.
It looks like it! Thanks for catching this bug. I'll send a fix right now.
Made an update and also condensed it into a single term due to the diagonal being the same. Should look good now!