deep-learning-with-pytorch / dlwpt-code

Code for the book Deep Learning with PyTorch by Eli Stevens, Luca Antiga, and Thomas Viehmann.

Home Page:https://www.manning.com/books/deep-learning-with-pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

p2ch11: label_g[:,1] as target label in computeBatchLoss() ?

4179e1 opened this issue · comments

Hi,

I'm having trouble understanding the label_g[:,1] used in computeBatchLoss():

def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
input_t, label_t, _series_list, _center_list = batch_tup
input_g = input_t.to(self.device, non_blocking=True)
label_g = label_t.to(self.device, non_blocking=True)
logits_g, probability_g = self.model(input_g)
loss_func = nn.CrossEntropyLoss(reduction='none')
loss_g = loss_func(
logits_g,
label_g[:,1],
)
start_ndx = batch_ndx * batch_size

Assume that batch size is 32, the logits_g will have shape [32, 2].
And the label_g have the same size [32, 2], if I didn't get it wrong, it should be the one-hot vector defined in

dlwpt-code/p2ch11/dsets.py

Lines 203 to 210 in d6c0210

pos_t = torch.tensor([
not candidateInfo_tup.isNodule_bool,
candidateInfo_tup.isNodule_bool
],
dtype=torch.long,
)
return candidate_t, pos_t, candidateInfo_tup.series_uid, torch.tensor(center_irc)

My quesetion is that in the CrossEntropyLoss function, should we use label_g instead of label_g[:,1] ( which take the 2nd column for each item)? Something like:

        loss_g = loss_func(
            logits_g,
            label_g, # the one-hot vector instead of label_g[:,1]
        )

or

        loss_g = loss_func(
            logits_g,
            torch.argmax(label_g, dim=1), # if we want to use the index 
        )

Thanks

hmm, just realized that it's a shortcut only apply to binary classification: label_g[:,1] and torch.argmax(label_g, dim=1) produce the same result:

torch.argmax(t, dim
t = torch.tensor([
    [1, 0],
    [0, 1],
    [1, 0],
])
​
print (torch.argmax(t, dim=1))  # tensor([0, 1, 0])
print (t[:,1])                                # tensor([0, 1, 0])

However it doesn't hold true for mulit classification

t = torch.tensor([
    [1, 0, 0],
    [0, 1, 0],
    [0, 0, 1],
])
​
print (torch.argmax(t, dim=1))  # tensor([0, 1, 2])
print (t[:,1])                                # tensor([0, 1, 0])