Multi scale testing results reproduction
MendelXu opened this issue · comments
Thanks for sharing your code. The multi-scale testing result of fold 0 is 52.5 mIOU according to your paper but I can only get 51~ mIOU. Could you give me some advice on how to reproduce the same result as yours? Below is my multi-scale testing code modified from train.py
.
with torch.no_grad():
print ('----Evaluation----')
model = model.eval()
valset.history_mask_list=[None] * 1000
all_inter, all_union, all_predict = [0] * 5, [0] * 5, [0] * 5
for i_iter, batch in enumerate(valloader):
query_rgb, query_mask, support_rgb, support_mask, history_mask, sample_class, index = batch
query_rgb = (query_rgb).cuda(0)
support_rgb = (support_rgb).cuda(0)
support_mask = (support_mask).cuda(0)
query_mask = (query_mask).cuda(0).long() # change formation for crossentropy use
query_mask = query_mask[:, 0, :, :] # remove the second dim,change formation for crossentropy use
history_mask = (history_mask).cuda(0)
pred_softmax = torch.zeros(1,2,query_rgb.size(-2),query_rgb.size(-1)).cuda(0)
for scale in [0.7,1,1.3]:
query_= nn.functional.interpolate(query_rgb,scale_factor=scale,mode='bilinear',align_corners=True)
scale_pred = model(query_, support_rgb, support_mask,history_mask)
scale_pred_softmax = F.softmax(scale_pred, dim=1)
pred_softmax += nn.functional.interpolate(scale_pred_softmax,size=query_rgb.size()[-2:], mode='bilinear',
align_corners=True)
pred_softmax/=3.
# update history mask
for j in range(support_mask.shape[0]):
sub_index = index[j]
valset.history_mask_list[sub_index] = pred_softmax[j]
# pred = nn.functional.interpolate(pred, size=query_rgb.size()[-2:], mode='bilinear',
# align_corners=True) #upsample # upsample
_, pred_label = torch.max(pred_softmax, 1)
inter_list, union_list, _, num_predict_list = get_iou_v1(query_mask, pred_label)
for j in range(query_mask.shape[0]):#batch size
all_inter[sample_class[j] - (options.fold * 5 + 1)] += inter_list[j]
all_union[sample_class[j] - (options.fold * 5 + 1)] += union_list[j]
IOU = [0] * 5
for j in range(5):
IOU[j] = all_inter[j] / all_union[j]
mean_iou = np.mean(IOU)
print('IOU:%.4f' % (mean_iou))
Thanks very much.
hello, it seems you do not have iter_time in your code, thus ignoring the IOM module.
Got it. Thank you very much.