-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
159 lines (137 loc) · 7.19 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Common libs
import numpy as np
import multiprocessing as mp
import os, sys, time, glob, pickle, psutil, argparse, importlib
sys.path.insert(0, f'{os.getcwd()}')
# Custom libs
from config import load_config, log_config
from utils.logger import print_mem, redirect_io
from config.utils import get_snap
def get_last_train(cfg):
saving_path = sorted(glob.glob(f'results/{cfg.dataset.lower()}/{cfg.name}/*'))
return saving_path[-1] if saving_path else None
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--cfg_path', type=str, help='config path')
parser.add_argument('--gpus', type=str, default=None, help='the number/ID of GPU(s) to use [default: 1], 0 to use cpu only')
parser.add_argument('--mode', type=str, default=None, help='options: train, val, test')
parser.add_argument('--seed', type=int, default=None, dest='rand_seed', help='random seed for use')
parser.add_argument('--data_path', type=str, default=None, help='path to dataset dir = data_path/dataset_name')
parser.add_argument('--model_path', type=str, default=None, help='pretrained model path')
parser.add_argument('--saving_path', type=str, default=None, help='specified saving path')
parser.add_argument('--num_votes', type=float, default=None, help='least num of votes of each point (default to 30)')
parser.add_argument('--num_threads', type=lambda n: mp.cpu_count() if n == 'a' else int(n) if n else None, default=None, help='the number of cpu to use for data loading')
parser.add_argument('--set', type=str, help='external source to set the config - str of dict / yaml file')
parser.add_argument('--debug', action='store_true', help='debug mode')
FLAGS = parser.parse_args()
# sys.argv = sys.argv[:1] # clean extra argv
# ---------------------------------------------------------------------------- #
# solve env & cfg
# ---------------------------------------------------------------------------- #
assert FLAGS.cfg_path is not None
# load config - config path: config(dir).dataset_name(py).config_name(py_class)
cfg = load_config(cfg_path=FLAGS.cfg_path)
# update config
for arg in ['data_path', 'model_path', 'saving_path', 'mode', 'gpus', 'rand_seed', 'num_threads', 'num_votes', 'debug']:
if getattr(FLAGS, arg) is not None:
setattr(cfg, arg, getattr(FLAGS, arg))
if FLAGS.set:
for arg in FLAGS.set.split(';'):
cfg.update(arg)
# env setting: visible gpu, tf warnings (level = '0'/'3')
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpu_devices
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
if cfg.mixed_precision:
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
import tensorflow as tf
if tf.__version__.split('.')[0] == '2':
tf = tf.compat.v1
tf.disable_v2_behavior()
import models, datasets
from utils.tester import ModelTester
from utils.trainer import ModelTrainer
from utils.tf_graph_builder import GraphBuilder
# solve config
if cfg.dataset in ['S3DIS']:
cfg.mode = cfg.mode.replace('test', 'validation')
if cfg.model_path and os.path.isdir(cfg.model_path):
cfg.model_path = get_snap(cfg.model_path, step='last')
if cfg.save_memory: # use gradient-checkpointing to save memory
import utils.memory_saving_gradients
tf.__dict__['gradients'] = utils.memory_saving_gradients.gradients_memory # one from the: gradients_speed, gradients_memory, gradients_collection
if isinstance(cfg.rand_seed, int): # manual set seed
tf.set_random_seed(cfg.rand_seed)
np.random.seed(cfg.rand_seed)
if cfg.debug: # debug mode
cfg.saving_path = 'test'
cfg.log_file = sys.stdout
# ---------------------------------------------------------------------------- #
# training
# ---------------------------------------------------------------------------- #
if 'train' in cfg.mode:
# result dir: results/dataset_name/config_name/Log_time/...
if not cfg.saving_path:
time.sleep(np.random.randint(1, 10)) # random sleep (avoid same log dir)
# dataset_name = '_'.join([i for i in [cfg.dataset.lower(), cfg.version, cfg.validation_split] if i]) # default version / validation_split specified in dataset class
cfg.saving_path = f'results/{cfg.dataset.lower()}/{cfg.name}/' + time.strftime('Log_%Y-%m-%d_%H-%M-%S', time.gmtime())
os.makedirs(cfg.saving_path, exist_ok=True)
if not cfg.log_file:
cfg.log_file = os.path.join(cfg.saving_path, 'log_train.txt')
if isinstance(cfg.log_file, str):
cfg.log_file = open(cfg.log_file, 'w')
log_config(cfg)
log_config(cfg, f_out=cfg.log_file)
# actual training
print_mem('>>> start training', check_time=True)
with redirect_io(cfg.log_file, cfg.debug):
trainer = ModelTrainer(cfg)
trainer.train()
print(flush=True)
print_mem('>>> finished training', check_time=True)
if cfg.gpu_num > 1:
cfg.gpus = 1
if 'test' in cfg.mode or 'val' in cfg.mode:
# find chosen snap (and saving_path if not specified)
log_config(cfg)
if cfg.model_path and 'train' not in cfg.mode: # specified for val/test (not for continue training)
snap_list = [cfg.model_path]
cfg.saving_path = os.path.dirname(cfg.model_path).split(cfg.snap_dir)[0].rstrip('/') # ensure at least is a dir
elif cfg.saving_path:
snap_list = [f[:-5] for f in glob.glob(os.path.join(cfg.saving_path, cfg.snap_dir, f'{cfg.snap_prefix}*.meta'))]
else:
raise ValueError('provide either cfg.model_path (snap) or cfg.saving_path (dir)')
assert len(snap_list) > 0, f'no snap found in saving_path={cfg.saving_path}'
def val_test(snap):
# using the saved model
step = snap.split(f'{cfg.snap_prefix}-')[-1].split('.')[0]
assert len(glob.glob(snap + '*')) > 0 and os.path.isdir(cfg.saving_path), f'err path: chosen_snap = {snap}, saving_path = {cfg.saving_path}'
print('using restored model, chosen_snap =', snap, flush=True)
with tf.Graph().as_default():
g = GraphBuilder(cfg) # build fresh compute graph
g.restore(restore_snap=snap, select_list=['model/.*'])
tester = ModelTester(cfg)
if 'val' in cfg.mode:
log_file = os.path.join(cfg.saving_path, f'log_validation.txt_{step}')
with redirect_io(log_file, cfg.debug):
log_config(cfg)
print('using restored model, chosen_snap =', snap, flush=True)
tester.val_vote(g.sess, g.ops, g.dataset, g.model, num_votes=cfg.num_votes) # fresh voting
print(flush=True)
print_mem('>>> finished val', check_time=True)
if 'test' in cfg.mode:
log_file = os.path.join(cfg.saving_path, f'log_test.txt_{step}')
test_path = os.path.join(cfg.saving_path, f'test_{step}')
with redirect_io(log_file, cfg.debug):
log_config(cfg)
tester.test_vote(g.sess, g.ops, g.dataset, g.model, num_votes=cfg.num_votes, test_path=test_path)
print(flush=True)
print_mem('>>> finished test', check_time=True)
for snap in snap_list:
val_test(snap)
# cleanup
for child in mp.active_children():
child.terminate()
parent = psutil.Process(os.getpid())
children = parent.children(recursive=True)
for child in children:
child.kill()