-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrainer.py
80 lines (69 loc) · 2.39 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from cmath import isnan
from re import U
import sys
sys.path.append('./')
sys.path.append('../')
import numpy as np
import random
import json
import argparse
import wandb
import jax
import hydra
import jax.numpy as jnp
import itertools
from src.trainers.trainers_control import ControlTrainer
from omegaconf import DictConfig, OmegaConf
import pandas as pd
from tqdm import tqdm
import numpy as np
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import Pool
import multiprocessing as mp
from tqdm.contrib.logging import logging_redirect_tqdm
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
task_to_trainer={
'minigrid_pixel':ControlTrainer,
'minigrid_onehot':ControlTrainer,
}
@hydra.main(version_base=None, config_path="config", config_name="default_config")
def main(config: DictConfig):
logger.info("Starting Job for Config:\n"+str(OmegaConf.to_yaml(config)))
tags=config.tags.split(',') if config.tags is not None else []
if config.use_wandb:
run = wandb.init(project=config.project_name,tags=tags,settings=wandb.Settings(start_method="fork"),config=OmegaConf.to_container(config))
else:
run=None
key=jax.random.PRNGKey(config.seed)
np.random.seed(config.seed)
random.seed(config.seed)
trainer_config=config.trainer
env_config=config.task
#Train the model
kwargs={'global_args':config,'trainer_config':trainer_config,'env_config':env_config,
'seed':config.seed,'key':key,'wandb_run':run}
trainer=task_to_trainer[env_config['task']](**kwargs)
pbar = tqdm(total=config.steps)
step_count=0
last_step_count=0
with logging_redirect_tqdm():
while True:
loss,metrics,step_count=trainer.step()
pbar.update(n=step_count-last_step_count)
last_step_count=step_count
if metrics is not None:
logger.info("Seed: "+str(config.seed)+" Steps: "+str(step_count)+" Metrics: "+str(metrics))
if config.use_wandb: run.log({'seed':config.seed,**metrics
},step=step_count)
if step_count>=config.steps:
break
pbar.close()
if config.use_wandb:
wandb.finish()
#Need to do something about logging val_metric
if __name__=='__main__':
mp.set_start_method('forkserver')
main()