Skip to content

Commit

Permalink
add view contiguous
Browse files Browse the repository at this point in the history
  • Loading branch information
msnh2012 committed Aug 31, 2020
1 parent 2bce53a commit 07bfba4
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 11 deletions.
3 changes: 3 additions & 0 deletions tools/pytorch2Msnhnet/ExampleDeepLabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
deeplabv3.load_state_dict(ccc)
deeplabv3.eval()


input=torch.ones([1,3,224,224])

input.view()

'''
# trans msnhnet file only
transNet(deeplabv3, input, "deeplabv3.msnhnet")
Expand Down
2 changes: 1 addition & 1 deletion tools/pytorch2Msnhnet/ExampleResnet18.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
'''

# trans msnhnet and msnhbin file
trans(resnet18, input,"resnet18.msnhnet","resnet18.msnhbin")
trans(resnet18, input,"resnet18.msnhnet","resnet18.msnhbin")
11 changes: 10 additions & 1 deletion tools/pytorch2Msnhnet/MsnhBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,13 @@ def buildReduction(self, name, type, axis):
self.index = self.index + 1
self.net = self.net + "reduction:\n"
self.net = self.net + " type: " + type + "\n"
self.net = self.net + " axis: " + str(int(axis-1)) + "\n"
self.net = self.net + " axis: " + str(int(axis-1)) + "\n"

def buildView(self, name, dim0, dim1, dim2):
self.name_index_dict[name]=self.index
self.net = self.net + "#" + str(self.index) + "\n"
self.index = self.index + 1
self.net = self.net + "view:\n"
self.net = self.net + " dim0: " + str(int(dim0)) + "\n"
self.net = self.net + " dim1: " + str(int(dim1)) + "\n"
self.net = self.net + " dim2: " + str(int(dim2)) + "\n"
123 changes: 114 additions & 9 deletions tools/pytorch2Msnhnet/PytorchToMsnhnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
from struct import pack


msnhnet = Msnhnet()
ccc = []
index = 0
Expand Down Expand Up @@ -700,14 +701,124 @@ def _log10(raw, inData, *args):
msnhnet.buildVariableOp(str(x._cdata),"","log10")
return x

# ===== Variable op not supported ======
''' TODO '''
def _contiguous(inData, *args):
log( "contiguous-i" , args[0]._cdata)
x = raw__contiguous__(inData, *args)
ccc.append(x)

key = msnhnet.getLastKey()
val = msnhnet.name_index_dict[key]
msnhnet.name_index_dict.pop(key)
msnhnet.name_index_dict[str(x._cdata)] = val
log( "contiguous-o" , x._cdata)
return x

def _view(inData, *args):
x=raw_view(inData, *args)
ccc.append(x)
raise NotImplementedError("view not supported yet")
dataSize = inData.shape[1]*inData.shape[2]*inData.shape[3]

if inData.shape[0] != 1:
raise NotImplementedError("params error")

if len(list(args)) == 1:
if args[0] != -1:
raise NotImplementedError("params error")
msnhnet.buildView(str(x._cdata),1,1,dataSize)

if len(list(args)) == 2:
if args[0] == -1 and args[1] != -1:
if dataSize % args[1] != 0:
raise NotImplementedError("params error")
dim1 = dataSize/args[1]
dim2 = args[1]
msnhnet.buildView(str(x._cdata),1,dim1,dim2)
elif args[0] != -1 and args[1] == -1:
if dataSize % args[1] != 0:
raise NotImplementedError("params error")
dim1 = args[0]
dim2 = dataSize/args[0]
msnhnet.buildView(str(x._cdata),1,dim1,dim2)
elif args[0] != -1 and args[1] != -1:
if dataSize % (args[1]*args[0]) != 0:
raise NotImplementedError("params error")
dim1 = arg[0]
dim2 = arg[1]
msnhnet.buildView(str(x._cdata),1,dim1,dim2)
else:
raise NotImplementedError("params error")
if len(list(args)) == 3:
if args[0] == -1 and args[1] != -1 and args[2] != -1:
if dataSize % (args[1]*args[2]) != 0:
raise NotImplementedError("params error")
dim0 = dataSize /(args[1]*args[2])
dim1 = args[1]
dim2 = args[2]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
elif args[0] != -1 and args[1] == -1 and args[2] != -1:
if dataSize % (args[0]*args[2]) != 0:
raise NotImplementedError("params error")
dim0 = args[0]
dim1 = dataSize/(args[0]*args[2])
dim2 = args[2]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
elif args[0] != -1 and args[1] != -1 and args[2] == -1:
if dataSize % (args[0]*args[1]) != 0:
raise NotImplementedError("params error")
dim0 = args[0]
dim1 = args[1]
dim2 = dataSize/(args[0]*args[1])
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
elif args[0] != -1 and args[1] != -1 and args[2] != -1:
if dataSize / (args[0]*args[1]*args[2]) != 1:
raise NotImplementedError("params error")
dim0 = args[0]
dim1 = args[1]
dim2 = args[2]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
if len(list(args)) == 4:
if args[0] == -1:
if dataSize/(args[1]*args[2]*args[3])==1 :
dim0 = args[1]
dim1 = args[2]
dim2 = args[3]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
else:
raise NotImplementedError("params error")
elif args[0] == 1:
if args[1] == -1 and args[2] != -1 and args[3] != -1:
if dataSize % (args[1]*args[2]) != 0:
raise NotImplementedError("params error")
dim0 = dataSize /(args[2]*args[3])
dim1 = args[2]
dim2 = args[3]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
elif args[1] != -1 and args[2] == -1 and args[3] != -1:
if dataSize % (args[1]*args[3]) != 0:
raise NotImplementedError("params error")
dim0 = args[1]
dim1 = dataSize/(args[1]*args[3])
dim2 = args[3]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
elif args[1] != -1 and args[2] != -1 and args[3] == -1:
if dataSize % (args[1]*args[2]) != 0:
raise NotImplementedError("params error")
dim0 = args[1]
dim1 = args[2]
dim2 = dataSize/(args[1]*args[2])
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
elif args[1] != -1 and args[2] != -1 and args[3] != -1:
if dataSize / (args[1]*args[2]*args[3]) != 1:
raise NotImplementedError("params error")
dim0 = args[1]
dim1 = args[2]
dim2 = args[3]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
print(msnhnet.net)
return x

# ===== Variable op not supported ======
''' TODO '''
def _unsqueeze(inData, *args):
x = raw__unsqueeze__(inData, *args)
ccc.append(x)
Expand All @@ -720,12 +831,6 @@ def _expand_as(inData, *args):
raise NotImplementedError("expand_as not supported yet")
return x

def _contiguous(inData, *args):
x = raw__contiguous__(inData, *args)
ccc.append(x)
raise NotImplementedError("contiguous not supported yet")
return x

F.conv2d = Hook(F.conv2d,_conv2d)
F.max_pool2d = Hook(F.max_pool2d,_max_pool2d)
F.avg_pool2d = Hook(F.avg_pool2d,_avg_pool2d)
Expand Down
2 changes: 2 additions & 0 deletions tools/pytorch2Msnhnet/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ Alpha version, maybe have some bugs. Only official op is supported, customized o
- log10
- mean
- permute
- view
- contiguous
- sqrt
- pow
- sum
Expand Down

0 comments on commit 07bfba4

Please sign in to comment.