Skip to content

Commit

Permalink
append GIoU DIoU CIoU,and models can be generated from *.cfg files du…
Browse files Browse the repository at this point in the history
…ring training
  • Loading branch information
Tianxiaomo committed Jun 30, 2020
1 parent 4a5b6c1 commit 439fff6
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 36 deletions.
4 changes: 4 additions & 0 deletions cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from easydict import EasyDict

Cfg = EasyDict()

Cfg.use_darknet_cfg = True
Cfg.cfgfile = 'cfg/yolov4.cfg'

Cfg.batch = 64
Cfg.subdivisions = 16
Cfg.width = 608
Expand Down
1 change: 0 additions & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ def forward(self, input1, input2, input3):
return [x2, x10, x18]



class Yolov4(nn.Module):
def __init__(self, yolov4conv137weight=None, n_classes=80, inference=False):
super().__init__()
Expand Down
5 changes: 3 additions & 2 deletions tool/coco_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os

"""hyper parameters"""
json_file_path = 'E:/Dataset/coco2017/annotations_trainval2017/annotations/instances_val2017.json'
json_file_path = 'E:/Dataset/mscoco2017/annotations/instances_train2017.json'
images_dir_path = 'mscoco2017/train2017/'
output_path = '../data/val.txt'

Expand All @@ -31,7 +31,8 @@
annotations = data['annotations']
for ant in tqdm(annotations):
id = ant['image_id']
name = os.path.join(images_dir_path, images[id]['file_name'])
# name = os.path.join(images_dir_path, images[id]['file_name'])
name = os.path.join(images_dir_path, '{:012d}.jpg'.format(id))
cat = ant['category_id']

if cat >= 1 and cat <= 11:
Expand Down
14 changes: 8 additions & 6 deletions tool/darknet2pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,18 +207,20 @@ def forward(self, x):
self.loss = self.models[ind](x)
outputs[ind] = None
elif block['type'] == 'yolo':
if self.training:
pass
else:
boxes = self.models[ind](x)
out_boxes.append(boxes)
# if self.training:
# pass
# else:
# boxes = self.models[ind](x)
# out_boxes.append(boxes)
boxes = self.models[ind](x)
out_boxes.append(boxes)
elif block['type'] == 'cost':
continue
else:
print('unknown type %s' % (block['type']))

if self.training:
return self.loss
return out_boxes
else:
return get_region_boxes(out_boxes)

Expand Down
3 changes: 2 additions & 1 deletion tool/yolo_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def __init__(self, anchor_mask=[], num_classes=0, anchors=[], num_anchors=1, str
self.model_out = model_out

def forward(self, output, target=None):

if self.training:
return output
masked_anchors = []
for m in self.anchor_mask:
masked_anchors += self.anchors[m * self.anchor_step:(m + 1) * self.anchor_step]
Expand Down
91 changes: 65 additions & 26 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@
from torch import optim
from tensorboardX import SummaryWriter
import logging
import os, sys
import os, sys, math
from tqdm import tqdm
from dataset import Yolo_dataset
from cfg import Cfg
from models import Yolov4
import argparse
from easydict import EasyDict as edict
from torch.nn import functional as F
from tool.darknet2pytorch import Darknet

import numpy as np


def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
def bboxes_iou(bboxes_a, bboxes_b, xyxy=True, GIoU=False, DIoU=False, CIoU=False):
"""Calculate the Intersection of Unions (IoUs) between bounding boxes.
IoU is calculated as a ratio of area of the intersection
and area of the union.
Expand All @@ -48,41 +49,73 @@ def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
box in :obj:`bbox_b`.
from: https://github.com/chainer/chainercv
https://github.com/ultralytics/yolov3/blob/eca5b9c1d36e4f73bf2f94e141d864f1c2739e23/utils/utils.py#L262-L282
"""
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
raise IndexError

# top left
if xyxy:
# intersection top left
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
# bottom right
# intersection bottom right
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
# convex (smallest enclosing box) top left and bottom right
con_tl = torch.min(bboxes_a[:, None, :2], bboxes_b[:, :2])
con_br = torch.max(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
# centerpoint distance squared
rho2 = ((bboxes_a[:, None, 0] + bboxes_a[:, None, 2]) - (bboxes_b[:, 0] + bboxes_b[:, 2])) ** 2 / 4 + (
(bboxes_a[:, None, 1] + bboxes_a[:, None, 3]) - (bboxes_b[:, 1] + bboxes_b[:, 3])) ** 2 / 4

w1 = bboxes_a[:, 2] - bboxes_a[:, 0]
h1 = bboxes_a[:, 3] - bboxes_a[:, 1]
w2 = bboxes_b[:, 2] - bboxes_b[:, 0]
h2 = bboxes_b[:, 3] - bboxes_b[:, 1]

area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
else:
# intersection top left
tl = torch.max((bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2))
# bottom right
# intersection bottom right
br = torch.min((bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2))

# convex (smallest enclosing box) top left and bottom right
con_tl = torch.min((bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2))
con_br = torch.max((bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2))
# centerpoint distance squared
rho2 = ((bboxes_a[:, None, :2] - bboxes_b[:, :2]) ** 2 / 4).sum(dim=-1)

w1 = bboxes_a[:, 2]
h1 = bboxes_a[:, 3]
w2 = bboxes_b[:, 2]
h2 = bboxes_b[:, 3]

area_a = torch.prod(bboxes_a[:, 2:], 1)
area_b = torch.prod(bboxes_b[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=2)
area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
return area_i / (area_a[:, None] + area_b - area_i)


def bboxes_giou(bboxes_a, bboxes_b, xyxy=True):
pass


def bboxes_diou(bboxes_a, bboxes_b, xyxy=True):
pass


def bboxes_ciou(bboxes_a, bboxes_b, xyxy=True):
pass
area_u = area_a[:, None] + area_b - area_i
iou = area_i / area_u

if GIoU or DIoU or CIoU:
if GIoU: # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf
area_c = torch.prod(con_br - con_tl, 2) # convex area
return iou - (area_c - area_u) / area_c # GIoU
if DIoU or CIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
# convex diagonal squared
c2 = torch.pow(con_br - con_tl, 2).sum(dim=2) + 1e-16
if DIoU:
return iou - rho2 / c2 # DIoU
elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w1 / h1).unsqueeze(1) - torch.atan(w2 / h2), 2)
with torch.no_grad():
alpha = v / (1 - iou + v)
return iou - (rho2 / c2 + v * alpha) # CIoU
return iou


class Yolo_loss(nn.Module):
Expand Down Expand Up @@ -150,7 +183,10 @@ def build_target(self, pred, labels, batchsize, fsize, n_ch, output_id):
truth_j = truth_j_all[b, :n]

# calculate iou between truth and reference anchors
anchor_ious_all = bboxes_iou(truth_box.cpu(), self.ref_anchors[output_id])
anchor_ious_all = bboxes_iou(truth_box.cpu(), self.ref_anchors[output_id], CIoU=True)

# temp = bbox_iou(truth_box.cpu(), self.ref_anchors[output_id])

best_n_all = anchor_ious_all.argmax(dim=1)
best_n = best_n_all % 3
best_n_mask = ((best_n_all == self.anch_masks[output_id][0]) |
Expand Down Expand Up @@ -296,13 +332,13 @@ def burnin_schedule(i):
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate / config.batch, betas=(0.9, 0.999), eps=1e-08)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)

criterion = Yolo_loss(device=device, batch=config.batch // config.subdivisions,n_classes=config.classes)
criterion = Yolo_loss(device=device, batch=config.batch // config.subdivisions, n_classes=config.classes)
# scheduler = ReduceLROnPlateau(optimizer, mode='max', verbose=True, patience=6, min_lr=1e-7)
# scheduler = CosineAnnealingWarmRestarts(optimizer, 0.001, 1e-6, 20)

model.train()
for epoch in range(epochs):
#model.train()
# model.train()
epoch_loss = 0
epoch_step = 0

Expand All @@ -323,7 +359,7 @@ def burnin_schedule(i):

epoch_loss += loss.item()

if global_step % config.subdivisions == 0:
if global_step % config.subdivisions == 0:
optimizer.step()
scheduler.step()
model.zero_grad()
Expand Down Expand Up @@ -378,9 +414,9 @@ def get_args(**kwargs):
help='GPU', dest='gpu')
parser.add_argument('-dir', '--data-dir', type=str, default=None,
help='dataset dir', dest='dataset_dir')
parser.add_argument('-pretrained',type=str,default=None,help='pretrained yolov4.conv.137')
parser.add_argument('-classes',type=int,default=80,help='dataset classes')
parser.add_argument('-train_label_path',dest='train_label',type=str,default='train.txt',help="train label path")
parser.add_argument('-pretrained', type=str, default=None, help='pretrained yolov4.conv.137')
parser.add_argument('-classes', type=int, default=80, help='dataset classes')
parser.add_argument('-train_label_path', dest='train_label', type=str, default='train.txt', help="train label path")
args = vars(parser.parse_args())

for k in args.keys():
Expand Down Expand Up @@ -431,7 +467,10 @@ def get_date_str():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')

model = Yolov4(cfg.pretrained,n_classes=cfg.classes)
if cfg.use_darknet_cfg:
model = Darknet(cfg.cfgfile)
else:
model = Yolov4(cfg.pretrained, n_classes=cfg.classes)

if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
Expand Down

0 comments on commit 439fff6

Please sign in to comment.