Skip to content

Commit

Permalink
Initial implementation of parallel batch submission
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetretto committed Sep 5, 2024
1 parent 96baad0 commit 49a4adf
Show file tree
Hide file tree
Showing 9 changed files with 351 additions and 35 deletions.
1 change: 1 addition & 0 deletions src/jobflow_remote/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Import the submodules with a local app to register them to the main app
import jobflow_remote.cli.admin
import jobflow_remote.cli.batch
import jobflow_remote.cli.execution
import jobflow_remote.cli.flow
import jobflow_remote.cli.job
Expand Down
67 changes: 67 additions & 0 deletions src/jobflow_remote/cli/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Annotated, Optional

import typer

from jobflow_remote.cli.formatting import get_batch_processes_table
from jobflow_remote.cli.jf import app
from jobflow_remote.cli.jfr_typer import JFRTyper
from jobflow_remote.cli.types import verbosity_opt
from jobflow_remote.cli.utils import (
exit_with_warning_msg,
get_config_manager,
get_job_controller,
out_console,
)
from jobflow_remote.jobs.batch import RemoteBatchManager

app_batch = JFRTyper(
name="batch", help="Helper utils handling batch jobs", no_args_is_help=True
)
app.add_typer(app_batch)


@app_batch.command(name="list")
def processes_list(
worker: Annotated[
Optional[str],
typer.Option(
"--worker",
"-w",
help="Select the worker.",
),
] = None,
verbosity: verbosity_opt = 0,
) -> None:
"""
Show the list of processes being executed on the batch workers.
Increasing verbosity will require connecting to the host.
"""

jc = get_job_controller()

batch_processes = jc.get_batch_processes(worker)
if not batch_processes or not any(wbc for wbc in batch_processes.values()):
exit_with_warning_msg("No batch processes running")

cm = get_config_manager()
project = cm.get_project()
workers = project.workers
running_jobs = {}
if verbosity > 0:
for worker_name in batch_processes:
worker_config = workers[worker_name]
host = worker_config.get_host()
host.connect()
remote_batch_manager = RemoteBatchManager(
host, worker_config.batch.jobs_handle_dir
)
running_jobs[worker_name] = remote_batch_manager.get_running()

table = get_batch_processes_table(
batch_processes=batch_processes,
workers=workers,
running_jobs=running_jobs,
verbosity=verbosity,
)

out_console.print(table)
9 changes: 9 additions & 0 deletions src/jobflow_remote/cli/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def run_batch(
help=("The maximum number of jobs that will be executed by the batch job"),
),
] = None,
parallel_jobs: Annotated[
Optional[int],
typer.Option(
"--parallel-jobs",
"-pj",
help=("Number of jobs executed in parallel"),
),
] = None,
) -> None:
"""Run Jobs in batch mode."""
run_batch_jobs(
Expand All @@ -85,4 +93,5 @@ def run_batch(
max_time=max_time,
max_wait=max_wait,
max_jobs=max_jobs,
parallel_jobs=parallel_jobs,
)
40 changes: 40 additions & 0 deletions src/jobflow_remote/cli/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from jobflow_remote.cli.utils import ReprStr, fmt_datetime
from jobflow_remote.jobs.state import JobState
from jobflow_remote.remote.data import get_job_path
from jobflow_remote.utils.data import convert_utc_time

if TYPE_CHECKING:
Expand Down Expand Up @@ -255,3 +256,42 @@ def get_worker_table(workers: dict[str, WorkerBase], verbosity: int = 0):
table.add_row(*row)

return table


def get_batch_processes_table(
batch_processes: dict,
workers: dict[str, WorkerBase],
running_jobs: dict[str, list[tuple[str, int, str]]],
verbosity: int = 0,
):
table = Table(title="Flows info")
table.add_column("Process ID")
table.add_column("Process UUID")
table.add_column("Worker")
table.add_column("Process folder")
if verbosity > 0:
table.add_column("Running Job ids (Index)")

for worker_name, processes_data in batch_processes.items():
worker = workers[worker_name]
for process_id, process_uuid in processes_data.items():
row = [
process_id,
process_uuid,
worker_name,
get_job_path(process_uuid, None, worker.batch.work_dir),
]

if verbosity > 0:
jobs_data = running_jobs.get(worker_name, [])
process_jobs = [
f"{job_data[0]} ({job_data[1]})"
for job_data in jobs_data
if job_data[2] == process_uuid
]

row.append("\n".join(process_jobs))

table.add_row(*row)

return table
5 changes: 4 additions & 1 deletion src/jobflow_remote/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ class BatchConfig(BaseModel):
)
max_time: Optional[int] = Field(
None,
description="Maximum time after which a job will not submit more jobs (seconds). To help avoid hitting the walltime",
description="Maximum time after which a job will not start more jobs (seconds). To help avoid hitting the walltime",
)
parallel_jobs: Optional[int] = Field(
None, description="Number of jobs executed in parallel in the same process"
)
model_config = ConfigDict(extra="forbid")

Expand Down
78 changes: 71 additions & 7 deletions src/jobflow_remote/jobs/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import random
from contextlib import ExitStack
from pathlib import Path
from typing import TYPE_CHECKING

Expand All @@ -11,6 +12,7 @@
if TYPE_CHECKING:
from jobflow_remote.remote.host import BaseHost


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -86,18 +88,20 @@ def get_submitted(self) -> list[str]:
Returns
-------
The list of file names in the directory.
The list of file names in the submitted directory.
"""
return self.host.listdir(self.submitted_dir)

def get_terminated(self) -> list[tuple[str, int, str]]:
"""
Get job ids and process ids of the terminated jobs from the corresponding
directory.
directory on the host.
Returns
-------
list
The list of job ids, job indexes and batch process uuids in the host
terminated directory.
"""
terminated = []
for i in self.host.listdir(self.terminated_dir):
Expand All @@ -107,6 +111,16 @@ def get_terminated(self) -> list[tuple[str, int, str]]:
return terminated

def get_running(self) -> list[tuple[str, int, str]]:
"""
Get job ids and process ids of the running jobs from the corresponding
directory on the host.
Returns
-------
list
The list of job ids, job indexes and batch process uuids in the host
running directory.
"""
running = []
for filename in self.host.listdir(self.running_dir):
job_id, _index, process_uuid = filename.split("_")
Expand All @@ -127,23 +141,62 @@ class LocalBatchManager:
Used in the worker to executes the batch Jobs.
"""

def __init__(self, files_dir: str | Path, process_id: str) -> None:
def __init__(
self,
files_dir: str | Path,
process_id: str,
multiprocess_lock=None,
) -> None:
"""
Parameters
----------
files_dir
The full path to directory where the files to handle the jobs
to be executed in batch processes are stored.
process_id
The uuid associated to the batch process.
multiprocess_lock
A lock from the multiprocessing module to be used when executing jobs in
parallel with other processes of the same worker.
"""
self.process_id = process_id
self.files_dir = Path(files_dir)
self.multiprocess_lock = multiprocess_lock
self.submitted_dir = self.files_dir / SUBMITTED_DIR
self.running_dir = self.files_dir / RUNNING_DIR
self.terminated_dir = self.files_dir / TERMINATED_DIR
self.lock_dir = self.files_dir / LOCK_DIR

def get_job(self) -> str | None:
"""
Select randomly a job from the submitted directory to be executed.
Move the file to the running directory.
Locks will prevent the same job from being executed from other processes.
If no job can be executed, None is returned.
Returns
-------
str | None
The name of the job that was selected, or None if no job can be executed.
"""
files = os.listdir(self.submitted_dir)

while files:
selected = random.choice(files)
try:
with Lock(
str(self.lock_dir / selected), lifetime=60, default_timeout=0
):
with ExitStack() as lock_stack:
# if in a multiprocess execution, avoid concurrent interaction
# from processes belonging to the same job
if self.multiprocess_lock:
lock_stack.enter_context(self.multiprocess_lock)
lock_stack.enter_context(
Lock(
str(self.lock_dir / selected),
lifetime=60,
default_timeout=0,
)
)
os.remove(self.submitted_dir / selected)
(self.running_dir / f"{selected}_{self.process_id}").touch()
return selected
Expand All @@ -155,5 +208,16 @@ def get_job(self) -> str | None:
return None

def terminate_job(self, job_id: str, index: int) -> None:
"""
Terminate a job by removing the corresponding file from the running
directory and adding a new file in the terminated directory.
Parameters
----------
job_id
The uuid of the job to terminate.
index
The index of the job to terminate.
"""
os.remove(self.running_dir / f"{job_id}_{index}_{self.process_id}")
(self.terminated_dir / f"{job_id}_{index}_{self.process_id}").touch()
30 changes: 26 additions & 4 deletions src/jobflow_remote/jobs/jobcontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2700,6 +2700,21 @@ def count_flows(
)
return self.flows.count_documents(query)

def count_jobs_states(self, states: list[JobState]) -> dict[JobState, int]:
pipeline = [
{"$match": {"state": {"$in": [s.value for s in states]}}},
{"$group": {"_id": "$state", "count": {"$sum": 1}}},
]
result = self.jobs.aggregate(pipeline) or []
out = {}
for r in result:
out[JobState(r["_id"])] = r["count"]

for state in states:
out[state] = out.get(state, 0)

return out

def get_jobs_info_by_flow_uuid(
self, flow_uuid, projection: list | dict | None = None
):
Expand Down Expand Up @@ -3628,7 +3643,9 @@ def _cancel_queue_process(self, job_doc: dict) -> None:
f"The connection to host {host} could not be closed.", exc_info=True
)

def get_batch_processes(self, worker: str) -> dict[str, str]:
def get_batch_processes(
self, worker: str | None = None
) -> dict[str, dict[str, str]]:
"""
Get the batch processes associated with a given worker.
Expand All @@ -3643,9 +3660,14 @@ def get_batch_processes(self, worker: str) -> dict[str, str]:
A dictionary with the {process_id: process_uuid} of the batch
jobs running on the selected worker.
"""
result = self.auxiliary.find_one({"batch_processes": {"$exists": True}})
if worker:
query = {f"batch_processes.{worker}": {"$exists": True}}
else:
query = {"batch_processes": {"$exists": True}}

result = self.auxiliary.find_one(query)
if result:
return result["batch_processes"].get(worker, {})
return result["batch_processes"] or {}
return {}

def add_batch_process(
Expand Down Expand Up @@ -3674,7 +3696,7 @@ def add_batch_process(
"""
return self.auxiliary.find_one_and_update(
{"batch_processes": {"$exists": True}},
{"$push": {f"batch_processes.{worker}.{process_id}": process_uuid}},
{"$set": {f"batch_processes.{worker}.{process_id}": process_uuid}},
upsert=True,
)

Expand Down
Loading

0 comments on commit 49a4adf

Please sign in to comment.