tsunghan-wu / RandLA-Net-pytorch

:four_leaf_clover: Pytorch Implementation of RandLA-Net (https://arxiv.org/abs/1911.11236)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError: weight tensor should be defined either for all 19 classes or no classes but got weight tensor of shape: [1, 19]

pomeloooo opened this issue · comments

RuntimeError: weight tensor should be defined either for all 19 classes or no classes but got weight tensor of shape: [1, 19]
how to solve

I have same error,have u solve it?

weight‘s shape should be (19,),not (1,19)
in train_SemantciKITTI.py,line 94

    class_weights = torch.from_numpy(train_dataset.get_class_weight()).float().cuda()
    class_weights=class_weights.squeeze(0)
    self.criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='none')

weight‘s shape should be (19,),not (1,19) in train_SemantciKITTI.py,line 94

    class_weights = torch.from_numpy(train_dataset.get_class_weight()).float().cuda()
    class_weights=class_weights.squeeze(0)
    self.criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='none')
    class_weights = torch.from_numpy(train_dataset.get_class_weight()).float().view(19).cuda() # change code to this can run