How to convert yolo layer output to bounding box coordinates?
greyowl opened this issue · comments
yolo layers seem to have an output shape of
[ batch, height, width, 3 * (5 + num_classes)]
where height and width are size of the grid.
this is what I've tried referencing model.py
def get_bbox_with_anchors(prediction, anchors, classes_num=8):
# split by anchor box
# split into three
split_preds = tf.split(prediction, 3, axis=-1)
pred = tf.stack(split_preds, axis=-2)
# get grid shape
grid_size = tf.shape(pred)[1]
box_xy, box_wh, objectness, class_probs = tf.split(
pred, (2, 2, 1, classes_num), axis=-1)
# apply transforms
# box_xy = tf.sigmoid(box_xy)
# darknet does not seem to appy sigmoid on offset
objectness = tf.sigmoid(objectness)
class_probs = tf.sigmoid(class_probs)
pred_box = tf.concat((box_xy, box_wh), axis=-1) # original xywh for loss
# calculate offsets
grid = tf.meshgrid(tf.range(grid_size), tf.range(grid_size))
grid = tf.expand_dims(tf.stack(grid, axis=-1), axis=2) # [gx, gy, 1, 2]
box_xy = (box_xy + tf.cast(grid, tf.float32)) / \
tf.cast(grid_size, tf.float32)
box_wh = tf.exp(box_wh) * anchors
box_x1y1 = box_xy - box_wh / 2
box_x2y2 = box_xy + box_wh / 2
bbox = tf.concat([box_x1y1, box_x2y2], axis=-1)
return bbox, objectness, class_probs
overlaying these bbox to the original image seems to be off quite a bit
I am not sure what I am doing wrong.
Any help is appreciated
@greyowl Hey, I'm sort of new to TF, did you resolve this finally? How did you get the actual bbox coordinates from the output layers?
I don't have the code but I ended up implementing the output -> bbox coordinates in tensorflow
I'll look that up, thanks a lot!
@nishanthcgit Hi, were you able to figure out how to get bounding boxes? I tried referencing other repositories including model.py, but I always got zero bounding boxes. Any help will be appreciated. Thanks
I had looked into how tensorflow itself does the output -> coordinates transformation in its own models - its not actually too hard, just poke around in all the functions called. I don't have the exact name of the function now but I had found one in there that does the same and used it..
Check out the tensorflow object detection API, it is 100% in there, I am just not sure of the exact function name now