-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgen_wts.py
40 lines (36 loc) · 1.34 KB
/
gen_wts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import struct
import sys
from utils.torch_utils import select_device
import os
# Initialize
device = select_device('cpu')
# pt_file = sys.argv[1]
pt_file = "/home/liwei.fang/YOLOP-main/weights/End-to-end.pth"
# Load model
# model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
# model.to(device).eval()
model_dict = torch.load(pt_file, map_location=device)['state_dict']
dirname = os.path.dirname(pt_file)
# with open(os.path.join(dirname,'output_keys.txt'), 'w') as f:
# for k, v in model_dict.items():
# f.write('{}, shape:{}\n'.format(k, v.shape))
# print(model_dict.keys())
with open(os.path.join(dirname,'output.wts'), 'w') as f:
f.write('{}\n'.format(len(model_dict.keys())))
for k, v in model_dict.items():
vr = v.reshape(-1).cpu().numpy()
f.write('{} {} '.format(k, len(vr)))
for vv in vr:
f.write(' ')
f.write(struct.pack('>f', float(vv)).hex())
f.write('\n')
# with open(pt_file.split('.')[0] + '.wts', 'w') as f:
# f.write('{}\n'.format(len(model.state_dict().keys())))
# for k, v in model.state_dict().items():
# vr = v.reshape(-1).cpu().numpy()
# f.write('{} {} '.format(k, len(vr)))
# for vv in vr:
# f.write(' ')
# f.write(struct.pack('>f',float(vv)).hex())
# f.write('\n')