Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Slurm agent #3005

Draft
wants to merge 31 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
421d1b8
Add slurm plugin blank components
JiangJiaWei1103 Dec 14, 2024
1d1f806
feat: Add naive slurm agent create and get with rest api
JiangJiaWei1103 Dec 16, 2024
5d97126
Use asyncssh instead of REST API
JiangJiaWei1103 Dec 17, 2024
2e7f0f2
Test ssh communication and run sbatch
JiangJiaWei1103 Dec 18, 2024
9644b99
Add delete method and support slurm job state
JiangJiaWei1103 Dec 19, 2024
e41b181
feat: Submit and run SlurmTask on a remote Slurm cluster
JiangJiaWei1103 Dec 27, 2024
6db24dc
refactor: Remove redundant task_module transfer
JiangJiaWei1103 Dec 28, 2024
122c7f1
refactor: Remove redundant env var
JiangJiaWei1103 Dec 28, 2024
e9760a7
docs: Add env setup guide for local test
JiangJiaWei1103 Dec 30, 2024
e68fda9
docs: Add links and figures
JiangJiaWei1103 Dec 30, 2024
470637c
docs: Fix commit sha
JiangJiaWei1103 Dec 30, 2024
1579ab4
docs: Fix commit sha for demo guide
JiangJiaWei1103 Dec 30, 2024
0e538f0
docs: Fix links
JiangJiaWei1103 Dec 30, 2024
8229418
feat: Support SSH config in task config
JiangJiaWei1103 Dec 31, 2024
9e6d8a6
docs: Include ssh config in demo example
JiangJiaWei1103 Dec 31, 2024
e07b09a
refactor: Reduce ssh_conf option to slurm_host only
JiangJiaWei1103 Jan 7, 2025
3a7eb6d
feat: Support Slurm agent with ShellTask
JiangJiaWei1103 Jan 7, 2025
a815fd9
feat: Simplify Slurm job submission logic
JiangJiaWei1103 Jan 9, 2025
a3ea014
Added script args to agent and task
pryce-turner Jan 10, 2025
a109bd8
Add asyncssh to dependencies
JiangJiaWei1103 Jan 11, 2025
e5da665
docs: Update setup and demo for a basic use case
JiangJiaWei1103 Jan 11, 2025
0a3d9f1
docs: Update basic arch figure path
JiangJiaWei1103 Jan 11, 2025
1b0f6df
docs: Fix typo and hyperlink
JiangJiaWei1103 Jan 11, 2025
26cc201
fix: A tmp workaround to test agent locally without container_image
JiangJiaWei1103 Jan 11, 2025
16d953e
feat: Support user-defined batch script content with SlurmShellTask
JiangJiaWei1103 Jan 14, 2025
c743917
feat: Fall back to PythonTask for naive use cases
JiangJiaWei1103 Jan 15, 2025
e365dee
refactor: Define Slurm as a base task config and extend for remote sc…
JiangJiaWei1103 Jan 15, 2025
c1064d4
feat: Support PythonFunctionTask and reorganize agent structure
JiangJiaWei1103 Jan 16, 2025
361fbd1
Use poetry virtual env to avoid contamination
JiangJiaWei1103 Jan 22, 2025
fc3e34e
docs: Complete local test env setup process
JiangJiaWei1103 Jan 23, 2025
5cd58ec
docs: Add use cases ranging from basic to advanced
JiangJiaWei1103 Jan 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ class AsyncAgentExecutorMixin:

def execute(self: PythonTask, **kwargs) -> LiteralMap:
ctx = FlyteContext.current_context()
ss = ctx.serialization_settings or SerializationSettings(ImageConfig())
ss = ctx.serialization_settings or SerializationSettings(ImageConfig.auto_default_image())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?
is this for shell task?

Copy link
Contributor Author

@JiangJiaWei1103 JiangJiaWei1103 Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we define a SlurmTask without specifying container_image (as the example python script provided above), ctx.serialization_settings will be None. Then, an error is raised which describes that PythonAutoContainerTask needs an image.

I think this is just a temporary workaround for local test and I'm still pondering how to better handle this issue.

output_prefix = ctx.file_access.get_random_remote_directory()
self.resource_meta = None

Expand Down
4 changes: 2 additions & 2 deletions flytekit/extend/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def convert_to_flyte_phase(state: str) -> TaskExecution.Phase:
Convert the state from the agent to the phase in flyte.
"""
state = state.lower()
if state in ["failed", "timeout", "timedout", "canceled", "skipped", "internal_error"]:
if state in ["failed", "timeout", "timedout", "canceled", "cancelled", "skipped", "internal_error"]:
return TaskExecution.FAILED
elif state in ["done", "succeeded", "success"]:
elif state in ["done", "succeeded", "success", "completed"]:
return TaskExecution.SUCCEEDED
elif state in ["running", "terminating"]:
return TaskExecution.RUNNING
Expand Down
32 changes: 23 additions & 9 deletions flytekit/extras/tasks/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,10 @@ def __init__(

if task_config is not None:
fully_qualified_class_name = task_config.__module__ + "." + task_config.__class__.__name__
if not fully_qualified_class_name == "flytekitplugins.pod.task.Pod":
if fully_qualified_class_name not in [
"flytekitplugins.pod.task.Pod",
"flytekitplugins.slurm.script.task.Slurm",
]:
raise ValueError("TaskConfig can either be empty - indicating simple container task or a PodConfig.")

# Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used
Expand All @@ -259,11 +262,14 @@ def __init__(
# errors.
# This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work.
plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config))
self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func)
# Rename the internal task so that there are no conflicts at serialization time. Technically these internal
# tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities
# at serialization time.
self._config_task_instance._name = f"_bash.{name}"
if plugin_class.__name__ in ["SlurmShellTask"]:
self._config_task_instance = None
else:
self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func)
# Rename the internal task so that there are no conflicts at serialization time. Technically these internal
# tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities
# at serialization time.
self._config_task_instance._name = f"_bash.{name}"
self._script = script
self._script_file = script_file
self._debug = debug
Expand All @@ -275,7 +281,9 @@ def __init__(
super().__init__(
name,
task_config,
task_type=self._config_task_instance.task_type,
task_type=kwargs.pop("task_type")
if self._config_task_instance is None
else self._config_task_instance.task_type,
interface=Interface(inputs=inputs, outputs=outputs),
**kwargs,
)
Expand Down Expand Up @@ -309,7 +317,10 @@ def script_file(self) -> typing.Optional[os.PathLike]:
return self._script_file

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
return self._config_task_instance.pre_execute(user_params)
if self._config_task_instance is None:
return user_params
else:
return self._config_task_instance.pre_execute(user_params)

def execute(self, **kwargs) -> typing.Any:
"""
Expand Down Expand Up @@ -367,7 +378,10 @@ def execute(self, **kwargs) -> typing.Any:
return None

def post_execute(self, user_params: ExecutionParameters, rval: typing.Any) -> typing.Any:
return self._config_task_instance.post_execute(user_params, rval)
if self._config_task_instance is None:
return rval
else:
return self._config_task_instance.post_execute(user_params, rval)


class RawShellTask(ShellTask):
Expand Down
5 changes: 5 additions & 0 deletions plugins/flytekit-slurm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Flytekit Slurm Plugin

The Slurm agent is designed to integrate Flyte workflows with Slurm-managed high-performance computing (HPC) clusters, enabling users to leverage Slurm's capability of compute resource allocation, scheduling, and monitoring.

This [guide](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md) provides a concise overview of the design philosophy behind the Slurm agent and explains how to set up a local environment for testing the agent.
Binary file added plugins/flytekit-slurm/assets/basic_arch.png
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing Graph.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, bro.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plugins/flytekit-slurm/assets/flyte_client.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plugins/flytekit-slurm/assets/overview_v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
113 changes: 113 additions & 0 deletions plugins/flytekit-slurm/demo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Slurm Agent Demo

In this guide, we will briefly introduce how to setup an environment to test Slurm agent locally without running the backend service (e.g., flyte agent gRPC server). It covers both basic and advanced use cases.

## Table of Content
* [Overview](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#overview)
* [Setup a Local Test Environment](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#setup-a-local-test-environment)
* [Flyte Client (Localhost)](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#flyte-client-localhost)
* [Remote Tiny Slurm Cluster](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#remote-tiny-slurm-cluster)
* [SSH Configuration](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#ssh-configuration)
* [Run a Demo](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#run-a-demo)

## Overview
Slurm agent on the highest level has three core methods to interact with a Slurm cluster:
1. `create`: Use `srun` or `sbatch` to run a job on a Slurm cluster
2. `get`: Use `scontrol show job <job_id>` to monitor the Slurm job state
3. `delete`: Use `scancel <job_id>` to cancel the Slurm job (this method is still under test)

In the simplest form, Slurm agent supports directly running a batch script using `sbatch` on a Slurm cluster as shown below:

![](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/assets/basic_arch.png)

## Setup a Local Test Environment
Without running the backend service, we can setup an environment to test Slurm agent locally. The setup consists of two main components: a client (localhost) and a remote tiny Slurm cluster. Then, we need to configure SSH connection to facilitate communication between the two, which relies on `asyncssh`.

### Flyte Client (Localhost)
1. Setup a local Flyte cluster following this [official guide](https://docs.flyte.org/en/latest/community/contribute/contribute_code.html#how-to-setup-dev-environment-for-flytekit)
2. Build a virtual environment (e.g., conda) and activate it
3. Clone Flytekit repo, checkout the Slurm agent PR, and install Flytekit
```
git clone https://github.com/flyteorg/flytekit.git
gh pr checkout 3005
make setup && pip install -e .
```
4. Install Flytekit Slurm agent
```
cd plugins/flytekit-slurm/
pip install -e .
```

### Remote Tiny Slurm Cluster
To simplify the setup process, we follow this [guide](https://github.com/JiangJiaWei1103/Slurm-101) to configure a single-host Slurm cluster, covering `slurmctld` (the central management daemon) and `slurmd` (the compute node daemon).

### SSH Configuration
To facilitate communication between the Flyte client and the remote Slurm cluster, we setup SSH on the Flyte client side as follows:
1. Create a new authentication key pair
```
ssh-keygen -t rsa -b 4096
```
2. Copy the public key into the remote Slurm cluster
```
ssh-copy-id <username>@<remote_server_ip>
```
3. Enable key-based authentication
```
# ~/.ssh/config
Host <host_alias>
HostName <remote_server_ip>
Port <ssh_port>
User <username>
IdentityFile <path_to_private_key>
```

## Run a Demo
Suppose we have a batch script to run on Slurm cluster:
```
#!/bin/bash

echo "Working!" >> ./remote_touch.txt
```

We use the following python script to test Slurm agent on the client side. A crucial part of the task configuration is specifying the target Slurm cluster and designating the batch script's path within the cluster.

```python
import os

from flytekit import workflow
from flytekitplugins.slurm import Slurm, SlurmTask


echo_job = SlurmTask(
name="echo-job-name",
task_config=Slurm(
slurm_host="<host_alias>",
batch_script_path="<path_to_batch_script_within_cluster>",
sbatch_conf={
"partition": "debug",
"job-name": "tiny-slurm",
}
)
)


@workflow
def wf() -> None:
echo_job()


if __name__ == "__main__":
from flytekit.clis.sdk_in_container import pyflyte
from click.testing import CliRunner

runner = CliRunner()
path = os.path.realpath(__file__)

print(f">>> LOCAL EXEC <<<")
result = runner.invoke(pyflyte.main, ["run", path, "wf"])
print(result.output)
```

After the Slurm job is completed, we can find the following result on Slurm cluster:

![](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/assets/slurm_basic_result.png)
4 changes: 4 additions & 0 deletions plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .function.agent import SlurmFunctionAgent
from .function.task import SlurmFunction, SlurmFunctionTask
from .script.agent import SlurmScriptAgent
from .script.task import Slurm, SlurmRemoteScript, SlurmShellTask, SlurmTask
115 changes: 115 additions & 0 deletions plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from dataclasses import dataclass
from typing import Dict, Optional

import asyncssh
from asyncssh import SSHClientConnection

from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
from flytekit.extend.backend.utils import convert_to_flyte_phase
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate


@dataclass
class SlurmJobMetadata(ResourceMeta):
"""Slurm job metadata.

Args:
job_id: Slurm job id.
"""

job_id: str
slurm_host: str


class SlurmFunctionAgent(AsyncAgentBase):
name = "Slurm Function Agent"

# SSH connection pool for multi-host environment
_conn: Optional[SSHClientConnection] = None

def __init__(self) -> None:
super(SlurmFunctionAgent, self).__init__(task_type_name="slurm_fn", metadata_type=SlurmJobMetadata)

async def create(
self,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
**kwargs,
) -> SlurmJobMetadata:
# Retrieve task config
slurm_host = task_template.custom["slurm_host"]
srun_conf = task_template.custom["srun_conf"]
Comment on lines +41 to +42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use task_template.custom.get("slurm_host")?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As slurm_host is a required field of the corresponding dataclass, we could assume the key "slurm_host" must exist in task_template.custom dict. Then, maybe directly accessing through the bracket is more straightforward here?

Let me know what you think. Thanks!


# Construct srun command for Slurm cluster
cmd = _get_srun_cmd(srun_conf=srun_conf, entrypoint=" ".join(task_template.container.args))

# Run Slurm job
if self._conn is None:
await self._connect(slurm_host)
res = await self._conn.run(cmd, check=True)

# Direct return for sbatch
# job_id = res.stdout.split()[-1]
# Use echo trick for srun
job_id = res.stdout.strip()

return SlurmJobMetadata(job_id=job_id, slurm_host=slurm_host)

async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource:
await self._connect(resource_meta.slurm_host)
res = await self._conn.run(f"scontrol show job {resource_meta.job_id}", check=True)

# Determine the current flyte phase from Slurm job state
job_state = "running"
for o in res.stdout.split(" "):
if "JobState" in o:
job_state = o.split("=")[1].strip().lower()
cur_phase = convert_to_flyte_phase(job_state)

return Resource(phase=cur_phase)

async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None:
await self._connect(resource_meta.slurm_host)
_ = await self._conn.run(f"scancel {resource_meta.job_id}", check=True)

async def _connect(self, slurm_host: str) -> None:
"""Make an SSH client connection."""
self._conn = await asyncssh.connect(host=slurm_host)


def _get_srun_cmd(srun_conf: Dict[str, str], entrypoint: str) -> str:
"""Construct Slurm srun command.

Flyte entrypoint, pyflyte-execute, is run within a bash shell process.

Args:
srun_conf: Options of srun command.
entrypoint: Flyte entrypoint.

Returns:
cmd: Slurm srun command.
"""
# Setup srun options
cmd = ["srun"]
for opt, val in srun_conf.items():
cmd.extend([f"--{opt}", str(val)])

cmd.extend(["bash", "-c"])
cmd = " ".join(cmd)

cmd += f""" '# Setup environment variables
export PATH=$PATH:/opt/anaconda/anaconda3/bin;

# Run pyflyte-execute in a pre-built conda env
source activate dev;
{entrypoint};

# A trick to show Slurm job id on stdout
echo $SLURM_JOB_ID;'
"""

return cmd


AgentRegistry.register(SlurmFunctionAgent())
Loading
Loading