-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathmain.py
37 lines (26 loc) · 959 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import shutil
from absl import app
from absl import flags
from src.trainer import TrainingManager
from src.utils import load_config
EXPERIMENT_PATH = "experiments"
EXPERIMENT_CONFIG_NAME = "config.yml"
FLAGS = flags.FLAGS
flags.DEFINE_string(
"experiment_name",
"",
"Experiment name: experiment outputs will be saved in a created experiment name directory",
)
flags.DEFINE_string("config_path", "config.yml", "Config file path")
def main(argv):
config = load_config(FLAGS.config_path)
experiment_path = os.path.join(EXPERIMENT_PATH, FLAGS.experiment_name)
os.makedirs(experiment_path, exist_ok=True)
experiment_config_path = os.path.join(experiment_path, EXPERIMENT_CONFIG_NAME)
shutil.copy2(FLAGS.config_path, experiment_config_path)
trainer = TrainingManager(config, experiment_path)
trainer.train()
if __name__ == "__main__":
flags.mark_flag_as_required("experiment_name")
app.run(main)