RuntimeError: weight tensor should be defined either for all 19 classes or no classes but got weight tensor of shape: [1, 19]
pomeloooo opened this issue · comments
RuntimeError: weight tensor should be defined either for all 19 classes or no classes but got weight tensor of shape: [1, 19]
how to solve
I have same error,have u solve it?
weight‘s shape should be (19,),not (1,19)
in train_SemantciKITTI.py,line 94
class_weights = torch.from_numpy(train_dataset.get_class_weight()).float().cuda()
class_weights=class_weights.squeeze(0)
self.criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='none')
weight‘s shape should be (19,),not (1,19) in train_SemantciKITTI.py,line 94
class_weights = torch.from_numpy(train_dataset.get_class_weight()).float().cuda() class_weights=class_weights.squeeze(0) self.criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='none')
class_weights = torch.from_numpy(train_dataset.get_class_weight()).float().view(19).cuda() # change code to this can run