HiLab-git / SSL4MIS

Semi Supervised Learning for Medical Image Segmentation, a collection of literature reviews and code implementations.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

train_cross_pseudo_supervision.py unnecessary code

IcecreamArtist opened this issue · comments

commented

Hi,

In the code:

loss1 = 0.5 * (ce_loss(outputs1[:args.labeled_bs],
                                   label_batch[:][:args.labeled_bs].long()) + dice_loss(
                outputs_soft1[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1)))
            loss2 = 0.5 * (ce_loss(outputs2[:args.labeled_bs],
                                   label_batch[:][:args.labeled_bs].long()) + dice_loss(
                outputs_soft2[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1)))

for ce_loss, it fetch 'label_batch[:][:args.labeled_bs]' which can be simplified to 'label_batch[:args.labeled_bs]'.

Welcome discussion if there is any mistask.