zgcr / SimpleAICV_pytorch_training_examples

SimpleAICV:pytorch training and testing examples.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

reg_head in RetinaNet

nlpLover123 opened this issue · comments

Hi, thanks for your great contributions. I have a question about the implementation of RetinaNet. In losses.py, it seems that the reg_head directly output the absolute position of bounding boxes and l1 loss was calculated by the difference between ground truth bbox positions and reg_head output. Is my understanding correct ?

commented

Hi, thanks for your great contributions. I have a question about the implementation of RetinaNet. In losses.py, it seems that the reg_head directly output the absolute position of bounding boxes and l1 loss was calculated by the difference between ground truth bbox positions and reg_head output. Is my understanding correct ?

reg_head output is not the absolute position of bounding boxes,actually it is tx,ty,tw,th.see relative code in https://github.com/zgcr/simpleAICV-pytorch-ImageNet-COCO-training/blob/master/simpleAICV/detection/losses.py:

def snap_annotations_to_txtytwth(self, anchors_gt_bboxes, anchors):
'''
snap each anchor ground truth bbox form format:[x_min,y_min,x_max,y_max] to format:[tx,ty,tw,th]
'''
anchors_w_h = anchors[:, 2:] - anchors[:, :2]
anchors_ctr = anchors[:, :2] + 0.5 * anchors_w_h

anchors_gt_bboxes_w_h = anchors_gt_bboxes[:,
                                          2:] - anchors_gt_bboxes[:, :2]
anchors_gt_bboxes_w_h = torch.clamp(anchors_gt_bboxes_w_h, min=1e-4)
anchors_gt_bboxes_ctr = anchors_gt_bboxes[:, :
                                          2] + 0.5 * anchors_gt_bboxes_w_h

snaped_annotations_for_anchors = torch.cat(
    [(anchors_gt_bboxes_ctr - anchors_ctr) / anchors_w_h,
     torch.log(anchors_gt_bboxes_w_h / anchors_w_h)],
    dim=1)

# snaped_annotations_for_anchors shape:[anchor_nums, 4]
return snaped_annotations_for_anchors

def snap_txtytwth_to_xyxy(self, snap_boxes, anchors):
'''
snap reg heads to pred bboxes
snap_boxes:[batch_sizeanchor_nums,4],4:[tx,ty,tw,th]
anchors:[batch_size
anchor_nums,4],4:[x_min,y_min,x_max,y_max]
'''
anchors_wh = anchors[:, 2:4] - anchors[:, 0:2]
anchors_ctr = anchors[:, 0:2] + 0.5 * anchors_wh

boxes_wh = torch.exp(snap_boxes[:, 2:4]) * anchors_wh
boxes_ctr = snap_boxes[:, :2] * anchors_wh + anchors_ctr

boxes_x_min_y_min = boxes_ctr - 0.5 * boxes_wh
boxes_x_max_y_max = boxes_ctr + 0.5 * boxes_wh

boxes = torch.cat([boxes_x_min_y_min, boxes_x_max_y_max], dim=1)

# boxes shape:[anchor_nums,4]
return boxes