forked from LPDI-EPFL/masif
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtensorflow_rename_variables.py
123 lines (105 loc) · 4.28 KB
/
tensorflow_rename_variables.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import getopt
import sys
import difflib
import tensorflow as tf
usage_str = ('python tensorflow_rename_variables.py '
'--checkpoint_dir=path/to/dir/ --replace_from=substr '
'--replace_to=substr --add_prefix=abc --dry_run')
find_usage_str = ('python tensorflow_rename_variables.py '
'--checkpoint_dir=path/to/dir/ --find_str=[\'!\']substr')
comp_usage_str = ('python tensorflow_rename_variables.py '
'--checkpoint_dir=path/to/dir/ '
'--checkpoint_dir2=path/to/dir/')
def print_usage_str():
print('Please specify a checkpoint_dir. Usage:')
print('%s\nor\n%s\nor\n%s' % (usage_str, find_usage_str, comp_usage_str))
print('Note: checkpoint_dir should be a *DIR*, not a file')
def compare(checkpoint_dir, checkpoint_dir2):
list1 = [el1 for (el1, el2) in tf.train.list_variables(checkpoint_dir)]
list2 = [el1 for (el1, el2) in tf.train.list_variables(checkpoint_dir2)]
for k1 in list1:
if k1 in list2:
continue
else:
print('{} close matches: {}'.format(
k1, difflib.get_close_matches(k1, list2)))
def find(checkpoint_dir, find_str):
negate = find_str.startswith('!')
if negate:
find_str = find_str[1:]
for var_name, _ in tf.train.list_variables(checkpoint_dir):
if negate and find_str not in var_name:
print('%s missing from %s.' % (find_str, var_name))
if not negate and find_str in var_name:
print('Found %s in %s.' % (find_str, var_name))
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
with tf.compat.v1.Session() as sess:
for var_name, _ in tf.train.list_variables(checkpoint_dir):
# Load the variable
var = tf.train.load_variable(checkpoint_dir, var_name)
# Set the new name
if None not in [replace_from, replace_to]:
new_name = var_name
if replace_from in var_name:
new_name = new_name.replace(replace_from, replace_to)
if add_prefix:
new_name = add_prefix + new_name
if dry_run:
print('%s would be renamed to %s.' % (var_name, new_name))
else:
print('Renaming %s to %s.' % (var_name, new_name))
# Create the variable, potentially renaming it
var = tf.Variable(var, name=new_name)
if not dry_run:
# Save the variables
saver = tf.compat.v1.train.Saver()
sess.run(tf.compat.v1.global_variables_initializer())
saver.save(sess, checkpoint.model_checkpoint_path)
def main(argv):
checkpoint_dir = None
checkpoint_dir2 = None
replace_from = None
replace_to = None
add_prefix = None
dry_run = False
find_str = None
try:
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=',
'replace_from=', 'replace_to=',
'add_prefix=', 'dry_run',
'find_str=',
'checkpoint_dir2='])
except getopt.GetoptError as e:
print(e)
print_usage_str()
sys.exit(2)
for opt, arg in opts:
if opt in ('-h', '--help'):
print(usage_str)
sys.exit()
elif opt == '--checkpoint_dir':
checkpoint_dir = arg
elif opt == '--checkpoint_dir2':
checkpoint_dir2 = arg
elif opt == '--replace_from':
replace_from = arg
elif opt == '--replace_to':
replace_to = arg
elif opt == '--add_prefix':
add_prefix = arg
elif opt == '--dry_run':
dry_run = True
elif opt == '--find_str':
find_str = arg
if not checkpoint_dir:
print_usage_str()
sys.exit(2)
if checkpoint_dir2:
compare(checkpoint_dir, checkpoint_dir2)
elif find_str:
find(checkpoint_dir, find_str)
else:
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run)
if __name__ == '__main__':
main(sys.argv[1:])