forked from facebookresearch/ijepa
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_distributed.py
113 lines (94 loc) · 2.91 KB
/
main_distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
import logging
import pprint
import sys
from pathlib import Path
import submitit
import yaml
from ijepa import main as app_main
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()
parser = argparse.ArgumentParser()
parser.add_argument("--folder", type=str, help="location to save submitit logs")
parser.add_argument(
"--batch-launch",
action="store_true",
help="whether fname points to a file to batch-lauch several config files",
)
parser.add_argument(
"--fname",
type=str,
help="yaml file containing config file names to launch",
default="configs.yaml",
)
parser.add_argument("--partition", type=str, help="cluster partition to submit jobs on")
parser.add_argument(
"--nodes",
type=int,
default=1,
help="num. nodes to request for job",
)
parser.add_argument(
"--tasks-per-node",
type=int,
default=1,
help="num. procs to per node",
)
parser.add_argument("--time", type=int, default=4300, help="time in minutes to run job")
class Trainer:
def __init__(self, fname="configs.yaml", load_model=None) -> None:
self.fname = fname
self.load_model = load_model
def __call__(self):
fname = self.fname
load_model = self.load_model
logger.info(f"called-params {fname}")
# -- load script params
params = None
with Path(fname).open() as y_file:
params = yaml.load(y_file, Loader=yaml.FullLoader)
logger.info("loaded params...")
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(params)
resume_preempt = False if load_model is None else load_model
app_main(args=params, resume_preempt=resume_preempt)
def checkpoint(self):
fb_trainer = Trainer(self.fname, load_model=True)
return submitit.helpers.DelayedSubmission(
fb_trainer,
)
def launch():
executor = submitit.AutoExecutor(
folder=Path(args.folder) / "job_%j",
slurm_max_num_timeout=20,
)
executor.update_parameters(
slurm_partition=args.partition,
slurm_mem_per_gpu="55G",
timeout_min=args.time,
nodes=args.nodes,
tasks_per_node=args.tasks_per_node,
cpus_per_task=10,
gpus_per_node=args.tasks_per_node,
)
config_fnames = [args.fname]
jobs, trainers = [], []
with executor.batch():
for cf in config_fnames:
fb_trainer = Trainer(cf)
job = executor.submit(
fb_trainer,
)
trainers.append(fb_trainer)
jobs.append(job)
for job in jobs:
logger.info(f"Job id: {job.job_id}")
if __name__ == "__main__":
args = parser.parse_args()
launch()