Skip to content

Commit

Permalink
Update SDK to use new workflows API (#783)
Browse files Browse the repository at this point in the history
The only thing remaining now is the frontend. Once the frontend swaps to the new experiments and workflows API, we can remove the old experiments endpoint and all related code that was only needed for that route.

---------

Signed-off-by: Nathan Brake <[email protected]>
Co-authored-by: Peter Wilson <[email protected]>
  • Loading branch information
njbrake and peteski22 authored Feb 5, 2025
1 parent 81944e1 commit 9c2fe5b
Show file tree
Hide file tree
Showing 7 changed files with 302 additions and 51 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"python.testing.pytestEnabled": true,
"python.analysis.extraPaths": [
"./lumigator/schemas",
"./lumigator/jobs"
"./lumigator/jobs",
"./lumigator/sdk"
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,103 @@ def test_full_experiment_launch(
retrieve_and_validate_workflow_logs(local_client, workflow_1_details.id)
delete_experiment_and_validate(local_client, experiment_id)

experiment = local_client.post(
"/experiments/new/",
headers=POST_HEADER,
json={
"name": "test_create_exp_workflow_check_results",
"description": "Test for an experiment with associated workflows",
},
)
assert experiment.status_code == 201
experiment_id = experiment.json()["id"]

# run a workflow for that experiment
workflow_1 = WorkflowResponse.model_validate(
local_client.post(
"/workflows/",
headers=POST_HEADER,
json={
"name": "Workflow_1",
"description": "Test workflow for inf and eval",
"model": TEST_CAUSAL_MODEL,
"dataset": str(dataset.id),
"experiment_id": experiment_id,
"max_samples": 1,
},
).json()
)

# Wait till the workflow is done
workflow_1_details = wait_for_workflow_complete(local_client, workflow_1.id)

experiment_results = GetExperimentResponse.model_validate(
local_client.get(f"/experiments/new/{experiment_id}").json()
)

assert workflow_1_details.experiment_id == experiment_results.id
assert len(experiment_results.workflows) == 1
# the presigned url can be different but everything else should be the same
assert workflow_1_details.model_dump(
exclude={"artifacts_download_url"}
) == experiment_results.workflows[0].model_dump(exclude={"artifacts_download_url"})

# add another workflow to the experiment
workflow_2 = WorkflowResponse.model_validate(
local_client.post(
"/workflows/",
headers=POST_HEADER,
json={
"name": "Workflow_2",
"description": "Test workflow for inf and eval",
"model": TEST_CAUSAL_MODEL,
"dataset": str(dataset.id),
"experiment_id": experiment_id,
"max_samples": 1,
},
).json()
)

# Wait till the workflow is done
workflow_2_details = wait_for_workflow_complete(local_client, workflow_2.id)

# now get the results of the experiment
experiment_results = GetExperimentResponse.model_validate(
local_client.get(f"/experiments/new/{experiment_id}").json()
)
# make sure it has the info for both workflows
assert len(experiment_results.workflows) == 2
# make sure both workflows are in the experiment, excluding that presigned url again
assert workflow_1_details.model_dump(exclude={"artifacts_download_url"}) in [
w.model_dump(exclude={"artifacts_download_url"}) for w in experiment_results.workflows
]
assert workflow_2_details.model_dump(exclude={"artifacts_download_url"}) in [
w.model_dump(exclude={"artifacts_download_url"}) for w in experiment_results.workflows
]

# get the logs
logs_job_response = local_client.get(f"/workflows/{workflow_1_details.id}/logs")
logs = JobLogsResponse.model_validate(logs_job_response.json())
assert logs.logs is not None
# Very naive way to check whether both of the logs we expect are in here
# This will need to be updated as we improve the log retrieval structure.
assert "Inference results stored at" in logs.logs
assert "Storing evaluation results into" in logs.logs
# assert that inference comes before eval
assert logs.logs.index("Inference results stored at") < logs.logs.index(
"Storing evaluation results into"
)

# delete the experiment
local_client.delete(f"/experiments/new/{experiment_id}")
response = local_client.get(f"/experiments/new/{experiment_id}")
assert response.status_code == 404
# make sure the workflow results also were deleted
response = local_client.get(f"/workflows/{workflow_1_details.id}")
assert response.status_code == 404
response = local_client.get(f"/workflows/{workflow_2_details.id}")
assert response.status_code == 404


def test_experiment_non_existing(local_client: TestClient, dependency_overrides_services):
non_existing_id = "71aaf905-4bea-4d19-ad06-214202165812"
Expand Down
47 changes: 16 additions & 31 deletions lumigator/sdk/lumigator_sdk/experiments.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,50 @@
from http import HTTPMethod
from json import dumps
from uuid import UUID

from lumigator_schemas.experiments import (
ExperimentCreate,
ExperimentIdCreate,
ExperimentIdResponse,
ExperimentResponse,
ExperimentResultDownloadResponse,
ExperimentResultResponse,
GetExperimentResponse,
)
from lumigator_schemas.extras import ListingResponse

from lumigator_sdk.client import ApiClient
from lumigator_sdk.strict_schemas import ExperimentCreate as ExperimentCreateStrict
from lumigator_sdk.strict_schemas import ExperimentIdCreate as ExperimentIdCreateStrict


class Experiments:
EXPERIMENTS_ROUTE = "experiments"
EXPERIMENTS_ROUTE = "experiments/new"

def __init__(self, c: ApiClient):
self.__client = c

def create_experiment(self, experiment: ExperimentCreate) -> ExperimentResponse:
def create_experiment(self, experiment: ExperimentIdCreate) -> ExperimentIdResponse:
"""Creates a new experiment."""
ExperimentCreateStrict.model_validate(ExperimentCreate.model_dump(experiment))
ExperimentIdCreateStrict.model_validate(ExperimentIdCreate.model_dump(experiment))
response = self.__client.get_response(
self.EXPERIMENTS_ROUTE, HTTPMethod.POST, dumps(experiment)
self.EXPERIMENTS_ROUTE, HTTPMethod.POST, experiment.model_dump_json()
)

data = response.json()
return ExperimentResponse(**data)
return ExperimentIdResponse(**data)

def get_experiment(self, experiment_id: UUID) -> ExperimentResponse | None:
def get_experiment(self, experiment_id: str) -> GetExperimentResponse | None:
"""Returns information on the experiment for the specified ID."""
response = self.__client.get_response(f"{self.EXPERIMENTS_ROUTE}/{experiment_id}")

data = response.json()
return ExperimentResponse(**data)
return GetExperimentResponse(**data)

def get_experiments(
self, skip: int = 0, limit: int = 100
) -> ListingResponse[ExperimentResponse]:
"""Returns information on all experiments."""
response = self.__client.get_response(self.EXPERIMENTS_ROUTE)
response = self.__client.get_response(f"{self.EXPERIMENTS_ROUTE}/all")

data = response.json()
return ListingResponse[ExperimentResponse](**data)

def get_experiment_result(self, experiment_id: UUID) -> ExperimentResultResponse | None:
"""Returns the result of the experiment for the specified ID."""
response = self.__client.get_response(f"{self.EXPERIMENTS_ROUTE}/{experiment_id}/result")

data = response.json()
return ExperimentResultResponse(**data)

def get_experiment_result_download(
self, experiment_id: UUID
) -> ExperimentResultDownloadResponse | None:
"""Returns the result of the experiment for the specified ID."""
response = self.__client.get_response(
f"{self.EXPERIMENTS_ROUTE}/{experiment_id}/result/download"
)

data = response.json()
return ExperimentResultDownloadResponse(**data)
def delete_experiment(self, experiment_id: str) -> None:
"""Deletes the experiment for the specified ID."""
self.__client.get_response(f"{self.EXPERIMENTS_ROUTE}/{experiment_id}", HTTPMethod.DELETE)
return None
4 changes: 4 additions & 0 deletions lumigator/sdk/lumigator_sdk/lumigator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from lumigator_sdk.client import ApiClient
from lumigator_sdk.completions import Completions
from lumigator_sdk.experiments import Experiments
from lumigator_sdk.health import Health
from lumigator_sdk.jobs import Jobs
from lumigator_sdk.lm_datasets import Datasets
from lumigator_sdk.models import Models
from lumigator_sdk.workflows import Workflows

# Only retries initial connections
# No HTTP errors are retried
Expand Down Expand Up @@ -54,3 +56,5 @@ def __init__(
self.jobs = Jobs(self.client)
self.datasets = Datasets(self.client)
self.models = Models(self.client)
self.workflows = Workflows(self.client)
self.experiments = Experiments(self.client)
19 changes: 7 additions & 12 deletions lumigator/sdk/lumigator_sdk/strict_schemas.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from lumigator_schemas.completions import CompletionResponse
from lumigator_schemas.datasets import DatasetDownloadResponse, DatasetResponse
from lumigator_schemas.experiments import (
ExperimentCreate,
ExperimentIdCreate,
ExperimentResponse,
ExperimentResultDownloadResponse,
ExperimentResultResponse,
)
from lumigator_schemas.extras import HealthResponse, ListingResponse
from lumigator_schemas.jobs import (
Expand All @@ -20,6 +18,7 @@
JobResultResponse,
JobSubmissionResponse,
)
from lumigator_schemas.workflows import WorkflowCreateRequest
from pydantic import ConfigDict


Expand All @@ -35,22 +34,14 @@ class DatasetResponse(DatasetResponse, from_attributes=True):
model_config = ConfigDict(extra="forbid")


class ExperimentCreate(ExperimentCreate):
class ExperimentIdCreate(ExperimentIdCreate):
model_config = ConfigDict(extra="forbid")


class ExperimentResponse(ExperimentResponse, from_attributes=True):
model_config = ConfigDict(extra="forbid")


class ExperimentResultResponse(ExperimentResultResponse, from_attributes=True):
model_config = ConfigDict(extra="forbid")


class ExperimentResultDownloadResponse(ExperimentResultDownloadResponse):
model_config = ConfigDict(extra="forbid")


class HealthResponse(HealthResponse):
model_config = ConfigDict(extra="forbid")

Expand Down Expand Up @@ -101,3 +92,7 @@ class JobResultResponse(JobResultResponse, from_attributes=True):

class JobResultDownloadResponse(JobResultDownloadResponse):
model_config = ConfigDict(extra="forbid")


class WorkflowCreateRequest(WorkflowCreateRequest):
model_config = ConfigDict(extra="forbid")
47 changes: 47 additions & 0 deletions lumigator/sdk/lumigator_sdk/workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from http import HTTPMethod

from lumigator_schemas.jobs import JobLogsResponse
from lumigator_schemas.workflows import (
WorkflowCreateRequest,
WorkflowDetailsResponse,
WorkflowResponse,
)

from lumigator_sdk.client import ApiClient
from lumigator_sdk.strict_schemas import WorkflowCreateRequest as WorkflowCreateRequestStrict


class Workflows:
WORKFLOWS_ROUTE = "workflows"

def __init__(self, c: ApiClient):
self.__client = c

def create_workflow(self, workflow: WorkflowCreateRequest) -> WorkflowResponse:
"""Creates a new experiment."""
WorkflowCreateRequestStrict.model_validate(WorkflowCreateRequest.model_dump(workflow))
response = self.__client.get_response(
self.WORKFLOWS_ROUTE, HTTPMethod.POST, workflow.model_dump_json()
)

data = response.json()
return WorkflowResponse(**data)

def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None:
"""Returns information on the experiment for the specified ID."""
response = self.__client.get_response(f"{self.WORKFLOWS_ROUTE}/{workflow_id}")

data = response.json()
return WorkflowDetailsResponse(**data)

def get_workflow_logs(self, workflow_id: str) -> JobLogsResponse | None:
"""Returns information on the experiment for the specified ID."""
response = self.__client.get_response(f"{self.WORKFLOWS_ROUTE}/{workflow_id}/logs")

data = response.json()
return JobLogsResponse(**data)

def delete_workflow(self, workflow_id: str) -> None:
"""Deletes the experiment for the specified ID."""
self.__client.get_response(f"{self.WORKFLOWS_ROUTE}/{workflow_id}", HTTPMethod.DELETE)
return None
Loading

0 comments on commit 9c2fe5b

Please sign in to comment.