-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcloudburst_controller.py
198 lines (162 loc) · 6.12 KB
/
cloudburst_controller.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import json
import threading
import time
import pika
import yaml
from string import Template
from kubernetes import client, config
from kubernetes.client.rest import ApiException
import argparse
import os
import job_monitor
parser = argparse.ArgumentParser()
parser.add_argument(
"-queue",
dest="queue_name",
help="The queue name, if not the QUEUE env var",
default=os.getenv("QUEUE"))
parser.add_argument(
"-broker_url",
dest="broker_url",
help="The broker to pass in, if not check BROKER_URL env var",
default=os.getenv("BROKER_URL"))
parser.add_argument(
"-container_name",
dest="container_name",
help="The name of the container to start",
default=os.getenv("CONTAINER_NAME"))
parser.add_argument(
"-container_url",
dest="container_url",
help="The URL of the container to start",
default=os.getenv("CONTAINER_URL"))
parser.add_argument(
"-num_threads",
dest="num_threads",
type=int,
help="The number of threads to run",
default=1)
parser.add_argument(
"-max_concurrent_jobs",
dest="max_concurrent_jobs",
type=int,
help="The max number of jobs to run concurrently. This is not recommended, instead k8s can use memory and cpu constraints to manage load",
default=os.getenv("MAX_CONCURRENT_JOBS"))
parser.add_argument(
"-test",
dest="is_test",
action="store_true",
help="The broker to pass in, if not check BROKER_URL env var")
parser.add_argument(
"-no-monitor",
dest="no_monitor",
action="store_true",
help="Disable the job monitor thread which stores status in mariadb")
parser.add_argument(
"-debug",
dest="debug",
action="store_true",
help="The broker to pass in, if not check BROKER_URL env var")
def load_template(template_file):
with open(template_file, 'r') as file:
template = Template(file.read())
return template
# attempt to load in-cluster config, or else it's a docker setup
try:
config.load_incluster_config()
except config.ConfigException:
config.load_kube_config()
# init the batch API
batch_v1 = client.BatchV1Api()
# load the YAML job template
job_template_file = 'cloudburst-job-template.yaml'
job_template = load_template(job_template_file)
def substitute_template(template, variables):
substituted_content = template.substitute(variables)
return yaml.safe_load(substituted_content)
def create_kubernetes_job(message):
# load up the substitution variables
my_vars = {}
b_named = False
job_name = "x"
job_namespace = "default"
# create the substitution variables based on the request message.
# extract the job name and namespace when provided
for name1, value1 in message.items():
my_vars[name1.upper()] = value1
if name1 == "WORK_ITEM":
# try to find work_item to identify the job name, make it unique with a timestamp
job_name = f"job-cb-{value1}-{int(time.time_ns()/1000)}"
my_vars["JOB_NAME"] = job_name
if args.debug:
print(f"naming job {job_name} based on {name1}")
b_named = True
elif name1 == "JOB_NAMESPACE":
job_namespace = value1
# name the job something if it was not already named
if not b_named:
for name1, value1 in message.items():
job_name = f"job-cb-{value1}-{int(time.time_ns()/1000)}"
my_vars["JOB_NAME"] = job_name
if args.debug:
print(f"naming job {job_name} based on {name1}")
break
job_manifest = substitute_template(job_template, my_vars)
try:
batch_v1.create_namespaced_job(body=job_manifest, namespace=job_namespace)
print(f"Job {job_name} created successfully in namespace {job_namespace}")
except ApiException as e:
print(f"Exception when creating job: {e}")
# Retrieve the number of current running and pending jobs
def get_running_jobs():
try:
jobs = batch_v1.list_namespaced_job(namespace="default")
running_jobs = [job for job in jobs.items if job.status.active or (job.status.conditions and any(
condition.type == "PodScheduled" and condition.status == "True" for condition in job.status.conditions))]
return len(running_jobs)
except ApiException as e:
print(f"Exception when listing jobs: {e}")
return []
# Function to process RabbitMQ messages
def callback(ch, method, properties, body):
try:
message = json.loads(body)
if args.debug:
print(f"received message: {message}")
# throttle the number of currently running jobs, if max_concurrent_jobs is not None
if args.max_concurrent_jobs is not None and args.max_concurrent_jobs > 0:
while get_running_jobs() >= args.max_concurrent_jobs:
if args.debug:
print(f"Maximum concurrent jobs ({args.max_concurrent_jobs}) running, waiting...")
time.sleep(5) # Wait before checking again
create_kubernetes_job(message)
except (json.JSONDecodeError, KeyError) as e:
print(f"Failed to process message: {e}")
# Function to start consuming messages from RabbitMQ
def start_consuming():
mq_url = args.broker_url
print(f"listening on: {mq_url} queue: {args.queue_name}")
connection = pika.BlockingConnection(pika.URLParameters(mq_url))
channel = connection.channel()
channel.queue_declare(queue=args.queue_name)
channel.basic_consume(queue=args.queue_name, on_message_callback=callback, auto_ack=True)
channel.start_consuming()
def start_job_monitor_thread():
thread = threading.Thread(target=job_monitor.main)
thread.daemon = True # This makes the thread exit when the main program exits
thread.start()
print("Job monitor background thread started")
# Main function to start multiple threads
def main():
if not args.no_monitor:
start_job_monitor_thread()
threads = []
for _ in range(args.num_threads):
thread = threading.Thread(target=start_consuming)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
if __name__ == "__main__":
args = parser.parse_args()
main()