diff --git a/models/export.py b/models/export.py index e1c79ff2..c6c20c08 100644 --- a/models/export.py +++ b/models/export.py @@ -41,12 +41,17 @@ def convert_sync_batchnorm_to_batchnorm(module): parser.add_argument('--weights', type=str, default='./yolor-p6.pt', help='weights path') parser.add_argument('--img-size', nargs='+', type=int, default=[1280, 1280], help='image size') # height, width parser.add_argument('--batch-size', type=int, default=1, help='batch size') + parser.add_argument('--inplace', action='store_true', help='set inplace of Yolo Detect layer to True') + parser.add_argument('--simplify', action='store_true', help='use onnx-simplifier') opt = parser.parse_args() opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand print(opt) set_logging() t = time.time() + detect_inplace = opt.inplace or False + do_simplify = opt.simplify or True + # Load PyTorch model model = attempt_load(opt.weights, map_location=torch.device('cpu')) # load FP32 model labels = model.names @@ -55,7 +60,7 @@ def convert_sync_batchnorm_to_batchnorm(module): model = convert_sync_batchnorm_to_batchnorm(model) - print(model) + #print(model) # Checks gs = int(max(model.stride)) # grid size (max stride) @@ -69,12 +74,15 @@ def convert_sync_batchnorm_to_batchnorm(module): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility if isinstance(m, models.common.Conv) and isinstance(m.act, nn.Hardswish): m.act = Hardswish() # assign activation - # if isinstance(m, models.yolo.Detect): - # m.forward = m.forward_export # assign forward (optional) + #if isinstance(m, models.yolo.IDetect): + # m.inplace = detect_inplace + # m.forward = m.forward_export # assign forward (optional) model.model[-1].export = True # set Detect() layer export=True + model.model[-1].inplace = detect_inplace # set Detect() layer inplace=True + train = False y = model(img) # dry run - print(y[0].shape) + print("out[0].shape: {}".format(y[0].shape)) # TorchScript export try: @@ -93,29 +101,39 @@ def convert_sync_batchnorm_to_batchnorm(module): print('\nStarting ONNX export with onnx %s...' % onnx.__version__) f = opt.weights.replace('.pt', f'-{opt.img_size[0]}-{opt.img_size[1]}.onnx') # filename - torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'], + torch.onnx.export(model, img, f, verbose=False, + opset_version=12, + do_constant_folding=not train, + training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL, + input_names=['images'], output_names=['classes', 'boxes'] if y is None else ['output']) - # Checks + # Load onnx_model = onnx.load(f) # load onnx model - onnx.checker.check_model(onnx_model) # check onnx model - print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model - do_simplify = True + #do_simplify = True if do_simplify: + print('Simplifying...') from onnxsim import simplify onnx_model, check = simplify(onnx_model, check_n=3) assert check, 'assert simplify check failed' onnx.save(onnx_model, f) + print('Simplify success.') + + print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model + print("Model successfully saved and loaded") + + # Check + onnx.checker.check_model(onnx_model) # check onnx model - session = ort.InferenceSession(f) + #session = ort.InferenceSession(f) - for ii in session.get_inputs(): - print("input: ", ii) + #for ii in session.get_inputs(): + # print("input: ", ii) - for oo in session.get_outputs(): - print("output: ", oo) + #for oo in session.get_outputs(): + # print("output: ", oo) print('ONNX export success, saved as %s' % f) except Exception as e: @@ -138,7 +156,7 @@ def convert_sync_batchnorm_to_batchnorm(module): print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t)) """ - PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 640 - PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 320 - PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 1280 + PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 640 + PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 320 + PYTHONPATH=. python3 ./models/export.py --weights ./weights/yolor-p6.pt --img-size 1280 """ diff --git a/models/yolo.py b/models/yolo.py index 256da5d8..0ef59294 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -44,7 +44,7 @@ def __init__(self, nc=80, anchors=(), ch=()): # detection layer def forward(self, x): # x = x.copy() # for profiling z = [] # inference output - self.training |= self.export + self.training ^= not self.export for i in range(self.nl): x[i] = self.m[i](x[i]) # conv bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) @@ -71,7 +71,7 @@ class IDetect(nn.Module): stride = None # strides computed during build export = False # onnx export - def __init__(self, nc=80, anchors=(), ch=()): # detection layer + def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer super(IDetect, self).__init__() self.nc = nc # number of classes self.no = nc + 5 # number of outputs per anchor @@ -82,14 +82,19 @@ def __init__(self, nc=80, anchors=(), ch=()): # detection layer self.register_buffer('anchors', a) # shape(nl,na,2) self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv - + self.ia = nn.ModuleList(ImplicitA(x) for x in ch) self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch) + self.inplace = inplace + def forward(self, x): # x = x.copy() # for profiling z = [] # inference output - self.training |= self.export + # NOTE: this was a bug: self.training |= self.export + self.training ^= not self.export + print(f"IDetect module, self.training: {self.training}") + for i in range(self.nl): x[i] = self.im[i](self.m[i](self.ia[i](x[i]))) # conv bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) @@ -100,8 +105,14 @@ def forward(self, x): self.grid[i] = self._make_grid(nx, ny).to(x[i].device) y = x[i].sigmoid() - y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy - y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + if self.inplace: + y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 + xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh + y = torch.cat((xy, wh, y[..., 4:]), -1) + z.append(y.view(bs, -1, self.no)) return x if self.training else (torch.cat(z, 1), x) @@ -111,9 +122,11 @@ def _make_grid(nx=20, ny=20): yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() +#CFG = 'yolov5s.yaml' +CFG = 'yolor-p6.yaml' class Model(nn.Module): - def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, nid=None): # model, input channels, number of classes + def __init__(self, cfg=CFG, ch=3, nc=None, nid=None): # model, input channels, number of classes super(Model, self).__init__() if isinstance(cfg, dict): self.yaml = cfg # model dict diff --git a/requirements.txt b/requirements.txt index d8bcf54f..3c7ee3f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,7 @@ pandas # export -------------------------------------- # coremltools>=4.1 # onnx>=1.8.1 +# onnx-simplifier>=0.3.6 # scikit-learn==0.19.2 # for coreml quantization # extras --------------------------------------