Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: added support for onnx-simplifier; fixed self.training parameter bug in Yolo IDetect layer #2

Open
wants to merge 1 commit into
base: paper
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions models/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,17 @@ def convert_sync_batchnorm_to_batchnorm(module):
parser.add_argument('--weights', type=str, default='./yolor-p6.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[1280, 1280], help='image size') # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--inplace', action='store_true', help='set inplace of Yolo Detect layer to True')
parser.add_argument('--simplify', action='store_true', help='use onnx-simplifier')
opt = parser.parse_args()
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
print(opt)
set_logging()
t = time.time()

detect_inplace = opt.inplace or False
do_simplify = opt.simplify or True

# Load PyTorch model
model = attempt_load(opt.weights, map_location=torch.device('cpu')) # load FP32 model
labels = model.names
Expand All @@ -55,7 +60,7 @@ def convert_sync_batchnorm_to_batchnorm(module):

model = convert_sync_batchnorm_to_batchnorm(model)

print(model)
#print(model)

# Checks
gs = int(max(model.stride)) # grid size (max stride)
Expand All @@ -69,12 +74,15 @@ def convert_sync_batchnorm_to_batchnorm(module):
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
if isinstance(m, models.common.Conv) and isinstance(m.act, nn.Hardswish):
m.act = Hardswish() # assign activation
# if isinstance(m, models.yolo.Detect):
# m.forward = m.forward_export # assign forward (optional)
#if isinstance(m, models.yolo.IDetect):
# m.inplace = detect_inplace
# m.forward = m.forward_export # assign forward (optional)
model.model[-1].export = True # set Detect() layer export=True
model.model[-1].inplace = detect_inplace # set Detect() layer inplace=True
train = False
y = model(img) # dry run

print(y[0].shape)
print("out[0].shape: {}".format(y[0].shape))

# TorchScript export
try:
Expand All @@ -93,29 +101,39 @@ def convert_sync_batchnorm_to_batchnorm(module):

print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
f = opt.weights.replace('.pt', f'-{opt.img_size[0]}-{opt.img_size[1]}.onnx') # filename
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
torch.onnx.export(model, img, f, verbose=False,
opset_version=12,
do_constant_folding=not train,
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
input_names=['images'],
output_names=['classes', 'boxes'] if y is None else ['output'])

# Checks
# Load
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model

do_simplify = True
#do_simplify = True
if do_simplify:
print('Simplifying...')
from onnxsim import simplify

onnx_model, check = simplify(onnx_model, check_n=3)
assert check, 'assert simplify check failed'
onnx.save(onnx_model, f)
print('Simplify success.')

print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
print("Model successfully saved and loaded")

# Check
onnx.checker.check_model(onnx_model) # check onnx model

session = ort.InferenceSession(f)
#session = ort.InferenceSession(f)

for ii in session.get_inputs():
print("input: ", ii)
#for ii in session.get_inputs():
# print("input: ", ii)

for oo in session.get_outputs():
print("output: ", oo)
#for oo in session.get_outputs():
# print("output: ", oo)

print('ONNX export success, saved as %s' % f)
except Exception as e:
Expand All @@ -138,7 +156,7 @@ def convert_sync_batchnorm_to_batchnorm(module):
print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t))

"""
PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 640
PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 320
PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 1280
PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 640
PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 320
PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 1280
"""
27 changes: 20 additions & 7 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
def forward(self, x):
# x = x.copy() # for profiling
z = [] # inference output
self.training |= self.export
self.training ^= not self.export
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
Expand All @@ -71,7 +71,7 @@ class IDetect(nn.Module):
stride = None # strides computed during build
export = False # onnx export

def __init__(self, nc=80, anchors=(), ch=()): # detection layer
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
super(IDetect, self).__init__()
self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor
Expand All @@ -82,14 +82,19 @@ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
self.register_buffer('anchors', a) # shape(nl,na,2)
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv

self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch)

self.inplace = inplace

def forward(self, x):
# x = x.copy() # for profiling
z = [] # inference output
self.training |= self.export
# NOTE: this was a bug: self.training |= self.export
self.training ^= not self.export
print(f"IDetect module, self.training: {self.training}")

for i in range(self.nl):
x[i] = self.im[i](self.m[i](self.ia[i](x[i]))) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
Expand All @@ -100,8 +105,14 @@ def forward(self, x):
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)

y = x[i].sigmoid()
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
if self.inplace:
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)

z.append(y.view(bs, -1, self.no))

return x if self.training else (torch.cat(z, 1), x)
Expand All @@ -111,9 +122,11 @@ def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()

#CFG = 'yolov5s.yaml'
CFG = 'yolor-p6.yaml'

class Model(nn.Module):
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, nid=None): # model, input channels, number of classes
def __init__(self, cfg=CFG, ch=3, nc=None, nid=None): # model, input channels, number of classes
super(Model, self).__init__()
if isinstance(cfg, dict):
self.yaml = cfg # model dict
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pandas
# export --------------------------------------
# coremltools>=4.1
# onnx>=1.8.1
# onnx-simplifier>=0.3.6
# scikit-learn==0.19.2 # for coreml quantization

# extras --------------------------------------
Expand Down