-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtune.py
112 lines (70 loc) · 2.63 KB
/
tune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import os
import tvm
import nnvm
import numpy as np
from nnvm import testing
from tvm import autotvm
import tvm.contrib.graph_runtime as runtime
from tuner import tuner
parser = argparse.ArgumentParser(description = '')
parser.add_argument('--network', type = str, default = None, help = 'Network Architecture')
parser.add_argument('--target', type = str, default = 'cuda', help = 'Deploy Target')
parser.add_argument('--board', type = str, help = 'board')
parser.add_argument('--dtype', type = str, default = 'float32', help = 'Data Type')
parser.add_argument('--tuner', type = str, default = 'xgb', help = 'Select Tuner')
parser.add_argument('--recompile', action = 'store_true', help = 'ReCompile')
parser.add_argument('--local', action = 'store_true', help = 'ReCompile')
parser.add_argument('--remote', action = 'store_true', help = 'ReCompile')
parser.add_argument('--resume', action = 'store_true', help = 'continue')
parser.add_argument('--device', type = str, help = 'Select Tuner')
parser.add_argument('--size', type = int, help = 'Select Tuner')
args = parser.parse_args()
target_host = None
if args.target == 'cuda':
#target = 'cuda -libs=cudnn'#tvm.target.cuda()
target = 'cuda'
elif args.target == 'llvm':
target = 'llvm'
else:
print('[!] Not Supported Yet')
if args.board == 'tx2':
if args.device == 'cpu':
target = 'llvm'
elif args.device == 'gpu':
target = 'cuda'
target_host = 'llvm -target=aarch64-linux-gnu'
device_key = 'tx2'
if args.remote:
runner = autotvm.RPCRunner(
device_key,
host = 'localhost',
port = 9190,
number = 5,
timeout = 4)
elif args.local:
runner = autotvm.LocalRunner(number = 20, repeat = 3, timeout = 4)
if args.dtype == 'float32':
log_filename = 'log/{}-{}.{}.{}.log'.format(args.network, args.size, args.board, args.device)
elif args.dtype == 'float16':
log_filename = 'log/{}-{}.{}.{}.fp16.log'.format(args.network, args.size, args.board, args.device)
log_filename = '{}.{}.{}.log'.format(args.network, args.board, args.device)
print('[*] Log File : ', log_filename)
option = {
'recompile' : args.recompile,
'board' : args.board,
'target_host' : target_host if target_host != None else 'llvm',
'network' : args.network,
'dtype' : args.dtype,
'target' : target,
'log_filename': log_filename,
'tuner': args.tuner,
'input_size': args.size,
'n_trial': 5000,
'early_stopping': 1500,
'measure_option': autotvm.measure_option(
builder = autotvm.LocalBuilder(timeout = 10),
runner = runner)
}
T = tuner(**option)
T.tune()