How to handle unbalanced data?
jS5t3r opened this issue · comments
Peter Lorenz commented
Replacing in train.py the line:
loss_function = nn.CrossEntropyLoss()
with
cw = torch.tensor([0.11850941, 0.14937713, 0.11032023, 0.62179323], dtype=torch.float32).cuda() # class weights for 1, 2, 3, 4
loss_function = nn.CrossEntropyLoss(ignore_index=255, weight=cw, reduction='mean')
I get the error:
RuntimeError: weight tensor should be defained either for all 100 classes or no classes but got weight tensor of shape: [4]
Analyzing the code:
loss = loss_function(outputs, labels)
Debug output:
outputs.shape
torch.Size([1, 100])
labels
tensor([3], device='cuda:0')
What do I need to change for having a loss for multiple classes?