p2ch11: label_g[:,1] as target label in computeBatchLoss() ?
4179e1 opened this issue · comments
Hi,
I'm having trouble understanding the label_g[:,1] used in computeBatchLoss():
Lines 225 to 238 in d6c0210
Assume that batch size is 32, the logits_g
will have shape [32, 2]
.
And the label_g
have the same size [32, 2]
, if I didn't get it wrong, it should be the one-hot vector defined in
Lines 203 to 210 in d6c0210
My quesetion is that in the CrossEntropyLoss function, should we use label_g
instead of label_g[:,1]
( which take the 2nd column for each item)? Something like:
loss_g = loss_func(
logits_g,
label_g, # the one-hot vector instead of label_g[:,1]
)
or
loss_g = loss_func(
logits_g,
torch.argmax(label_g, dim=1), # if we want to use the index
)
Thanks
hmm, just realized that it's a shortcut only apply to binary classification: label_g[:,1]
and torch.argmax(label_g, dim=1)
produce the same result:
torch.argmax(t, dim
t = torch.tensor([
[1, 0],
[0, 1],
[1, 0],
])
print (torch.argmax(t, dim=1)) # tensor([0, 1, 0])
print (t[:,1]) # tensor([0, 1, 0])
However it doesn't hold true for mulit classification
t = torch.tensor([
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
])
print (torch.argmax(t, dim=1)) # tensor([0, 1, 2])
print (t[:,1]) # tensor([0, 1, 0])