diff --git a/tools/pytorch2Msnhnet/PytorchToMsnhnet.py b/tools/pytorch2Msnhnet/PytorchToMsnhnet.py index a876bfe2..f48cd3a4 100644 --- a/tools/pytorch2Msnhnet/PytorchToMsnhnet.py +++ b/tools/pytorch2Msnhnet/PytorchToMsnhnet.py @@ -718,6 +718,109 @@ def _view(inData, *args): ccc.append(x) dataSize = inData.shape[1]*inData.shape[2]*inData.shape[3] + if inData.shape[0] != 1: + raise NotImplementedError("params error") + + if len(list(args)) == 1: + if args[0] != -1: + raise NotImplementedError("params error") + msnhnet.buildView(str(x._cdata),1,1,dataSize) + + if len(list(args)) == 2: + if args[0] == -1 and args[1] != -1: + if dataSize % args[1] != 0: + raise NotImplementedError("params error") + dim1 = dataSize/args[1] + dim2 = args[1] + msnhnet.buildView(str(x._cdata),1,dim1,dim2) + elif args[0] != -1 and args[1] == -1: + if dataSize % args[1] != 0: + raise NotImplementedError("params error") + dim1 = args[0] + dim2 = dataSize/args[0] + msnhnet.buildView(str(x._cdata),1,dim1,dim2) + elif args[0] != -1 and args[1] != -1: + if dataSize % (args[1]*args[0]) != 0: + raise NotImplementedError("params error") + dim1 = arg[0] + dim2 = arg[1] + msnhnet.buildView(str(x._cdata),1,dim1,dim2) + else: + raise NotImplementedError("params error") + if len(list(args)) == 3: + if args[0] == -1 and args[1] != -1 and args[2] != -1: + if dataSize % (args[1]*args[2]) != 0: + raise NotImplementedError("params error") + dim0 = dataSize /(args[1]*args[2]) + dim1 = args[1] + dim2 = args[2] + msnhnet.buildView(str(x._cdata),dim0,dim1,dim2) + elif args[0] != -1 and args[1] == -1 and args[2] != -1: + if dataSize % (args[0]*args[2]) != 0: + raise NotImplementedError("params error") + dim0 = args[0] + dim1 = dataSize/(args[0]*args[2]) + dim2 = args[2] + msnhnet.buildView(str(x._cdata),dim0,dim1,dim2) + elif args[0] != -1 and args[1] != -1 and args[2] == -1: + if dataSize % (args[0]*args[1]) != 0: + raise NotImplementedError("params error") + dim0 = args[0] + dim1 = args[1] + dim2 = dataSize/(args[0]*args[1]) + msnhnet.buildView(str(x._cdata),dim0,dim1,dim2) + elif args[0] != -1 and args[1] != -1 and args[2] != -1: + if dataSize / (args[0]*args[1]*args[2]) != 1: + raise NotImplementedError("params error") + dim0 = args[0] + dim1 = args[1] + dim2 = args[2] + msnhnet.buildView(str(x._cdata),dim0,dim1,dim2) + if len(list(args)) == 4: + if args[0] == -1: + if dataSize/(args[1]*args[2]*args[3])==1 : + dim0 = args[1] + dim1 = args[2] + dim2 = args[3] + msnhnet.buildView(str(x._cdata),dim0,dim1,dim2) + else: + raise NotImplementedError("params error") + elif args[0] == 1: + if args[1] == -1 and args[2] != -1 and args[3] != -1: + if dataSize % (args[1]*args[2]) != 0: + raise NotImplementedError("params error") + dim0 = dataSize /(args[2]*args[3]) + dim1 = args[2] + dim2 = args[3] + msnhnet.buildView(str(x._cdata),dim0,dim1,dim2) + elif args[1] != -1 and args[2] == -1 and args[3] != -1: + if dataSize % (args[1]*args[3]) != 0: + raise NotImplementedError("params error") + dim0 = args[1] + dim1 = dataSize/(args[1]*args[3]) + dim2 = args[3] + msnhnet.buildView(str(x._cdata),dim0,dim1,dim2) + elif args[1] != -1 and args[2] != -1 and args[3] == -1: + if dataSize % (args[1]*args[2]) != 0: + raise NotImplementedError("params error") + dim0 = args[1] + dim1 = args[2] + dim2 = dataSize/(args[1]*args[2]) + msnhnet.buildView(str(x._cdata),dim0,dim1,dim2) + elif args[1] != -1 and args[2] != -1 and args[3] != -1: + if dataSize / (args[1]*args[2]*args[3]) != 1: + raise NotImplementedError("params error") + dim0 = args[1] + dim1 = args[2] + dim2 = args[3] + msnhnet.buildView(str(x._cdata),dim0,dim1,dim2) + return x + +def _reshape(inData, *args): + x=raw_reshape(inData, *args) + ccc.append(x) + dataSize = inData.shape[1]*inData.shape[2]*inData.shape[3] + if inData.shape[0] != 1: raise NotImplementedError("params error") @@ -869,6 +972,8 @@ def _expand_as(inData, *args): for t in [torch.Tensor]: raw_view = t.view t.view = _view + raw_reshape = t.reshape + t.reshape = _reshape raw_mean = t.mean t.mean = _mean raw__add__ = t.__add__ @@ -898,8 +1003,6 @@ def _expand_as(inData, *args): raw__expand_as__ = t.expand_as t.expand_as = _expand_as - - def trans(net, inputVar, msnhnet_path, msnhbin_path): Hook.hookInited = True msnhnet.buildConfig(str(id(inputVar)), inputVar.size())