-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathpbt_wrapper.py
63 lines (41 loc) · 1.67 KB
/
pbt_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
from multiprocessing import Process
from subprocess import Popen, PIPE
def create_worker(type, task, ps_hosts, worker_hosts, task_index):
if task == "mueller":
p = Popen([
'python3', 'mueller_tf.py', ps_hosts, worker_hosts, '--job_name={}'.format(type), '--task_index={}'.format(task_index)])
elif task == "toy":
p = Popen([
'python3', 'pbtv2_tf.py', ps_hosts, worker_hosts, '--job_name={}'.format(type), '--task_index={}'.format(task_index)])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("size", type=int)
parser.add_argument("task", type=str)
args = parser.parse_args()
population_size = args.size
task = args.task
# create cluster specifications
ps_hosts = '--ps_hosts='
worker_hosts = '--worker_hosts='
hostnames = ['localhost:{}'.format(i) for i in range(2222, 2222+(args.size+1))]
ps_hosts = ps_hosts + hostnames[0]
worker_hosts = worker_hosts + ','.join(hostnames[1:])
# create Processes
processes = []
for i in range(population_size):
if i == 0:
_p = Process(
target=create_worker,
args=('ps', task, ps_hosts, worker_hosts, 0)
)
processes.append(_p)
_p = Process(
target=create_worker,
args=('worker', task, ps_hosts, worker_hosts, i)
)
processes.append(_p)
for process in processes:
process.start()
for process in processes:
process.join()