-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathir_weight_extractor.py
64 lines (57 loc) · 2.46 KB
/
ir_weight_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import sys
import argparse
import pickle
import struct
import xml.etree.ElementTree as et
def dumpWeight(model):
# for unpacking binary buffer
format_config = { 'FP32': ['f', 4],
'FP16': ['e', 2],
'I64' : ['q', 8],
'I32' : ['i', 4],
'I16' : ['h', 2],
'I8' : ['b', 1],
'U8' : ['B', 1]}
# Read IR weight data
with open(model+'.bin', 'rb') as f:
binWeight = f.read()
# Parse IR XML file, find 'Const' node, extract weight, and generate pickle file
tree = et.parse(model+'.xml')
root = tree.getroot()
layers = root.find('layers')
weight = {}
print(' size : nodeName')
for layer in layers:
if layer.attrib['type'] == 'Const':
data = layer.find('data')
if not data is None:
if 'offset' in data.attrib and 'size' in data.attrib:
offset = int(data.attrib['offset'])
size = int(data.attrib['size'])
blobBin = binWeight[offset:offset+size] # cut out the weight for this blob from the weight buffer
outputport = layer.find('output').find('port')
prec = outputport.attrib['precision']
dims = []
for dim in outputport.findall('dim'): # extract shape information
dims.append(dim.text)
formatstring = '<' + format_config[prec][0] * (len(blobBin)//format_config[prec][1])
decodedwgt = struct.unpack(formatstring, blobBin) # decode the buffer
weight[layer.attrib['name']] = [ prec, dims, decodedwgt ] # { blobName : [ precStr, dims, weightBuf ]}
print('{:8} : {}'.format(len(blobBin), layer.attrib['name']))
fname = model+'_wgt.pickle'
with open(fname, 'wb') as f:
pickle.dump(weight, f)
print('\n' + fname + ' is generated')
def main():
print('*** OpenVINO IR model weight data extractor')
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', type=str, help='input IR model path')
args = parser.parse_args()
model, ext = os.path.splitext(args.model)
if ext != '.xml':
print('The specified model is not \'.xml\' file')
sys.exit(-1)
dumpWeight(model)
if __name__ == "__main__":
main()