From 310e4f1aafbfb9845d5f5889cf11b49fc73addd0 Mon Sep 17 00:00:00 2001 From: Jonathan Lessinger Date: Sun, 7 Jan 2024 17:13:41 -0500 Subject: [PATCH] [AIC-py][editor] server v2 --- python/requirements.txt | 4 +- python/src/aiconfig/editor/server/server.py | 5 +- .../src/aiconfig/editor/server/server_v2.py | 247 +++++++++++++ .../editor/server/server_v2_common.py | 348 ++++++++++++++++++ .../editor/server/server_v2_operation_lib.py | 181 +++++++++ .../editor/server/server_v2_run_operation.py | 298 +++++++++++++++ .../aiconfig/editor/server/server_v2_utils.py | 250 +++++++++++++ .../src/aiconfig/scripts/aiconfig_cli_v2.py | 126 +++++++ python/tests/test_editor_server_v2.py | 260 +++++++++++++ 9 files changed, 1716 insertions(+), 3 deletions(-) create mode 100644 python/src/aiconfig/editor/server/server_v2.py create mode 100644 python/src/aiconfig/editor/server/server_v2_common.py create mode 100644 python/src/aiconfig/editor/server/server_v2_operation_lib.py create mode 100644 python/src/aiconfig/editor/server/server_v2_run_operation.py create mode 100644 python/src/aiconfig/editor/server/server_v2_utils.py create mode 100644 python/src/aiconfig/scripts/aiconfig_cli_v2.py create mode 100644 python/tests/test_editor_server_v2.py diff --git a/python/requirements.txt b/python/requirements.txt index 0e3b7ffde..5b40f6b5f 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -3,10 +3,12 @@ black flake8 flask-cors flask[async] +fastapi google-generativeai huggingface_hub hypothesis==6.91.0 -lastmile-utils==0.0.14 +hypercorn +lastmile-utils==0.0.16 mock nest_asyncio nltk diff --git a/python/src/aiconfig/editor/server/server.py b/python/src/aiconfig/editor/server/server.py index f9cd467f9..052ffab00 100644 --- a/python/src/aiconfig/editor/server/server.py +++ b/python/src/aiconfig/editor/server/server.py @@ -300,9 +300,10 @@ def run_async_config_in_thread(): status=200, content_type="application/json", ) - + # Run without streaming inference_options = InferenceOptions(stream=stream) + def run_async_config_in_thread(): asyncio.run( aiconfig.run( @@ -324,7 +325,7 @@ def run_async_config_in_thread(): code=200, aiconfig=aiconfig, ).to_flask_format() - + except Exception as e: return HttpResponseWithAIConfig( # diff --git a/python/src/aiconfig/editor/server/server_v2.py b/python/src/aiconfig/editor/server/server_v2.py new file mode 100644 index 000000000..906ea4e66 --- /dev/null +++ b/python/src/aiconfig/editor/server/server_v2.py @@ -0,0 +1,247 @@ +## SECTION: Imports and Constants + +import logging +import os +from contextlib import asynccontextmanager +from typing import cast +import webbrowser + +import lastmile_utils.lib.core.api as core_utils +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse, JSONResponse +from fastapi.staticfiles import StaticFiles +from hypercorn.asyncio import serve # type: ignore +from hypercorn.config import Config +from hypercorn.typing import ASGIFramework +from result import Err, Ok, Result +import result + +from aiconfig.editor.server import server_v2_utils as server_utils +import aiconfig.editor.server.server_v2_common as server_common +import aiconfig.editor.server.server_v2_operation_lib as operation_lib + + +THIS_DIR = os.path.dirname(os.path.realpath(__file__)) +STATIC_DIR = os.path.join(THIS_DIR, "static") + +DEFAULT_PORT = 8080 + +logger: logging.Logger = core_utils.get_logger(__name__, log_file_path="editor_server_v2.log") + + +## SECTION: Global state initialization + + +global_state = server_common.GlobalState( + editor_config=server_common.EditServerConfig(), + active_instances=dict(), +) + + +async def _run_websocket_connection(initial_loop_state: server_utils.LoopState, websocket: WebSocket) -> Result[str, str]: + """This is the main websocket loop.""" + loop_state = initial_loop_state + instance_state = loop_state.instance_state + instance_id = instance_state.instance_id + global global_state + global_state.active_instances[instance_id] = server_common.ConnectionState(websocket=websocket) + + logger.info("Starting websocket loop") + logger.info(f"{instance_state.aiconfig_path=}") + while True: + logger.debug(f"{loop_state.operation_task=}, {loop_state.recv_task=}") + try: + res_handle = await server_utils.handle_websocket_loop_iteration(loop_state, websocket) + match res_handle: + case Ok((response, new_loop_state)): + loop_state = new_loop_state + if response: + await websocket.send_text(response.serialize()) + case Err(e): + logger.critical(f"Can't generate response or update loop state.\n{e}") + return await _cleanup_websocket_connection(instance_state.instance_id, global_state.active_instances[instance_state.instance_id]) + except (WebSocketDisconnect, RuntimeError) as e: + ewt = core_utils.ErrWithTraceback(e) + logger.error(f"Websocket loop terminated: {e}, {ewt}") + return await _cleanup_websocket_connection(instance_state.instance_id, global_state.active_instances[instance_state.instance_id]) + + +@asynccontextmanager +async def _app_lifespan(app: FastAPI): + global global_state + logger.info("Start lifespan") + yield + logger.info("Shutting down app.") + cleanup_res_list = [ + await _cleanup_websocket_connection(instance_id, websocket_state) for instance_id, websocket_state in global_state.active_instances.items() + ] + cleanup_ok, cleanup_err = core_utils.result_reduce_list_separate(cleanup_res_list) + logger.info("Cleaned up websockets. %s", cleanup_ok) + if len(cleanup_err) > 0: + logger.error("Failed to clean up websockets. %s", cleanup_err) + del global_state + + +app = FastAPI(lifespan=_app_lifespan) + + +## SECTION: Programmatic Server API (run entrypoint) + + +async def run_backend_server(edit_config: server_common.EditServerConfig) -> Result[str, str]: + global global_state + global_state.editor_config = edit_config + global logger + logger = core_utils.get_logger(__name__, log_file_path="editor_server_v2.log", log_level=edit_config.log_level) + + await _init_app_state(app, edit_config) + + def _outcome_to_str(outcome: server_common.ServerBindOutcome) -> Result[str, str]: + match outcome: + case server_common.ServerBindOutcome.SUCCESS: + return Ok(f"Server running on port {edit_config.server_port}") + case server_common.ServerBindOutcome.PORT_IN_USE: + return Err(f"Port {edit_config.server_port} in use") + case server_common.ServerBindOutcome.OTHER_FAILURE: + return Err(f"Failed to run server on port {edit_config.server_port}") + + if edit_config.server_mode != server_common.ServerMode.DEBUG_BACKEND: + try: + logger.info(f"Opening browser at http://localhost:{edit_config.server_port}") + webbrowser.open(f"http://localhost:{edit_config.server_port}") + except Exception as e: + logger.warning(f"Failed to open browser: {e}. Please open http://localhost:{edit_config.server_port} manually.") + + match edit_config.server_port: + case int(): + result = await _run_backend_server_on_port(edit_config.log_level, edit_config.server_port) + return _outcome_to_str(result) + case None: + port_try = DEFAULT_PORT + max_port = 65535 + while port_try < max_port: + backend_res_on_port = await _run_backend_server_on_port(edit_config.log_level, port_try) + logger.debug(f"{backend_res_on_port=}") + match backend_res_on_port: + case server_common.ServerBindOutcome.PORT_IN_USE: + logger.info("Going to try next port...") + port_try += 1 + continue + case _: + return _outcome_to_str(backend_res_on_port) + + return Err(f"Failed to run backend server on any port in range {DEFAULT_PORT} to {max_port}") + + +async def _run_backend_server_on_port(log_level: str | int, port: int) -> server_common.ServerBindOutcome: + logger.info(f"Running backend server on port {port}") + + log_level_for_hypercorn = ( + # + log_level.upper() + if isinstance(log_level, str) + else logging.getLevelName(log_level) + ) + fastapi_app: ASGIFramework = cast(ASGIFramework, app) + try: + logger.info(f"Starting server on port {port}") + await serve( + fastapi_app, + Config.from_mapping( + # + _bind=[f"localhost:{port}"], + loglevel=log_level_for_hypercorn, + use_reloader=True, + keep_alive_timeout=365 * 24 * 3600, + ), + ) + logger.info(f"Done running server on port {port}") + return server_common.ServerBindOutcome.SUCCESS + except OSError as e_os: + logger.warning(f"Port in use: {port}: {e_os}") + return server_common.ServerBindOutcome.PORT_IN_USE + except Exception as e: + logger.error(f"Failed to run backend server on port {port}: {type(e)}") + logger.error(core_utils.ErrWithTraceback(e)) + return server_common.ServerBindOutcome.OTHER_FAILURE + + +## SECTION: Web API. HTTP endpoints: static files, root, and websocket connect + + +@app.get("/") +def home(): + logger.info(f"ROOT, {os.getcwd()}") + index_path = os.path.join(STATIC_DIR, "index.html") + res_index = core_utils.read_text_file(index_path) + match res_index: + case Ok(index): + return HTMLResponse(index) + case Err(e): + logger.error(f"Failed to load index.html: {e}") + return HTMLResponse(f"

Failed to load index.html: {e}

") + + +@app.get("/api/server_status") +def server_status(): + data = {"status": "OK"} + return JSONResponse(content=data, status_code=200) + + +@app.websocket("/ws_manage_aiconfig_instance") +async def accept_and_run_websocket(websocket: WebSocket): + logger.info("Accepting websocket connection") + await websocket.accept() + global global_state + + initial_loop_state = await server_utils.LoopState.new(websocket, global_state.editor_config) + res_websocket: Result[str, str] = await result.do_async( + await _run_websocket_connection(initial_loop_state_ok, websocket) + # + for initial_loop_state_ok in initial_loop_state + ) + logger.info(f"{res_websocket=}") + match res_websocket: + case Ok(result_): + return JSONResponse(content=result_, status_code=200) + case Err(e): + return JSONResponse(content=f"Failed to run websocket: {e}", status_code=500) + + +## SECTION: Global state management + + +async def _init_app_state(app: FastAPI, edit_config: server_common.EditServerConfig): + logger.setLevel(edit_config.log_level) + logger.info("Edit config: %s", edit_config.model_dump_json()) + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + app.mount("/static", StaticFiles(directory=os.path.join(STATIC_DIR, "static")), name="static") + + res_load_module = await ( + server_common.get_validated_path(edit_config.parsers_module_path) + # + .and_then_async(operation_lib.load_user_parser_module) + ) + match res_load_module: + case Ok(_module): + logger.info(f"Loaded module {edit_config.parsers_module_path}, output: {_module}") + case Err(e): + logger.warning(f"Failed to load module {edit_config.parsers_module_path}: {e}") + + +async def _cleanup_websocket_connection(instance_id: str, websocket_state: server_common.ConnectionState) -> Result[str, str]: + logger.info(f"Closing websocket connection for instance {websocket_state}") + try: + await websocket_state.websocket.close() + return Ok(f"Closed websocket connection for instance {instance_id}") + except Exception as e: + return Err(f"Failed to close websocket connection for instance {instance_id}: {e}") diff --git a/python/src/aiconfig/editor/server/server_v2_common.py b/python/src/aiconfig/editor/server/server_v2_common.py new file mode 100644 index 000000000..cd82ca2ae --- /dev/null +++ b/python/src/aiconfig/editor/server/server_v2_common.py @@ -0,0 +1,348 @@ +## SECTION: Imports + +from dataclasses import dataclass +from enum import Enum +import json +import os +from typing import Any, NewType, Optional, ParamSpec, TypeVar + +import lastmile_utils.lib.core.api as core_utils +from fastapi import WebSocket +from result import Err, Ok, Result + +import asyncio +from dataclasses import dataclass +from typing import Generic, TypeVar + +from result import Result +from typing import Any, Literal, TypeVar + +import lastmile_utils.lib.core.api as core_utils +from pydantic import Field + +from aiconfig.schema import Prompt + +import aiconfig.editor.server.server_v2_common as server_common + + +from aiconfig.Config import AIConfigRuntime + +logger = core_utils.get_logger(__name__, log_file_path="editor_server_v2.log") + +## SECTION: Common types + +P = ParamSpec("P") +T_Output = TypeVar("T_Output", covariant=True) + +UnvalidatedPath = NewType("UnvalidatedPath", str) +ValidatedPath = NewType("ValidatedPath", str) + + +class GetInstanceStatus(core_utils.Record): + # This is one of the equivalents of the old /api/server_status endpoint + command_name: Literal["get_instance_status"] + instance_id: str + status: str + + +class ListModels(core_utils.Record): + command_name: Literal["list_models"] + + +class LoadModelParserModule(core_utils.Record): + command_name: Literal["load_model_parser_module"] + path: server_common.UnvalidatedPath + + +class Load(core_utils.Record): + command_name: Literal["load"] + path: server_common.UnvalidatedPath | None = None + + +class Save(core_utils.Record): + command_name: Literal["save"] + path: server_common.UnvalidatedPath + + +class Create(core_utils.Record): + command_name: Literal["create"] + + +class Run(core_utils.Record): + command_name: Literal["run"] + prompt_name: str + params: dict[str, Any] = Field(default_factory=dict) + stream: bool = False + + +class AddPrompt(core_utils.Record): + command_name: Literal["add_prompt"] + prompt_name: str + prompt_data: Prompt + index: int + + +class UpdatePrompt(core_utils.Record): + command_name: Literal["update_prompt"] + prompt_name: str + prompt_data: Prompt + + +class DeletePrompt(core_utils.Record): + command_name: Literal["delete_prompt"] + prompt_name: str + + +class UpdateModel(core_utils.Record): + command_name: Literal["update_model"] + model_name: str | None + settings: dict[str, Any] + prompt_name: str | None + + +class SetParameter(core_utils.Record): + command_name: Literal["set_parameter"] + parameter_name: str + parameter_value: str | dict[str, Any] + prompt_name: str | None + + +class SetParameters(core_utils.Record): + command_name: Literal["set_parameters"] + parameters: dict[str, Any] + prompt_name: str | None + + +class DeleteParameter(core_utils.Record): + command_name: Literal["delete_parameter"] + parameter_name: str + prompt_name: str + + +class SetName(core_utils.Record): + command_name: Literal["set_name"] + name: str + + +class SetDescription(core_utils.Record): + command_name: Literal["set_description"] + description: str + + +class MockRun(core_utils.Record): + """For testing only""" + + command_name: Literal["mock_run"] + seconds: float + + +# THIS MUST BE KEPT IN SYNC WITH T_OPERATION BELOW +Operation = ( + GetInstanceStatus + | ListModels + | LoadModelParserModule + | Create + | Load + | Run + | AddPrompt + | UpdatePrompt + | DeletePrompt + | Save + | UpdateModel + | SetParameter + | SetParameters + | DeleteParameter + | SetName + | SetDescription + | MockRun +) + +# THIS MUST BE KEPT IN SYNC WITH OPERATION ABOVE +T_Operation = TypeVar( + "T_Operation", + GetInstanceStatus, + ListModels, + LoadModelParserModule, + Create, + Load, + Run, + AddPrompt, + UpdatePrompt, + DeletePrompt, + Save, + UpdateModel, + SetParameter, + SetParameters, + DeleteParameter, + SetName, + SetDescription, + MockRun, + contravariant=True, +) + + +# Cancel is the only command that is not an operation. +class Cancel(core_utils.Record): + command_name: Literal["cancel"] + + +Command = Operation | Cancel + + +class SerializableCommand(core_utils.Record): + command: Command = Field(..., discriminator="command_name") + + +class ServerMode(Enum): + DEBUG_SERVERS = "DEBUG_SERVERS" + DEBUG_BACKEND = "DEBUG_BACKEND" + PROD = "PROD" + + +class ServerBindOutcome(Enum): + SUCCESS = "SUCCESS" + PORT_IN_USE = "PORT_IN_USE" + OTHER_FAILURE = "OTHER_FAILURE" + + +class EditServerConfig(core_utils.Record, core_utils.EnumValidatedRecordMixin): + server_port: Optional[int] = None + aiconfig_path: str = "my_aiconfig.aiconfig.json" + log_level: str | int = "INFO" + log_file_path: str = "editor_server_v2.log" + server_mode: ServerMode = ServerMode.PROD + parsers_module_path: str = "aiconfig_model_registry.py" + + +class Response(core_utils.Record): + instance_id: str + message: str + is_success: bool + aiconfig_instance: AIConfigRuntime | None + # TODO: make this a more constrained type + data: Any | None = None + + def to_json(self) -> core_utils.JSONObject: + return core_utils.JSONObject( + { + "instance_id": self.instance_id, + "message": self.message, + "is_success": self.is_success, + "data": self.data, + "aiconfig": aiconfig_to_json(self.aiconfig_instance), + } + ) + + def serialize(self) -> str: + return json.dumps(self.to_json()) + + @staticmethod + def from_error_message(instance_id: str, message: str) -> "Response": + return Response( + instance_id=instance_id, + message=message, + is_success=False, + aiconfig_instance=None, + ) + + +def aiconfig_to_json(aiconfig_instance: AIConfigRuntime | None) -> core_utils.JSONObject | None: + if aiconfig_instance is None: + return None + else: + EXCLUDE_OPTIONS = { + "prompt_index": True, + "file_path": True, + "callback_manager": True, + } + return aiconfig_instance.model_dump(exclude=EXCLUDE_OPTIONS) + + +class OperationOutput(core_utils.Record): + # TODO: change the fields + instance_id: str + message: str + is_success: bool + aiconfig_instance: AIConfigRuntime | None + # TODO: make this a more constrained type + data: Any | None = None + + @staticmethod + def from_method_output( + instance_id: str, aiconfig_instance: AIConfigRuntime, method_output: Result[Any, str], message_suffix: str = "" + ) -> "OperationOutput": + match method_output: + case Ok(output_ok): + out = OperationOutput( + instance_id=instance_id, + message=message_suffix, + is_success=True, + aiconfig_instance=aiconfig_instance, + data={"output": str(output_ok)}, + ) + logger.info(f"{out.instance_id=}, {out.message=}") + return out + case Err(e): + logger.error(f"{e=}") + return OperationOutput( + instance_id=instance_id, + message=f"Failed to run prompt: {e}\n{message_suffix}", + is_success=False, + aiconfig_instance=None, + ) + + def to_json(self) -> core_utils.JSONObject: + return core_utils.JSONObject( + { + "instance_id": self.instance_id, + "message": self.message, + "is_success": self.is_success, + "data": self.data, + "aiconfig": aiconfig_to_json(self.aiconfig_instance), + } + ) + + +class InstanceState(core_utils.Record): + instance_id: str + aiconfig_instance: AIConfigRuntime + aiconfig_path: UnvalidatedPath + + +@dataclass +class OperationOutcome: + operation_output: OperationOutput + instance_state: InstanceState + + +@dataclass +class ConnectionState: + websocket: WebSocket + + +def resolve_path(path: str) -> str: + return os.path.abspath(os.path.expanduser(path)) + + +def get_validated_path(raw_path: str | None) -> Result[ValidatedPath, str]: + if not raw_path: + return Err("No path provided") + resolved = resolve_path(raw_path) + if not os.path.isfile(resolved): + return Err(f"File does not exist: {resolved}") + return Ok(ValidatedPath(resolved)) + + +T_TaskOutcome = TypeVar("T_TaskOutcome", OperationOutcome, Result[Command, str]) + + +class DoneTask(Generic[T_TaskOutcome]): + def __init__(self, task: asyncio.Task[T_TaskOutcome]): + self.task = task + + +@dataclass +class GlobalState: + # TODO: is there a better way to pass this into websocket connections? + editor_config: EditServerConfig + active_instances: dict[str, ConnectionState] diff --git a/python/src/aiconfig/editor/server/server_v2_operation_lib.py b/python/src/aiconfig/editor/server/server_v2_operation_lib.py new file mode 100644 index 000000000..ff58edbf5 --- /dev/null +++ b/python/src/aiconfig/editor/server/server_v2_operation_lib.py @@ -0,0 +1,181 @@ +## SECTION: Imports + +from abc import abstractmethod +import os +import sys +import importlib +import importlib.util + +from types import ModuleType +from typing import Callable, Generic, Protocol, TypeVar + +import aiconfig.editor.server.server_v2_common as server_common +import lastmile_utils.lib.core.api as core_utils +from aiconfig.Config import AIConfigRuntime +from result import Err, Ok, Result + +from aiconfig.schema import ExecuteResult + +logger = core_utils.get_logger(__name__, log_file_path="editor_server_v2.log") + +## SECTION: Types + +T_MethodOutput = TypeVar("T_MethodOutput", None, list[ExecuteResult], covariant=True) + + +class RunMethodFn(Protocol, Generic[server_common.T_Operation, T_MethodOutput]): + @abstractmethod + async def __call__(self, aiconfig_instance: AIConfigRuntime, inputs: server_common.T_Operation) -> T_MethodOutput: # type: ignore[fixme] + pass + + +class RunOperationFn(Protocol, Generic[server_common.T_Operation]): + @abstractmethod + async def __call__( + self, aiconfig_instance: AIConfigRuntime, instance_id: str, inputs: server_common.T_Operation + ) -> server_common.OperationOutput: + pass + + +class AIConfigFn(Protocol, Generic[server_common.P, server_common.T_Output]): + @abstractmethod + async def __call__( + self, aiconfig_instance: AIConfigRuntime, *args: server_common.P.args, **kwargs: server_common.P.kwargs + ) -> server_common.T_Output: + pass + + +def aiconfig_result_to_operation_output( + res_aiconfig: Result[AIConfigRuntime, str], instance_id: str, message_suffix: str = "" +) -> server_common.OperationOutput: + match res_aiconfig: + case Ok(aiconfig_instance): + return server_common.OperationOutput( + instance_id=instance_id, + message=message_suffix, + is_success=True, + aiconfig_instance=aiconfig_instance, + ) + case Err(e): + return server_common.OperationOutput( + instance_id=instance_id, + message=f"No AIConfig: {e}" + message_suffix, + is_success=False, + aiconfig_instance=None, + ) + + +def operation_output_to_response(operation_output: server_common.OperationOutput) -> server_common.Response: + return server_common.Response( + instance_id=operation_output.instance_id, + message=operation_output.message, + is_success=operation_output.is_success, + aiconfig_instance=operation_output.aiconfig_instance, + data=operation_output.data, + ) + + +def operation_to_aiconfig_path(operation: server_common.Operation) -> server_common.UnvalidatedPath | None: + match operation: + case server_common.Load(path=path_raw): + return path_raw + case server_common.Save(path=path_raw): + return path_raw + case _: + return None + + +def operation_input_to_output(run_method_fn: RunMethodFn[server_common.T_Operation, T_MethodOutput]) -> RunOperationFn[server_common.T_Operation]: + """Decorator to make a function: + (a) robust to exceptions, + (b) Convert an arbitrary output into a (standard) OperationOutput. + + The input function takes an AIConfigRuntime instance and one of the Operation subtypes + and returns some value depending on which operation was run. + + The output (decorated) function does essentially the same thing, but with the properties listed above. + The output function also automatically accepts the instance_id, which maps 1:1 with the aiconfig, + and bundles it into the operation output. + + See `run_add_prompt() for example`. + + """ + + async def _new_fn(aiconfig_instance: AIConfigRuntime, instance_id: str, inputs: server_common.T_Operation) -> server_common.OperationOutput: + @core_utils.safe_run_fn_async + async def _wrap_input_fn(aiconfig_instance: AIConfigRuntime, inputs: server_common.T_Operation) -> T_MethodOutput: + return await run_method_fn(aiconfig_instance, inputs) + + method_output = await _wrap_input_fn(aiconfig_instance, inputs) + + logger.info(f"Ran operation: {inputs}") + out = server_common.OperationOutput.from_method_output(instance_id, aiconfig_instance, method_output, f"Ran operation: {inputs}") + logger.info(f"{out.instance_id=}, {out.message=}") + return out + + return _new_fn + + +@core_utils.safe_run_fn +def safe_run_create() -> AIConfigRuntime: + out = AIConfigRuntime.create() # type: ignore + return out + + +@core_utils.safe_run_fn_async +async def safe_save_to_disk(aiconfig_instance: AIConfigRuntime, path: server_common.UnvalidatedPath) -> None: + return aiconfig_instance.save(path) + + +@core_utils.safe_run_fn +def safe_load_from_disk(aiconfig_path: server_common.ValidatedPath) -> AIConfigRuntime: + aiconfig = AIConfigRuntime.load(aiconfig_path) # type: ignore + return aiconfig + + +def _import_module_from_path(path_to_module: str) -> Result[ModuleType, str]: + logger.debug(f"{path_to_module=}") + resolved_path = server_common.resolve_path(path_to_module) + logger.debug(f"{resolved_path=}") + module_name = os.path.basename(resolved_path).replace(".py", "") + + try: + spec = importlib.util.spec_from_file_location(module_name, resolved_path) + if spec is None: + return Err(f"Could not import module from path: {resolved_path}") + elif spec.loader is None: + return Err(f"Could not import module from path: {resolved_path} (no loader)") + else: + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return Ok(module) + except Exception as e: + return core_utils.ErrWithTraceback(e) + + +@core_utils.safe_run_fn_async +async def _register_user_model_parsers(user_register_fn: Callable[[], None]) -> None: + out = user_register_fn() + return out + + +def _load_register_fn_from_user_module(user_module: ModuleType) -> Result[Callable[[], None], str]: + if not hasattr(user_module, "register_model_parsers"): + return Err(f"User module {user_module} does not have a register_model_parsers function.") + register_fn = getattr(user_module, "register_model_parsers") + if not callable(register_fn): + return Err(f"User module {user_module} does not have a register_model_parsers function") + else: + return Ok(register_fn) + + +async def load_user_parser_module(path_to_module: str) -> Result[None, str]: + logger.info(f"Importing parsers module from {path_to_module}") + res_user_module = _import_module_from_path(path_to_module) + register_result = await ( + res_user_module.and_then(_load_register_fn_from_user_module) + # + .and_then_async(_register_user_model_parsers) + ) + return register_result diff --git a/python/src/aiconfig/editor/server/server_v2_run_operation.py b/python/src/aiconfig/editor/server/server_v2_run_operation.py new file mode 100644 index 000000000..534798fd8 --- /dev/null +++ b/python/src/aiconfig/editor/server/server_v2_run_operation.py @@ -0,0 +1,298 @@ +## SECTION: Imports + +import asyncio +import dataclasses +import json +import threading + +from queue import Empty, Queue +from typing import Any, cast + +import aiconfig.editor.server.server_v2_common as server_common +import aiconfig.editor.server.server_v2_operation_lib as operation_lib +import lastmile_utils.lib.core.api as core_utils +from aiconfig.Config import AIConfigRuntime +from aiconfig.model_parser import InferenceOptions +from aiconfig.registry import ModelParserRegistry +from fastapi import WebSocket +from result import Err, Ok, Result + +from aiconfig.schema import ExecuteResult, Prompt + +logger = core_utils.get_logger(__name__, log_file_path="editor_server_v2.log") + + +## SECTION: AIConfig operation run functions (following operations sent over websocket) + + +async def run_operation( + operation: server_common.Operation, aiconfig_instance: AIConfigRuntime, instance_id: str, websocket: WebSocket +) -> server_common.OperationOutput: + match operation: + case server_common.GetInstanceStatus(): + return await run_get_instance_status(instance_id) + case server_common.ListModels(): + return await run_list_models(instance_id) + case server_common.LoadModelParserModule(path=path_raw): + return await run_load_model_parser_module(instance_id, path_raw) + case server_common.Create(): + return await run_create(instance_id) + case server_common.Load(path=path_raw): + return await run_load(aiconfig_instance, instance_id, path_raw) + case server_common.Run(): + return await run_run(aiconfig_instance, instance_id, operation, websocket) + case server_common.AddPrompt(): + return await run_add_prompt(aiconfig_instance, instance_id, operation) + case server_common.UpdatePrompt(): + return await run_update_prompt(aiconfig_instance, instance_id, operation) + case server_common.DeletePrompt(): + return await run_delete_prompt(aiconfig_instance, instance_id, operation) + case server_common.Save(): + return await run_save(aiconfig_instance, instance_id, operation) + case server_common.UpdateModel(): + return await run_update_model(aiconfig_instance, instance_id, operation) + case server_common.SetParameter(): + return await run_set_parameter(aiconfig_instance, instance_id, operation) + case server_common.SetParameters(): + return await run_set_parameters(aiconfig_instance, instance_id, operation) + case server_common.DeleteParameter(): + return await run_delete_parameter(aiconfig_instance, instance_id, operation) + case server_common.SetName(): + return await run_set_name(aiconfig_instance, instance_id, operation) + case server_common.SetDescription(): + return await run_set_description(aiconfig_instance, instance_id, operation) + case server_common.MockRun(): + return await run_mock_run(aiconfig_instance, instance_id, operation) + + +async def run_get_instance_status(instance_id: str) -> server_common.OperationOutput: + return server_common.OperationOutput( + instance_id=instance_id, + message="See `data` field for instance status", + is_success=True, + aiconfig_instance=None, + data={"status": "OK"}, + ) + + +async def run_list_models( + instance_id: str, +) -> server_common.OperationOutput: + ids: list[str] = ModelParserRegistry.parser_ids() # type: ignore + return server_common.OperationOutput( + instance_id=instance_id, + message="Listed models", + is_success=True, + aiconfig_instance=None, + data={"ids": ids}, + ) + + +async def run_load_model_parser_module(instance_id: str, path_raw: server_common.UnvalidatedPath) -> server_common.OperationOutput: + load_module_result = await server_common.get_validated_path(path_raw).and_then_async(operation_lib.load_user_parser_module) + match load_module_result: + case Ok(_module): + return server_common.OperationOutput( + instance_id=instance_id, + message=f"Loaded module {path_raw}, output: {_module}", + is_success=True, + aiconfig_instance=None, + ) + case Err(e): + return server_common.OperationOutput( + instance_id=instance_id, + message=f"Failed to load module {path_raw}: {e}", + is_success=False, + aiconfig_instance=None, + ) + + +async def run_create(instance_id: str) -> server_common.OperationOutput: + aiconfig_instance = operation_lib.safe_run_create() + return operation_lib.aiconfig_result_to_operation_output(aiconfig_instance, instance_id) + + +async def run_load( + aiconfig_instance: AIConfigRuntime, + instance_id: str, + path_raw: server_common.UnvalidatedPath | None, +) -> server_common.OperationOutput: + if path_raw is None: + return server_common.OperationOutput( + instance_id=instance_id, + message="No path given, but AIConfig is loaded into memory. Here it is!", + is_success=True, + aiconfig_instance=aiconfig_instance, + ) + else: + res_path_val = server_common.get_validated_path(path_raw) + res_aiconfig = res_path_val.and_then(operation_lib.safe_load_from_disk) + message = f"Loaded AIConfig from {res_path_val}. This may have overwritten in-memory changes." + logger.warning(message) + return server_common.OperationOutput( + instance_id=instance_id, + message=message, + is_success=res_aiconfig.is_ok(), + aiconfig_instance=res_aiconfig.unwrap_or(None), + ) + + +async def run_run( + aiconfig_instance: AIConfigRuntime, instance_id: str, inputs: server_common.Run, websocket: WebSocket +) -> server_common.OperationOutput: + if inputs.stream: + logger.info(f"Running streaming operation: {inputs}") + + out_queue: Queue[core_utils.JSONObject | None] = Queue() + + def _stream_callback_queue_put(data: core_utils.JSONObject, accumulated_data: Any, index: int) -> None: + logger.debug(f"[stream callback]put {data=}") + # time.sleep(0.1) + out_queue.put(data) + + inference_options = InferenceOptions(stream=True, stream_callback=_stream_callback_queue_put) + + @core_utils.safe_run_fn_async + async def _run_streaming_inner(aiconfig_instance: AIConfigRuntime, inputs: server_common.Run) -> list[ExecuteResult]: + return await aiconfig_instance.run(inputs.prompt_name, inputs.params, inference_options) # type: ignore + + # This function gets asyncio.run() to type check + async def _run_inner_wrap(): + return await _run_streaming_inner(aiconfig_instance, inputs) + + @dataclasses.dataclass + class ThreadOutput: + value: Result[list[ExecuteResult], str] = Err("Not set") + + def _run_thread_main(out_queue: Queue[dict[str, str] | None], output: ThreadOutput) -> Result[list[ExecuteResult], str]: + method_output = asyncio.run(_run_inner_wrap()) + logger.info(f"Ran operation: {inputs}") + output.value = method_output + out_queue.put(None) + return method_output + + async def _send(data: core_utils.JSONObject) -> None: + logger.debug(f"[send]{data=}") + response = json.dumps(data) + send_res = await websocket.send_text(response) + logger.debug(f"sent {response=}, {send_res=}") + + thread_output = ThreadOutput() + thread = threading.Thread(target=_run_thread_main, args=(out_queue, thread_output)) + thread.start() + + async def _read_queue_and_send_until_empty() -> None: + while True: + try: + data = out_queue.get(block=True, timeout=10) + logger.debug(f"[get]{data=}") + if data is None: + return + await _send(data) + except Empty as e: + raise ValueError(f"Timeout waiting for output from operation") from e + + await _read_queue_and_send_until_empty() + thread.join() + return server_common.OperationOutput.from_method_output(instance_id, aiconfig_instance, thread_output.value, "Method: run with streaming") + else: + + @core_utils.safe_run_fn_async + async def _run_not_stream(aiconfig_instance: AIConfigRuntime, inputs: server_common.Run) -> list[ExecuteResult]: + out: list[ExecuteResult] = cast( + list[ExecuteResult], await aiconfig_instance.run(inputs.prompt_name, inputs.params, InferenceOptions(stream=False)) # type: ignore + ) + return out + + result_run = await _run_not_stream(aiconfig_instance, inputs) + return server_common.OperationOutput.from_method_output(instance_id, aiconfig_instance, result_run, "Method: run w/o streaming") + + +@operation_lib.operation_input_to_output +async def run_add_prompt(aiconfig_instance: AIConfigRuntime, inputs: server_common.AddPrompt) -> None: + return aiconfig_instance.add_prompt(inputs.prompt_name, inputs.prompt_data, inputs.index) + + +@operation_lib.operation_input_to_output +async def run_update_prompt(aiconfig_instance: AIConfigRuntime, inputs: server_common.UpdatePrompt) -> None: + return aiconfig_instance.update_prompt(inputs.prompt_name, inputs.prompt_data) + + +@operation_lib.operation_input_to_output +async def run_delete_prompt(aiconfig_instance: AIConfigRuntime, inputs: server_common.DeletePrompt) -> None: + return aiconfig_instance.delete_prompt(inputs.prompt_name) + + +@operation_lib.operation_input_to_output +async def run_save(aiconfig_instance: AIConfigRuntime, inputs: server_common.Save) -> None: + return aiconfig_instance.save(inputs.path) + + +@operation_lib.operation_input_to_output +async def run_update_model(aiconfig_instance: AIConfigRuntime, inputs: server_common.UpdateModel) -> None: + return aiconfig_instance.update_model(inputs.model_name, inputs.settings, inputs.prompt_name) + + +@operation_lib.operation_input_to_output +async def run_set_parameter(aiconfig_instance: AIConfigRuntime, inputs: server_common.SetParameter) -> None: + return aiconfig_instance.set_parameter(inputs.parameter_name, inputs.parameter_value, inputs.prompt_name) + + +@operation_lib.operation_input_to_output +async def run_set_parameters(aiconfig_instance: AIConfigRuntime, inputs: server_common.SetParameters) -> None: + return aiconfig_instance.set_parameters(inputs.parameters, inputs.prompt_name) + + +@operation_lib.operation_input_to_output +async def run_delete_parameter(aiconfig_instance: AIConfigRuntime, inputs: server_common.DeleteParameter) -> None: + return aiconfig_instance.delete_parameter(inputs.parameter_name, inputs.prompt_name) # type: ignore + + +@operation_lib.operation_input_to_output +async def run_set_name(aiconfig_instance: AIConfigRuntime, inputs: server_common.SetName) -> None: + return aiconfig_instance.set_name(inputs.name) + + +@operation_lib.operation_input_to_output +async def run_set_description(aiconfig_instance: AIConfigRuntime, inputs: server_common.SetDescription) -> None: + return aiconfig_instance.set_description(inputs.description) + + +async def run_mock_run(aiconfig_instance: AIConfigRuntime, instance_id: str, inputs: server_common.MockRun) -> server_common.OperationOutput: + """Sleep for `seconds` and add two test prompts to the aiconfig.""" + logger.info(f"Running operation: {inputs}") + s = inputs.seconds + if s < 0.2: + return server_common.OperationOutput( + instance_id=instance_id, + message=f"Sleep time must be at least 0.2 second, got {s}", + is_success=False, + aiconfig_instance=None, + ) + + SLEEP_PART_1 = 0.1 + SLEEP_PART_2 = s - SLEEP_PART_1 + await asyncio.sleep(SLEEP_PART_1) + try: + last = 0 + if len(aiconfig_instance.prompt_index) > 0: + last = max(int(k) for k in aiconfig_instance.prompt_index.keys()) + next_ = last + 1 + aiconfig_instance.add_prompt(str(next_), Prompt(name=str(next_), input=f"mock_prompt_{next_}")) + await asyncio.sleep(SLEEP_PART_2) + aiconfig_instance.add_prompt(str(next_ + 1), Prompt(name=str(next_ + 1), input=f"mock_prompt_{next_+1}")) + return server_common.OperationOutput( + instance_id=instance_id, + message=f"Blocked for {s} seconds and added prompts {next_}, {next_+1}", + is_success=True, + aiconfig_instance=aiconfig_instance, + ) + + except ValueError as e: + err = core_utils.ErrWithTraceback(e) + return server_common.OperationOutput( + instance_id=instance_id, + message=f"Test aiconfig is invalid. All prompt names must be ints. Got {err}, {aiconfig_instance.prompt_index.keys()}", + is_success=False, + aiconfig_instance=None, + ) diff --git a/python/src/aiconfig/editor/server/server_v2_utils.py b/python/src/aiconfig/editor/server/server_v2_utils.py new file mode 100644 index 000000000..4a3bad3e5 --- /dev/null +++ b/python/src/aiconfig/editor/server/server_v2_utils.py @@ -0,0 +1,250 @@ +## SECTION: Imports and Constants + +import asyncio +from dataclasses import dataclass +import logging +import os +import uuid + +import lastmile_utils.lib.core.api as core_utils +from fastapi import WebSocket +from result import Err, Ok, Result +import result + +import aiconfig.editor.server.server_v2_operation_lib as operation_lib +import aiconfig.editor.server.server_v2_run_operation as operations +import aiconfig.editor.server.server_v2_common as server_common + +logger: logging.Logger = core_utils.get_logger(__name__, log_file_path="editor_server_v2.log") + + +@dataclass +class LoopState: + instance_state: server_common.InstanceState + operation_task: asyncio.Task[server_common.OperationOutcome] | None + recv_task: asyncio.Task[Result[server_common.Command, str]] + + @staticmethod + async def new(websocket: WebSocket, edit_config: server_common.EditServerConfig) -> Result["LoopState", str]: + instance_state = await _init_websocket_instance(edit_config) + return result.do( + Ok( + LoopState( + # + instance_state=instance_state_ok, + operation_task=None, + recv_task=schedule_receive_task(websocket), + ) + ) + for instance_state_ok in instance_state + ) + + +async def _init_websocket_instance(edit_config: server_common.EditServerConfig) -> Result[server_common.InstanceState, str]: + instance_id = str(uuid.uuid4()) + logger.info(f"Starting websocket connection. {instance_id=}") + aiconfig_path = server_common.UnvalidatedPath(edit_config.aiconfig_path) + if os.path.exists(aiconfig_path): + return result.do( + Ok(server_common.InstanceState(instance_id=instance_id, aiconfig_instance=aiconfig_instance_ok, aiconfig_path=aiconfig_path)) + for val_path in server_common.get_validated_path(aiconfig_path) + for aiconfig_instance_ok in operation_lib.safe_load_from_disk(val_path) + ) + else: + return await result.do_async( + # + Ok(server_common.InstanceState(instance_id=instance_id, aiconfig_instance=aiconfig_instance_ok, aiconfig_path=aiconfig_path)) + for aiconfig_instance_ok in operation_lib.safe_run_create() + for _ in await operation_lib.safe_save_to_disk(aiconfig_instance_ok, aiconfig_path) + ) + + +## SECTION: Websocket control and AIConfig instance state management + + +async def _receive_command(websocket: WebSocket) -> Result[server_common.Command, str]: + data = await websocket.receive_text() + logger.debug(f"DATA:#\n{data}#, type: {type(data)}") + return result.do(Ok(res_cmd.command) for res_cmd in core_utils.safe_model_validate_json(data, server_common.SerializableCommand)) + + +def schedule_operation_task( + operation: server_common.Operation, instance_state: server_common.InstanceState, websocket: WebSocket +) -> asyncio.Task[server_common.OperationOutcome]: + logger.info("Running create task") + + async def _operation_task() -> server_common.OperationOutcome: + return await _get_operation_outcome(operation, instance_state, websocket) + + # Enter run mode + operation_task = asyncio.create_task(_operation_task()) + logger.info("Created task") + return operation_task + + +def schedule_receive_task(websocket: WebSocket) -> asyncio.Task[Result[server_common.Command, str]]: + async def _task() -> Result[server_common.Command, str]: + return await _receive_command(websocket) + + return asyncio.create_task(_task()) + + +async def handle_websocket_loop_iteration( + loop_state: LoopState, websocket: WebSocket +) -> Result[tuple[server_common.Response | None, LoopState], str]: + if loop_state.operation_task is None: + return await _handle_new_operation_case(loop_state, websocket) + else: + return await _handle_existing_operation_case(loop_state, loop_state.operation_task, websocket) + + +async def _handle_new_operation_case( + current_loop_state: LoopState, + websocket: WebSocket, +) -> Result[tuple[server_common.Response | None, LoopState], str]: + instance_state = current_loop_state.instance_state + instance_id = instance_state.instance_id + res_command = await current_loop_state.recv_task + new_recv_task = schedule_receive_task(websocket) + logger.info(f"{res_command=}") + match res_command: + case Ok(command): + match command: + case server_common.Cancel(): + logger.info("Received cancel command but no operation is running") + response = server_common.Response.from_error_message(instance_id, "Received cancel command but no operation is running") + new_loop_state = LoopState(instance_state=instance_state, operation_task=None, recv_task=new_recv_task) + return Ok((response, new_loop_state)) + case _: + loop_state = LoopState( + instance_state=instance_state, + operation_task=schedule_operation_task(command, instance_state, websocket), + recv_task=new_recv_task, + ) + return Ok((None, loop_state)) + case Err(e): + logger.error(f"Failed to parse command: {e}") + response = server_common.Response.from_error_message(instance_id, f"Failed to parse command: {e}") + loop_state = LoopState(instance_state=instance_state, operation_task=None, recv_task=new_recv_task) + return Ok((response, loop_state)) + + +async def _handle_existing_operation_case( + current_loop_state: LoopState, current_operation_task: asyncio.Task[server_common.OperationOutcome], websocket: WebSocket +) -> Result[tuple[server_common.Response | None, LoopState], str]: + # both recv task and operation task are running + # Wait for one to finish + logger.info("Both recv task and operation task running. Waiting for one to finish.") + instance_state = current_loop_state.instance_state + + done, pending = await asyncio.wait([current_operation_task, current_loop_state.recv_task], return_when=asyncio.FIRST_COMPLETED) + # At least one is now done + if not any(t.done() for t in [current_loop_state.recv_task, current_operation_task]): + return Err(f"Got done and pending sets, but tasks are still running. This should not happen. {done=}, {pending=}") + else: + if current_operation_task.done(): + done_task = server_common.DoneTask(current_operation_task) + return Ok(_handle_operation_task_done(current_loop_state, done_task, instance_state)) + else: + done_task = server_common.DoneTask(current_loop_state.recv_task) + return Ok(_handle_recv_task_done(current_operation_task, done_task, instance_state, websocket)) + + +def _handle_operation_task_done( + current_loop_state: LoopState, + current_operation_done_task: server_common.DoneTask[server_common.OperationOutcome], + instance_state: server_common.InstanceState, +) -> tuple[server_common.Response | None, LoopState]: + current_operation_task = current_operation_done_task.task + if current_operation_task.cancelled(): + logger.info("Operation task cancelled") + loop_state = LoopState(instance_state=instance_state, operation_task=None, recv_task=current_loop_state.recv_task) + return (None, loop_state) + else: + logger.info("Operation task done") + task_result = current_operation_task.result() + loop_state = LoopState(instance_state=task_result.instance_state, operation_task=None, recv_task=current_loop_state.recv_task) + response = operation_lib.operation_output_to_response(task_result.operation_output) + return (response, loop_state) + + +def _handle_recv_task_done( + current_operation_task: asyncio.Task[server_common.OperationOutcome], + current_recv_done_task: server_common.DoneTask[Result[server_common.Command, str]], + instance_state: server_common.InstanceState, + websocket: WebSocket, +) -> tuple[server_common.Response | None, LoopState]: + current_recv_task = current_recv_done_task.task + instance_id = instance_state.instance_id + if current_recv_task.cancelled(): + logger.info("Recv task cancelled") + loop_state = LoopState(instance_state=instance_state, operation_task=current_operation_task, recv_task=schedule_receive_task(websocket)) + return (None, loop_state) + else: + logger.info("Recv task done") + res_command = current_recv_task.result() + match res_command: + case Ok(command): + return _handle_new_command_while_operation_task_running(command, current_operation_task, instance_state, websocket) + case Err(e): + logger.error(f"Failed to parse command: {e}") + response = server_common.Response.from_error_message(instance_id, f"Failed to parse command: {e}") + loop_state = LoopState(instance_state=instance_state, operation_task=None, recv_task=schedule_receive_task(websocket)) + return (response, loop_state) + + +def _handle_new_command_while_operation_task_running( + command: server_common.Command, + current_operation_task: asyncio.Task[server_common.OperationOutcome], + instance_state: server_common.InstanceState, + websocket: WebSocket, +) -> tuple[server_common.Response | None, LoopState]: + instance_id = instance_state.instance_id + match command: + case server_common.Cancel(): + logger.info("Received cancel command while operation task is running. Cancelling operation.") + current_operation_task.cancel() + loop_state = LoopState(instance_state=instance_state, operation_task=None, recv_task=schedule_receive_task(websocket)) + response = server_common.Response( + instance_id=instance_id, + message="Cancelling command", + is_success=True, + aiconfig_instance=None, + ) + return (response, loop_state) + case _: + # TODO: something other than _ + logger.info("Received operation while operation task is running. Ignoring request.") + loop_state = LoopState( + instance_state=instance_state, + operation_task=current_operation_task, + recv_task=schedule_receive_task(websocket), + ) + response = server_common.Response.from_error_message( + # + instance_id, + "Received operation while operation task is running. Ignoring request.", + ) + return (response, loop_state) + + +async def _get_operation_outcome( + operation: server_common.Operation, + instance_state: server_common.InstanceState, + websocket: WebSocket, +) -> server_common.OperationOutcome: + current_aiconfig_instance = instance_state.aiconfig_instance + current_aiconfig_path = instance_state.aiconfig_path + instance_id = instance_state.instance_id + operation_output = await operations.run_operation(operation, current_aiconfig_instance, instance_id, websocket) + + aiconfig_instance_updated = operation_output.aiconfig_instance if operation_output.aiconfig_instance is not None else current_aiconfig_instance + aiconfig_path = operation_lib.operation_to_aiconfig_path(operation) + aiconfig_path_updated = aiconfig_path or current_aiconfig_path + logger.debug("Updated instance: %s", aiconfig_instance_updated) + return server_common.OperationOutcome( + operation_output=operation_output, + instance_state=server_common.InstanceState( + instance_id=instance_id, aiconfig_instance=aiconfig_instance_updated, aiconfig_path=aiconfig_path_updated + ), + ) diff --git a/python/src/aiconfig/scripts/aiconfig_cli_v2.py b/python/src/aiconfig/scripts/aiconfig_cli_v2.py new file mode 100644 index 000000000..8f0b2074c --- /dev/null +++ b/python/src/aiconfig/scripts/aiconfig_cli_v2.py @@ -0,0 +1,126 @@ +import asyncio +import logging +import signal +import subprocess +import sys + +import lastmile_utils.lib.core.api as core_utils +import result +from result import Err, Ok, Result + +import aiconfig.editor.server.server_v2_common as server_common +from aiconfig.editor.server.server_v2 import run_backend_server + + +class AIConfigCLIConfig(core_utils.Record): + log_level: str | int = "WARNING" + + +logging.basicConfig(format=core_utils.LOGGER_FMT) +LOGGER = logging.getLogger(__name__) + + +async def main_with_args(argv: list[str]) -> int: + final_result = await run_subcommand(argv) + match final_result: + case Ok(msg): + LOGGER.info("Final result: Ok:\n%s", msg) + return 0 + case Err(e): + LOGGER.critical("Final result err: %s", e) + return core_utils.result_to_exitcode(Err(e)) + + +async def run_subcommand(argv: list[str]) -> Result[str, str]: + LOGGER.info("Running subcommand") + subparser_record_types = {"edit": server_common.EditServerConfig} + main_parser = core_utils.argparsify(AIConfigCLIConfig, subparser_record_types=subparser_record_types) + + res_cli_config = core_utils.parse_args(main_parser, argv[1:], AIConfigCLIConfig) + res_cli_config.and_then(_process_cli_config) + + subparser_name = core_utils.get_subparser_name(main_parser, argv[1:]) + LOGGER.info(f"Running subcommand: {subparser_name}") + + if subparser_name == "edit": + LOGGER.debug("Running edit subcommand") + res_edit_config = core_utils.parse_args(main_parser, argv[1:], server_common.EditServerConfig) + LOGGER.debug(f"{res_edit_config.is_ok()=}") + res_servers = await res_edit_config.and_then_async(_run_editor_servers) + out: Result[str, str] = await result.do_async( + # + Ok(",".join(res_servers_ok)) + # + for res_servers_ok in res_servers + ) + return out + else: + return Err(f"Unknown subparser: {subparser_name}") + + +def _sigint(procs: list[subprocess.Popen[bytes]]) -> Result[str, str]: + LOGGER.info("sigint") + for p in procs: + p.send_signal(signal.SIGINT) + return Ok("Sent SIGINT to frontend servers.") + + +async def _run_editor_servers(edit_config: server_common.EditServerConfig) -> Result[list[str], str]: + LOGGER.info("Running editor servers") + frontend_procs = _run_frontend_server_background() if edit_config.server_mode in [server_common.ServerMode.DEBUG_SERVERS] else Ok([]) + match frontend_procs: + case Ok(_): + pass + case Err(e): + return Err(e) + + results: list[Result[str, str]] = [] + backend_res = await run_backend_server(edit_config) + match backend_res: + case Ok(_): + pass + case Err(e): + return Err(e) + + results.append(backend_res) + + sigint_res = frontend_procs.and_then(_sigint) + results.append(sigint_res) + return core_utils.result_reduce_list_all_ok(results) + + +def _process_cli_config(cli_config: AIConfigCLIConfig) -> Result[bool, str]: + LOGGER.setLevel(cli_config.log_level) + return Ok(True) + + +def _run_frontend_server_background() -> Result[list[subprocess.Popen[bytes]], str]: + LOGGER.info("Running frontend server in background") + p1, p2 = None, None + try: + p1 = subprocess.Popen(["yarn"], cwd="python/src/aiconfig/editor/client") + except Exception as e: + return core_utils.ErrWithTraceback(e) + + try: + p2 = subprocess.Popen(["yarn", "start"], cwd="python/src/aiconfig/editor/client", stdin=subprocess.PIPE) + except Exception as e: + return core_utils.ErrWithTraceback(e) + + try: + assert p2.stdin is not None + p2.stdin.write(b"n\n") + except Exception as e: + return core_utils.ErrWithTraceback(e) + + return Ok([p1, p2]) + + +def main() -> int: + argv = sys.argv + return asyncio.run(main_with_args(argv)) + + +if __name__ == "__main__": + retcode: int = main() + sys.exit(retcode) diff --git a/python/tests/test_editor_server_v2.py b/python/tests/test_editor_server_v2.py new file mode 100644 index 000000000..f44d8671f --- /dev/null +++ b/python/tests/test_editor_server_v2.py @@ -0,0 +1,260 @@ +from abc import abstractmethod +from dataclasses import dataclass +import json +from pathlib import Path +import shlex +import time +from typing import Any, Callable, NewType, Protocol +import lastmile_utils.lib.core.api as core_utils + +import dotenv +import pytest +from websocket import WebSocket, create_connection # type: ignore +import aiconfig.editor.server.server_v2_common as server_common +from xprocess import ProcessStarter, XProcess + +RunningServerConfig = NewType("RunningServerConfig", server_common.EditServerConfig) + + +@dataclass +class ConnectedWebsocket: + websocket: WebSocket + running_server: RunningServerConfig + + +class GetConnectedWebsocketFn(Protocol): + @abstractmethod + def __call__(self, edit_config: server_common.EditServerConfig) -> ConnectedWebsocket: + pass + + +def _get_cli_command(subcommand_cfg: server_common.EditServerConfig) -> list[str]: + match subcommand_cfg: + case server_common.EditServerConfig( + # + server_port=server_port_, + aiconfig_path=aiconfig_path_, + server_mode=server_mode_, + parsers_module_path=parsers_module_path_, + ): + subcommand = "edit" + str_server_mode = server_mode_.name.lower() + cmd = shlex.split( + f""" + python -m 'aiconfig.scripts.aiconfig_cli_v2' {subcommand} \ + --server-port={server_port_} \ + --server-mode={str_server_mode} \ + --aiconfig-path={aiconfig_path_} \ + --parsers-module-path={parsers_module_path_} + """ + ) + return cmd + + +def _make_edit_config(**kwargs) -> server_common.EditServerConfig: # type: ignore + TEST_PORT = 8011 + defaults = dict( + server_port=TEST_PORT, + server_mode=server_common.ServerMode.DEBUG_BACKEND, + parsers_module_path="aiconfig.model_parser", + ) + given_kwargs: dict[str, Any] = kwargs + new_kwargs = core_utils.dict_union_allow_replace(defaults, given_kwargs, on_conflict="replace") + return server_common.EditServerConfig(**new_kwargs) + + +def _make_connected_websocket(get_connected_websocket: GetConnectedWebsocketFn, edit_config: server_common.EditServerConfig) -> ConnectedWebsocket: + connected_websocket = get_connected_websocket(edit_config) + running_server = connected_websocket.running_server + assert running_server.aiconfig_path == edit_config.aiconfig_path + return connected_websocket + + +def _make_default_connected_websocket(get_connected_websocket: GetConnectedWebsocketFn, tmp_path: Path) -> ConnectedWebsocket: + the_path = (tmp_path / "the.aiconfig.json").as_posix() + edit_config = _make_edit_config(aiconfig_path=the_path) + return _make_connected_websocket(get_connected_websocket, edit_config) + + +def _load_aiconfig(websocket: WebSocket) -> dict[str, Any]: + ws_send_command(websocket, "load", dict()) + resp = ws_receive_response(websocket) + assert resp is not None + return resp["aiconfig"] + + +@pytest.fixture +def get_running_server(xprocess: XProcess, request: pytest.FixtureRequest): + def _make(edit_config: server_common.EditServerConfig): + class Starter(ProcessStarter): + @property + def pattern(self): # type: ignore + return "Running on http://127.0.0.1" + + # command to start process + @property + def args(self): # type: ignore + return _get_cli_command(edit_config) + + terminate_on_interrupt = True + + # ensure process is running and return its logfile + pid, logfile = xprocess.ensure("myserver", Starter) # type: ignore + print(f"{logfile=}") + + dotenv.load_dotenv() + + def cleanup(): + # clean up whole process tree afterwards + xprocess.getinfo("myserver").terminate() # type: ignore + + request.addfinalizer(cleanup) + + return RunningServerConfig(edit_config) + + return _make + + +@pytest.fixture +def get_connected_websocket(get_running_server: Callable[[server_common.EditServerConfig], RunningServerConfig], request: pytest.FixtureRequest): + def _make(edit_config: server_common.EditServerConfig): + running_server = get_running_server(edit_config) + url = f"ws://localhost:{running_server.server_port}/ws_manage_aiconfig_instance" + ws = create_connection(url) + synchronize_with_server(ws) + + def cleanup(): + ws.close() + + request.addfinalizer(cleanup) + return ConnectedWebsocket(websocket=ws, running_server=running_server) + + return _make + + +def ws_send_command(websocket: WebSocket, command_name: str, command_params: dict[str, Any]): + command_json_obj = dict(command_name=command_name, **command_params) + message_obj = {"command": command_json_obj} + cmd_str = json.dumps(message_obj) + res = websocket.send_text(cmd_str) + print(f"sent {cmd_str=}, {res=}") + return res + + +def ws_receive_response(websocket: WebSocket) -> dict[str, Any] | None: + resp_str = websocket.recv() + if not resp_str: + return None + resp = json.loads(resp_str) + return resp + + +def synchronize_with_server(websocket: WebSocket): + """Send a command and wait for response + to make sure server is initialized and websocket is set up.""" + ws_send_command( + websocket, + "mock_run", + command_params=dict(seconds=0.2), + ) + _sync_resp = ws_receive_response(websocket) + # print(f"{sync_resp=}") + + +def test_editor_server_start_new_file(get_connected_websocket: GetConnectedWebsocketFn, tmp_path: Path): + the_path = (tmp_path / "the.aiconfig.json").as_posix() + edit_config = _make_edit_config(aiconfig_path=the_path) + assert core_utils.load_json_file(the_path).is_err() + connected_websocket = _make_connected_websocket(get_connected_websocket, edit_config) + assert connected_websocket.running_server.aiconfig_path == the_path + + running_server = connected_websocket.running_server + + assert core_utils.load_json_file(running_server.aiconfig_path).is_ok() + + +def _get_mock_default_prompt(prompt_num: int) -> dict[str, Any]: + return dict(name=str(prompt_num), input=f"mock_prompt_{prompt_num}", metadata=None, outputs=[]) + + +def test_editor_mock_run_simple(get_connected_websocket: GetConnectedWebsocketFn, tmp_path: Path): + connected_websocket = _make_default_connected_websocket(get_connected_websocket, tmp_path) + websocket = connected_websocket.websocket + ws_send_command(websocket, "mock_run", dict(seconds=0.3)) + resp = ws_receive_response(websocket) + assert resp is not None + message, is_success, aiconfig, _data = resp["message"], resp["is_success"], resp["aiconfig"], resp["data"] + assert is_success + # synchronize_with_server runs mock_run with 0.2 seconds, which adds + # prompts 1 and 2. + # This command is expected to add 3 and 4. + assert message == "Blocked for 0.3 seconds and added prompts 3, 4" + prompts = aiconfig.get("prompts", []) + assert prompts == [_get_mock_default_prompt(i + 1) for i in range(4)] + print("Done") + + +def test_editor_mock_run_cancel_rollback_1(get_connected_websocket: GetConnectedWebsocketFn, tmp_path: Path): + """Cancel immediately, before any mutation can happen""" + connected_websocket = _make_default_connected_websocket(get_connected_websocket, tmp_path) + websocket = connected_websocket.websocket + + aiconfig_before = _load_aiconfig(websocket) + print(f"{aiconfig_before=}") + + ws_send_command(websocket, "mock_run", dict(seconds=0.3)) + ws_send_command(websocket, "cancel", dict()) + cancel_resp = ws_receive_response(websocket) + print(f"{cancel_resp=}") + ws_send_command(websocket, "load", dict()) + aiconfig_after = _load_aiconfig(websocket) + assert aiconfig_before == aiconfig_after + + +def test_editor_mock_run_cancel_rollback_2(get_connected_websocket: GetConnectedWebsocketFn, tmp_path: Path): + """wait until some mutation happens, then cancel""" + connected_websocket = _make_default_connected_websocket(get_connected_websocket, tmp_path) + websocket = connected_websocket.websocket + + aiconfig_before = _load_aiconfig(websocket) + print(f"{aiconfig_before=}") + + ws_send_command(websocket, "mock_run", dict(seconds=0.3)) + + # mock run does its first mutation after 0.1 seconds. + # Waiting 0.2 before canceling allows it to happen, testing rollback. + time.sleep(0.2) + ws_send_command(websocket, "cancel", dict()) + cancel_resp = ws_receive_response(websocket) + print(f"{cancel_resp=}") + aiconfig_after = _load_aiconfig(websocket) + assert aiconfig_before == aiconfig_after, f"{aiconfig_before=}, {aiconfig_after=}" + + +def test_editor_mock_run_cancel_fast_walltime(get_connected_websocket: GetConnectedWebsocketFn, tmp_path: Path): + """wait until some mutation happens, then cancel""" + connected_websocket = _make_default_connected_websocket(get_connected_websocket, tmp_path) + websocket = connected_websocket.websocket + + # Control: block for 0.3 seconds, don't cancel + ts_start = time.time() + ws_send_command(websocket, "mock_run", dict(seconds=0.3)) + _mock_run_resp = ws_receive_response(websocket) + ts_end = time.time() + s_elapsed = ts_end - ts_start + print(f"[control]{s_elapsed=}") + assert s_elapsed >= 0.3 + + ts_start = time.time() + ws_send_command(websocket, "mock_run", dict(seconds=0.3)) + # Cancel immediately + ws_send_command(websocket, "cancel", dict()) + cancel_resp = ws_receive_response(websocket) + ts_end = time.time() + s_elapsed = ts_end - ts_start + print(f"[cancel]{s_elapsed=}") + assert s_elapsed < 0.01 + + assert cancel_resp is not None + assert cancel_resp["message"] == "Cancelling command" + assert cancel_resp["is_success"] == True