Skip to content

Commit

Permalink
eliminate grid sensitivity
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianxiaomo committed Jul 2, 2020
1 parent 439fff6 commit 76f513a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions tool/darknet2pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def create_network(self, blocks):
yolo_layer.num_anchors = int(block['num'])
yolo_layer.anchor_step = len(yolo_layer.anchors) // yolo_layer.num_anchors
yolo_layer.stride = prev_stride
yolo_layer.scale_x_y = float(block['scale_x_y'])
# yolo_layer.object_scale = float(block['object_scale'])
# yolo_layer.noobject_scale = float(block['noobject_scale'])
# yolo_layer.class_scale = float(block['class_scale'])
Expand Down
7 changes: 4 additions & 3 deletions tool/yolo_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def yolo_forward_alternative(output, conf_thresh, num_classes, anchors, num_anch



def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, only_objectness=1,
def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1,
validation=False):
# Output would be invalid if it does not satisfy this assert
# assert (output.size(1) == (5 + num_classes) * num_anchors)
Expand Down Expand Up @@ -158,7 +158,7 @@ def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, only_ob

# Apply sigmoid(), exp() and softmax() to slices
#
bxy = torch.sigmoid(bxy)
bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1)
bwh = torch.exp(bwh)
det_confs = torch.sigmoid(det_confs)
cls_confs = torch.nn.Softmax(dim=2)(cls_confs)
Expand Down Expand Up @@ -263,6 +263,7 @@ def __init__(self, anchor_mask=[], num_classes=0, anchors=[], num_anchors=1, str
self.thresh = 0.6
self.stride = stride
self.seen = 0
self.scale_x_y = 1

self.model_out = model_out

Expand All @@ -274,5 +275,5 @@ def forward(self, output, target=None):
masked_anchors += self.anchors[m * self.anchor_step:(m + 1) * self.anchor_step]
masked_anchors = [anchor / self.stride for anchor in masked_anchors]

return yolo_forward(output, self.thresh, self.num_classes, masked_anchors, len(self.anchor_mask))
return yolo_forward(output, self.thresh, self.num_classes, masked_anchors, len(self.anchor_mask),scale_x_y=self.scale_x_y)

0 comments on commit 76f513a

Please sign in to comment.