weiaicunzai / pytorch-cifar100

Practice on cifar100(ResNet, DenseNet, VGG, GoogleNet, InceptionV3, InceptionV4, Inception-ResNetv2, Xception, Resnet In Resnet, ResNext,ShuffleNet, ShuffleNetv2, MobileNet, MobileNetv2, SqueezeNet, NasNet, Residual Attention Network, SENet, WideResNet)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to handle unbalanced data?

jS5t3r opened this issue · comments

Replacing in train.py the line:

loss_function = nn.CrossEntropyLoss()

with

cw = torch.tensor([0.11850941, 0.14937713, 0.11032023, 0.62179323], dtype=torch.float32).cuda() # class weights for 1, 2, 3, 4
loss_function = nn.CrossEntropyLoss(ignore_index=255, weight=cw, reduction='mean')

I get the error:

RuntimeError: weight tensor should be defained either for all 100 classes or no classes but got weight tensor of shape: [4] 

Analyzing the code:

loss = loss_function(outputs, labels)

Debug output:

outputs.shape
torch.Size([1, 100])

labels
tensor([3], device='cuda:0')

What do I need to change for having a loss for multiple classes?