Skip to content

Commit

Permalink
chore(launch): pushToRunQueueByName mutation support in sdk (wandb#4292)
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning authored Oct 14, 2022
1 parent 6ef7dbb commit a7e5e81
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 130 deletions.
123 changes: 123 additions & 0 deletions tests/unit_tests/tests_launch/test_launch_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import wandb
from wandb.sdk.launch.launch_add import launch_add


def test_launch_add_default(relay_server, user):
proj = "test_project"
uri = "https://github.com/wandb/examples.git"
entry_point = ["python", "/examples/examples/launch/launch-quickstart/train.py"]
args = {
"uri": uri,
"project": proj,
"entity": user,
"queue_name": "default",
"entry_point": entry_point,
}

run = wandb.init(project=proj)

with relay_server() as relay:
queued_run = launch_add(**args)

assert queued_run.id
assert queued_run.state == "pending"
assert queued_run.entity == args["entity"]
assert queued_run.project == args["project"]
assert queued_run.queue_name == args["queue_name"]

for comm in relay.context.raw_data:
q = comm["request"].get("query")
# below should fail for non-existent default queue,
# then fallback to legacy method
if q and "mutation pushToRunQueueByName(" in str(q):
assert comm["response"].get("data", {}).get("pushToRunQueueByName") is None
elif q and "mutation pushToRunQueue(" in str(q):
assert comm["response"]["data"]["pushToRunQueue"] is not None

run.finish()


def test_push_to_runqueue_exists(relay_server, user):
proj = "test_project"
queue = "existing-queue"
uri = "https://github.com/wandb/examples.git"
entry_point = ["python", "/examples/examples/launch/launch-quickstart/train.py"]
args = {
"uri": uri,
"project": proj,
"entity": user,
"queue": "default",
"entry_point": entry_point,
}

run = wandb.init(project=proj)
api = wandb.sdk.internal.internal_api.Api()

with relay_server() as relay:
api.create_run_queue(entity=user, project=proj, queue_name=queue, access="USER")

result = api.push_to_run_queue(queue, args)

assert result["runQueueItemId"]

for comm in relay.context.raw_data:
q = comm["request"].get("query")
if q and "mutation pushToRunQueueByName(" in str(q):
assert comm["response"]["data"] is not None
elif q and "mutation pushToRunQueue(" in str(q):
raise Exception("should not be falling back to legacy here")

run.finish()


def test_push_to_default_runqueue_notexist(relay_server, user):
api = wandb.sdk.internal.internal_api.Api()
proj = "test_project"
uri = "https://github.com/wandb/examples.git"
entry_point = ["python", "/examples/examples/launch/launch-quickstart/train.py"]

launch_spec = {
"uri": uri,
"entity": user,
"project": proj,
"entry_point": entry_point,
}
run = wandb.init(project=proj)

with relay_server():
res = api.push_to_run_queue("nonexistent-queue", launch_spec)

assert not res

run.finish()


def test_push_to_runqueue_old_server(relay_server, user, monkeypatch):
proj = "test_project"
queue = "existing-queue"
uri = "https://github.com/wandb/examples.git"
entry_point = ["python", "/examples/examples/launch/launch-quickstart/train.py"]
args = {
"uri": uri,
"project": proj,
"entity": user,
"queue": "default",
"entry_point": entry_point,
}

run = wandb.init(project=proj)
api = wandb.sdk.internal.internal_api.Api()

monkeypatch.setattr(
"wandb.sdk.internal.internal_api.Api.push_to_run_queue_by_name",
lambda *args: None,
)

with relay_server():
api.create_run_queue(entity=user, project=proj, queue_name=queue, access="USER")

result = api.push_to_run_queue(queue, args)

assert result["runQueueItemId"]

run.finish()
61 changes: 1 addition & 60 deletions tests/unit_tests_old/tests_launch/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
import wandb.util as util
import yaml
from wandb.apis import PublicApi
from wandb.apis.public import Run
from wandb.errors import CommError, LaunchError
from wandb.errors import LaunchError
from wandb.sdk.launch.agent.agent import LaunchAgent
from wandb.sdk.launch.builder.build import pull_docker_image
from wandb.sdk.launch.builder.docker import DockerBuilder
from wandb.sdk.launch.launch_add import launch_add
from wandb.sdk.launch.utils import PROJECT_DOCKER_ARGS, PROJECT_SYNCHRONOUS

from tests.unit_tests_old.utils import fixture_open, notebook_path
Expand Down Expand Up @@ -352,18 +350,6 @@ def test_launch_resource_args(
check_mock_run_info(mock_with_run_info, EMPTY_BACKEND_CONFIG, kwargs)


def test_launch_add_base_queued_run(live_mock_server):
queued_run = launch_add("https://wandb.ai/mock_server_entity/tests/runs/1")
assert queued_run.state == "pending"
assert queued_run.id == "1"
assert queued_run.entity == "mock_server_entity"
assert queued_run.project == "tests"

live_mock_server.set_ctx({"run_queue_item_return_type": "claimed"})
run = queued_run.wait_until_finished()
assert isinstance(run, Run)


@pytest.mark.skipif(
sys.version_info < (3, 5),
reason="wandb launch is not available for python versions <3.5",
Expand Down Expand Up @@ -712,51 +698,6 @@ def test_run_in_launch_context_with_artifacts_no_match(
assert arti_info["used_name"] == "old_name:v0"


def test_push_to_runqueue(live_mock_server, test_settings):
api = wandb.sdk.internal.internal_api.Api(
default_settings=test_settings, load_settings=False
)
launch_spec = {
"uri": "https://wandb.ai/mock_server_entity/test/runs/1",
"entity": "mock_server_entity",
"project": "test",
}
api.push_to_run_queue("default", launch_spec)
ctx = live_mock_server.get_ctx()
assert len(ctx["run_queues"]["1"]) == 1


def test_push_to_default_runqueue_notexist(live_mock_server, test_settings):
live_mock_server.set_ctx({"run_queues_return_default": False})
api = wandb.sdk.internal.internal_api.Api(
default_settings=test_settings, load_settings=False
)
launch_spec = {
"uri": "https://wandb.ai/mock_server_entity/test/runs/1",
"entity": "mock_server_entity",
"project": "test",
}
api.push_to_run_queue("default", launch_spec)
ctx = live_mock_server.get_ctx()
assert len(ctx["run_queues"]["1"]) == 1


def test_push_to_runqueue_notfound(live_mock_server, test_settings, capsys):
api = wandb.sdk.internal.internal_api.Api(
default_settings=test_settings, load_settings=False
)
launch_spec = {
"uri": "https://wandb.ai/mock_server_entity/test/runs/1",
"entity": "mock_server_entity",
"project": "test",
}
api.push_to_run_queue("not-found", launch_spec)
ctx = live_mock_server.get_ctx()
_, err = capsys.readouterr()
assert len(ctx["run_queues"]["1"]) == 0
assert "Unable to push to run queue not-found. Queue not found" in err


# this test includes building a docker container which can take some time,
# hence the timeout. caching should usually keep this under 30 seconds
@pytest.mark.flaky
Expand Down
26 changes: 0 additions & 26 deletions tests/unit_tests_old/tests_launch/test_launch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,6 @@ def patched_update_finished(self, job_id):
)


def test_launch_add_default(runner, test_settings, live_mock_server):
args = [
"https://wandb.ai/mock_server_entity/test_project/runs/run",
"--project=test_project",
"--entity=mock_server_entity",
"--queue=default",
]
result = runner.invoke(cli.launch, args)
assert result.exit_code == 0
ctx = live_mock_server.get_ctx()
assert len(ctx["run_queues"]["1"]) == 1


def test_launch_add_config_file(runner, test_settings, live_mock_server):
args = [
"https://wandb.ai/mock_server_entity/test_project/runs/run",
"--project=test_project",
"--entity=mock_server_entity",
"--queue=default",
]
result = runner.invoke(cli.launch, args)
assert result.exit_code == 0
ctx = live_mock_server.get_ctx()
assert len(ctx["run_queues"]["1"]) == 1


# this test includes building a docker container which can take some time.
# hence the timeout. caching should usually keep this under 30 seconds
@pytest.mark.flaky
Expand Down
13 changes: 12 additions & 1 deletion tests/unit_tests_old/tests_launch/test_launch_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def job_download_func(root):
check_mock_run_info(mock_with_run_info, EMPTY_BACKEND_CONFIG, kwargs)


def test_launch_add_container_queued_run(live_mock_server, mocked_public_artifact):
def test_launch_add_container_queued_run(
live_mock_server, mocked_public_artifact, monkeypatch
):
def job_download_func(root=None):
if root is None:
root = tempfile.mkdtemp()
Expand All @@ -229,6 +231,15 @@ def job_download_func(root=None):

mocked_public_artifact(job_download_func)

def patched_push_to_run_queue_by_name(*args, **kwargs):
return {"runQueueItemId": "1"}

monkeypatch.setattr(
wandb.sdk.internal.internal_api.Api,
"push_to_run_queue_by_name",
lambda *arg, **kwargs: patched_push_to_run_queue_by_name(*arg, **kwargs),
)

queued_run = launch_add(job="test-job:v0")
with pytest.raises(CommError):
queued_run.wait_until_finished()
Loading

0 comments on commit a7e5e81

Please sign in to comment.