diff --git a/python/requirements.txt b/python/requirements.txt
index 0e3b7ffde..5b40f6b5f 100644
--- a/python/requirements.txt
+++ b/python/requirements.txt
@@ -3,10 +3,12 @@ black
flake8
flask-cors
flask[async]
+fastapi
google-generativeai
huggingface_hub
hypothesis==6.91.0
-lastmile-utils==0.0.14
+hypercorn
+lastmile-utils==0.0.16
mock
nest_asyncio
nltk
diff --git a/python/src/aiconfig/editor/server/server.py b/python/src/aiconfig/editor/server/server.py
index f9cd467f9..052ffab00 100644
--- a/python/src/aiconfig/editor/server/server.py
+++ b/python/src/aiconfig/editor/server/server.py
@@ -300,9 +300,10 @@ def run_async_config_in_thread():
status=200,
content_type="application/json",
)
-
+
# Run without streaming
inference_options = InferenceOptions(stream=stream)
+
def run_async_config_in_thread():
asyncio.run(
aiconfig.run(
@@ -324,7 +325,7 @@ def run_async_config_in_thread():
code=200,
aiconfig=aiconfig,
).to_flask_format()
-
+
except Exception as e:
return HttpResponseWithAIConfig(
#
diff --git a/python/src/aiconfig/editor/server/server_v2.py b/python/src/aiconfig/editor/server/server_v2.py
new file mode 100644
index 000000000..906ea4e66
--- /dev/null
+++ b/python/src/aiconfig/editor/server/server_v2.py
@@ -0,0 +1,247 @@
+## SECTION: Imports and Constants
+
+import logging
+import os
+from contextlib import asynccontextmanager
+from typing import cast
+import webbrowser
+
+import lastmile_utils.lib.core.api as core_utils
+from fastapi import FastAPI, WebSocket, WebSocketDisconnect
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import HTMLResponse, JSONResponse
+from fastapi.staticfiles import StaticFiles
+from hypercorn.asyncio import serve # type: ignore
+from hypercorn.config import Config
+from hypercorn.typing import ASGIFramework
+from result import Err, Ok, Result
+import result
+
+from aiconfig.editor.server import server_v2_utils as server_utils
+import aiconfig.editor.server.server_v2_common as server_common
+import aiconfig.editor.server.server_v2_operation_lib as operation_lib
+
+
+THIS_DIR = os.path.dirname(os.path.realpath(__file__))
+STATIC_DIR = os.path.join(THIS_DIR, "static")
+
+DEFAULT_PORT = 8080
+
+logger: logging.Logger = core_utils.get_logger(__name__, log_file_path="editor_server_v2.log")
+
+
+## SECTION: Global state initialization
+
+
+global_state = server_common.GlobalState(
+ editor_config=server_common.EditServerConfig(),
+ active_instances=dict(),
+)
+
+
+async def _run_websocket_connection(initial_loop_state: server_utils.LoopState, websocket: WebSocket) -> Result[str, str]:
+ """This is the main websocket loop."""
+ loop_state = initial_loop_state
+ instance_state = loop_state.instance_state
+ instance_id = instance_state.instance_id
+ global global_state
+ global_state.active_instances[instance_id] = server_common.ConnectionState(websocket=websocket)
+
+ logger.info("Starting websocket loop")
+ logger.info(f"{instance_state.aiconfig_path=}")
+ while True:
+ logger.debug(f"{loop_state.operation_task=}, {loop_state.recv_task=}")
+ try:
+ res_handle = await server_utils.handle_websocket_loop_iteration(loop_state, websocket)
+ match res_handle:
+ case Ok((response, new_loop_state)):
+ loop_state = new_loop_state
+ if response:
+ await websocket.send_text(response.serialize())
+ case Err(e):
+ logger.critical(f"Can't generate response or update loop state.\n{e}")
+ return await _cleanup_websocket_connection(instance_state.instance_id, global_state.active_instances[instance_state.instance_id])
+ except (WebSocketDisconnect, RuntimeError) as e:
+ ewt = core_utils.ErrWithTraceback(e)
+ logger.error(f"Websocket loop terminated: {e}, {ewt}")
+ return await _cleanup_websocket_connection(instance_state.instance_id, global_state.active_instances[instance_state.instance_id])
+
+
+@asynccontextmanager
+async def _app_lifespan(app: FastAPI):
+ global global_state
+ logger.info("Start lifespan")
+ yield
+ logger.info("Shutting down app.")
+ cleanup_res_list = [
+ await _cleanup_websocket_connection(instance_id, websocket_state) for instance_id, websocket_state in global_state.active_instances.items()
+ ]
+ cleanup_ok, cleanup_err = core_utils.result_reduce_list_separate(cleanup_res_list)
+ logger.info("Cleaned up websockets. %s", cleanup_ok)
+ if len(cleanup_err) > 0:
+ logger.error("Failed to clean up websockets. %s", cleanup_err)
+ del global_state
+
+
+app = FastAPI(lifespan=_app_lifespan)
+
+
+## SECTION: Programmatic Server API (run entrypoint)
+
+
+async def run_backend_server(edit_config: server_common.EditServerConfig) -> Result[str, str]:
+ global global_state
+ global_state.editor_config = edit_config
+ global logger
+ logger = core_utils.get_logger(__name__, log_file_path="editor_server_v2.log", log_level=edit_config.log_level)
+
+ await _init_app_state(app, edit_config)
+
+ def _outcome_to_str(outcome: server_common.ServerBindOutcome) -> Result[str, str]:
+ match outcome:
+ case server_common.ServerBindOutcome.SUCCESS:
+ return Ok(f"Server running on port {edit_config.server_port}")
+ case server_common.ServerBindOutcome.PORT_IN_USE:
+ return Err(f"Port {edit_config.server_port} in use")
+ case server_common.ServerBindOutcome.OTHER_FAILURE:
+ return Err(f"Failed to run server on port {edit_config.server_port}")
+
+ if edit_config.server_mode != server_common.ServerMode.DEBUG_BACKEND:
+ try:
+ logger.info(f"Opening browser at http://localhost:{edit_config.server_port}")
+ webbrowser.open(f"http://localhost:{edit_config.server_port}")
+ except Exception as e:
+ logger.warning(f"Failed to open browser: {e}. Please open http://localhost:{edit_config.server_port} manually.")
+
+ match edit_config.server_port:
+ case int():
+ result = await _run_backend_server_on_port(edit_config.log_level, edit_config.server_port)
+ return _outcome_to_str(result)
+ case None:
+ port_try = DEFAULT_PORT
+ max_port = 65535
+ while port_try < max_port:
+ backend_res_on_port = await _run_backend_server_on_port(edit_config.log_level, port_try)
+ logger.debug(f"{backend_res_on_port=}")
+ match backend_res_on_port:
+ case server_common.ServerBindOutcome.PORT_IN_USE:
+ logger.info("Going to try next port...")
+ port_try += 1
+ continue
+ case _:
+ return _outcome_to_str(backend_res_on_port)
+
+ return Err(f"Failed to run backend server on any port in range {DEFAULT_PORT} to {max_port}")
+
+
+async def _run_backend_server_on_port(log_level: str | int, port: int) -> server_common.ServerBindOutcome:
+ logger.info(f"Running backend server on port {port}")
+
+ log_level_for_hypercorn = (
+ #
+ log_level.upper()
+ if isinstance(log_level, str)
+ else logging.getLevelName(log_level)
+ )
+ fastapi_app: ASGIFramework = cast(ASGIFramework, app)
+ try:
+ logger.info(f"Starting server on port {port}")
+ await serve(
+ fastapi_app,
+ Config.from_mapping(
+ #
+ _bind=[f"localhost:{port}"],
+ loglevel=log_level_for_hypercorn,
+ use_reloader=True,
+ keep_alive_timeout=365 * 24 * 3600,
+ ),
+ )
+ logger.info(f"Done running server on port {port}")
+ return server_common.ServerBindOutcome.SUCCESS
+ except OSError as e_os:
+ logger.warning(f"Port in use: {port}: {e_os}")
+ return server_common.ServerBindOutcome.PORT_IN_USE
+ except Exception as e:
+ logger.error(f"Failed to run backend server on port {port}: {type(e)}")
+ logger.error(core_utils.ErrWithTraceback(e))
+ return server_common.ServerBindOutcome.OTHER_FAILURE
+
+
+## SECTION: Web API. HTTP endpoints: static files, root, and websocket connect
+
+
+@app.get("/")
+def home():
+ logger.info(f"ROOT, {os.getcwd()}")
+ index_path = os.path.join(STATIC_DIR, "index.html")
+ res_index = core_utils.read_text_file(index_path)
+ match res_index:
+ case Ok(index):
+ return HTMLResponse(index)
+ case Err(e):
+ logger.error(f"Failed to load index.html: {e}")
+ return HTMLResponse(f"
Failed to load index.html: {e}
")
+
+
+@app.get("/api/server_status")
+def server_status():
+ data = {"status": "OK"}
+ return JSONResponse(content=data, status_code=200)
+
+
+@app.websocket("/ws_manage_aiconfig_instance")
+async def accept_and_run_websocket(websocket: WebSocket):
+ logger.info("Accepting websocket connection")
+ await websocket.accept()
+ global global_state
+
+ initial_loop_state = await server_utils.LoopState.new(websocket, global_state.editor_config)
+ res_websocket: Result[str, str] = await result.do_async(
+ await _run_websocket_connection(initial_loop_state_ok, websocket)
+ #
+ for initial_loop_state_ok in initial_loop_state
+ )
+ logger.info(f"{res_websocket=}")
+ match res_websocket:
+ case Ok(result_):
+ return JSONResponse(content=result_, status_code=200)
+ case Err(e):
+ return JSONResponse(content=f"Failed to run websocket: {e}", status_code=500)
+
+
+## SECTION: Global state management
+
+
+async def _init_app_state(app: FastAPI, edit_config: server_common.EditServerConfig):
+ logger.setLevel(edit_config.log_level)
+ logger.info("Edit config: %s", edit_config.model_dump_json())
+
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+
+ app.mount("/static", StaticFiles(directory=os.path.join(STATIC_DIR, "static")), name="static")
+
+ res_load_module = await (
+ server_common.get_validated_path(edit_config.parsers_module_path)
+ #
+ .and_then_async(operation_lib.load_user_parser_module)
+ )
+ match res_load_module:
+ case Ok(_module):
+ logger.info(f"Loaded module {edit_config.parsers_module_path}, output: {_module}")
+ case Err(e):
+ logger.warning(f"Failed to load module {edit_config.parsers_module_path}: {e}")
+
+
+async def _cleanup_websocket_connection(instance_id: str, websocket_state: server_common.ConnectionState) -> Result[str, str]:
+ logger.info(f"Closing websocket connection for instance {websocket_state}")
+ try:
+ await websocket_state.websocket.close()
+ return Ok(f"Closed websocket connection for instance {instance_id}")
+ except Exception as e:
+ return Err(f"Failed to close websocket connection for instance {instance_id}: {e}")
diff --git a/python/src/aiconfig/editor/server/server_v2_common.py b/python/src/aiconfig/editor/server/server_v2_common.py
new file mode 100644
index 000000000..cd82ca2ae
--- /dev/null
+++ b/python/src/aiconfig/editor/server/server_v2_common.py
@@ -0,0 +1,348 @@
+## SECTION: Imports
+
+from dataclasses import dataclass
+from enum import Enum
+import json
+import os
+from typing import Any, NewType, Optional, ParamSpec, TypeVar
+
+import lastmile_utils.lib.core.api as core_utils
+from fastapi import WebSocket
+from result import Err, Ok, Result
+
+import asyncio
+from dataclasses import dataclass
+from typing import Generic, TypeVar
+
+from result import Result
+from typing import Any, Literal, TypeVar
+
+import lastmile_utils.lib.core.api as core_utils
+from pydantic import Field
+
+from aiconfig.schema import Prompt
+
+import aiconfig.editor.server.server_v2_common as server_common
+
+
+from aiconfig.Config import AIConfigRuntime
+
+logger = core_utils.get_logger(__name__, log_file_path="editor_server_v2.log")
+
+## SECTION: Common types
+
+P = ParamSpec("P")
+T_Output = TypeVar("T_Output", covariant=True)
+
+UnvalidatedPath = NewType("UnvalidatedPath", str)
+ValidatedPath = NewType("ValidatedPath", str)
+
+
+class GetInstanceStatus(core_utils.Record):
+ # This is one of the equivalents of the old /api/server_status endpoint
+ command_name: Literal["get_instance_status"]
+ instance_id: str
+ status: str
+
+
+class ListModels(core_utils.Record):
+ command_name: Literal["list_models"]
+
+
+class LoadModelParserModule(core_utils.Record):
+ command_name: Literal["load_model_parser_module"]
+ path: server_common.UnvalidatedPath
+
+
+class Load(core_utils.Record):
+ command_name: Literal["load"]
+ path: server_common.UnvalidatedPath | None = None
+
+
+class Save(core_utils.Record):
+ command_name: Literal["save"]
+ path: server_common.UnvalidatedPath
+
+
+class Create(core_utils.Record):
+ command_name: Literal["create"]
+
+
+class Run(core_utils.Record):
+ command_name: Literal["run"]
+ prompt_name: str
+ params: dict[str, Any] = Field(default_factory=dict)
+ stream: bool = False
+
+
+class AddPrompt(core_utils.Record):
+ command_name: Literal["add_prompt"]
+ prompt_name: str
+ prompt_data: Prompt
+ index: int
+
+
+class UpdatePrompt(core_utils.Record):
+ command_name: Literal["update_prompt"]
+ prompt_name: str
+ prompt_data: Prompt
+
+
+class DeletePrompt(core_utils.Record):
+ command_name: Literal["delete_prompt"]
+ prompt_name: str
+
+
+class UpdateModel(core_utils.Record):
+ command_name: Literal["update_model"]
+ model_name: str | None
+ settings: dict[str, Any]
+ prompt_name: str | None
+
+
+class SetParameter(core_utils.Record):
+ command_name: Literal["set_parameter"]
+ parameter_name: str
+ parameter_value: str | dict[str, Any]
+ prompt_name: str | None
+
+
+class SetParameters(core_utils.Record):
+ command_name: Literal["set_parameters"]
+ parameters: dict[str, Any]
+ prompt_name: str | None
+
+
+class DeleteParameter(core_utils.Record):
+ command_name: Literal["delete_parameter"]
+ parameter_name: str
+ prompt_name: str
+
+
+class SetName(core_utils.Record):
+ command_name: Literal["set_name"]
+ name: str
+
+
+class SetDescription(core_utils.Record):
+ command_name: Literal["set_description"]
+ description: str
+
+
+class MockRun(core_utils.Record):
+ """For testing only"""
+
+ command_name: Literal["mock_run"]
+ seconds: float
+
+
+# THIS MUST BE KEPT IN SYNC WITH T_OPERATION BELOW
+Operation = (
+ GetInstanceStatus
+ | ListModels
+ | LoadModelParserModule
+ | Create
+ | Load
+ | Run
+ | AddPrompt
+ | UpdatePrompt
+ | DeletePrompt
+ | Save
+ | UpdateModel
+ | SetParameter
+ | SetParameters
+ | DeleteParameter
+ | SetName
+ | SetDescription
+ | MockRun
+)
+
+# THIS MUST BE KEPT IN SYNC WITH OPERATION ABOVE
+T_Operation = TypeVar(
+ "T_Operation",
+ GetInstanceStatus,
+ ListModels,
+ LoadModelParserModule,
+ Create,
+ Load,
+ Run,
+ AddPrompt,
+ UpdatePrompt,
+ DeletePrompt,
+ Save,
+ UpdateModel,
+ SetParameter,
+ SetParameters,
+ DeleteParameter,
+ SetName,
+ SetDescription,
+ MockRun,
+ contravariant=True,
+)
+
+
+# Cancel is the only command that is not an operation.
+class Cancel(core_utils.Record):
+ command_name: Literal["cancel"]
+
+
+Command = Operation | Cancel
+
+
+class SerializableCommand(core_utils.Record):
+ command: Command = Field(..., discriminator="command_name")
+
+
+class ServerMode(Enum):
+ DEBUG_SERVERS = "DEBUG_SERVERS"
+ DEBUG_BACKEND = "DEBUG_BACKEND"
+ PROD = "PROD"
+
+
+class ServerBindOutcome(Enum):
+ SUCCESS = "SUCCESS"
+ PORT_IN_USE = "PORT_IN_USE"
+ OTHER_FAILURE = "OTHER_FAILURE"
+
+
+class EditServerConfig(core_utils.Record, core_utils.EnumValidatedRecordMixin):
+ server_port: Optional[int] = None
+ aiconfig_path: str = "my_aiconfig.aiconfig.json"
+ log_level: str | int = "INFO"
+ log_file_path: str = "editor_server_v2.log"
+ server_mode: ServerMode = ServerMode.PROD
+ parsers_module_path: str = "aiconfig_model_registry.py"
+
+
+class Response(core_utils.Record):
+ instance_id: str
+ message: str
+ is_success: bool
+ aiconfig_instance: AIConfigRuntime | None
+ # TODO: make this a more constrained type
+ data: Any | None = None
+
+ def to_json(self) -> core_utils.JSONObject:
+ return core_utils.JSONObject(
+ {
+ "instance_id": self.instance_id,
+ "message": self.message,
+ "is_success": self.is_success,
+ "data": self.data,
+ "aiconfig": aiconfig_to_json(self.aiconfig_instance),
+ }
+ )
+
+ def serialize(self) -> str:
+ return json.dumps(self.to_json())
+
+ @staticmethod
+ def from_error_message(instance_id: str, message: str) -> "Response":
+ return Response(
+ instance_id=instance_id,
+ message=message,
+ is_success=False,
+ aiconfig_instance=None,
+ )
+
+
+def aiconfig_to_json(aiconfig_instance: AIConfigRuntime | None) -> core_utils.JSONObject | None:
+ if aiconfig_instance is None:
+ return None
+ else:
+ EXCLUDE_OPTIONS = {
+ "prompt_index": True,
+ "file_path": True,
+ "callback_manager": True,
+ }
+ return aiconfig_instance.model_dump(exclude=EXCLUDE_OPTIONS)
+
+
+class OperationOutput(core_utils.Record):
+ # TODO: change the fields
+ instance_id: str
+ message: str
+ is_success: bool
+ aiconfig_instance: AIConfigRuntime | None
+ # TODO: make this a more constrained type
+ data: Any | None = None
+
+ @staticmethod
+ def from_method_output(
+ instance_id: str, aiconfig_instance: AIConfigRuntime, method_output: Result[Any, str], message_suffix: str = ""
+ ) -> "OperationOutput":
+ match method_output:
+ case Ok(output_ok):
+ out = OperationOutput(
+ instance_id=instance_id,
+ message=message_suffix,
+ is_success=True,
+ aiconfig_instance=aiconfig_instance,
+ data={"output": str(output_ok)},
+ )
+ logger.info(f"{out.instance_id=}, {out.message=}")
+ return out
+ case Err(e):
+ logger.error(f"{e=}")
+ return OperationOutput(
+ instance_id=instance_id,
+ message=f"Failed to run prompt: {e}\n{message_suffix}",
+ is_success=False,
+ aiconfig_instance=None,
+ )
+
+ def to_json(self) -> core_utils.JSONObject:
+ return core_utils.JSONObject(
+ {
+ "instance_id": self.instance_id,
+ "message": self.message,
+ "is_success": self.is_success,
+ "data": self.data,
+ "aiconfig": aiconfig_to_json(self.aiconfig_instance),
+ }
+ )
+
+
+class InstanceState(core_utils.Record):
+ instance_id: str
+ aiconfig_instance: AIConfigRuntime
+ aiconfig_path: UnvalidatedPath
+
+
+@dataclass
+class OperationOutcome:
+ operation_output: OperationOutput
+ instance_state: InstanceState
+
+
+@dataclass
+class ConnectionState:
+ websocket: WebSocket
+
+
+def resolve_path(path: str) -> str:
+ return os.path.abspath(os.path.expanduser(path))
+
+
+def get_validated_path(raw_path: str | None) -> Result[ValidatedPath, str]:
+ if not raw_path:
+ return Err("No path provided")
+ resolved = resolve_path(raw_path)
+ if not os.path.isfile(resolved):
+ return Err(f"File does not exist: {resolved}")
+ return Ok(ValidatedPath(resolved))
+
+
+T_TaskOutcome = TypeVar("T_TaskOutcome", OperationOutcome, Result[Command, str])
+
+
+class DoneTask(Generic[T_TaskOutcome]):
+ def __init__(self, task: asyncio.Task[T_TaskOutcome]):
+ self.task = task
+
+
+@dataclass
+class GlobalState:
+ # TODO: is there a better way to pass this into websocket connections?
+ editor_config: EditServerConfig
+ active_instances: dict[str, ConnectionState]
diff --git a/python/src/aiconfig/editor/server/server_v2_operation_lib.py b/python/src/aiconfig/editor/server/server_v2_operation_lib.py
new file mode 100644
index 000000000..ff58edbf5
--- /dev/null
+++ b/python/src/aiconfig/editor/server/server_v2_operation_lib.py
@@ -0,0 +1,181 @@
+## SECTION: Imports
+
+from abc import abstractmethod
+import os
+import sys
+import importlib
+import importlib.util
+
+from types import ModuleType
+from typing import Callable, Generic, Protocol, TypeVar
+
+import aiconfig.editor.server.server_v2_common as server_common
+import lastmile_utils.lib.core.api as core_utils
+from aiconfig.Config import AIConfigRuntime
+from result import Err, Ok, Result
+
+from aiconfig.schema import ExecuteResult
+
+logger = core_utils.get_logger(__name__, log_file_path="editor_server_v2.log")
+
+## SECTION: Types
+
+T_MethodOutput = TypeVar("T_MethodOutput", None, list[ExecuteResult], covariant=True)
+
+
+class RunMethodFn(Protocol, Generic[server_common.T_Operation, T_MethodOutput]):
+ @abstractmethod
+ async def __call__(self, aiconfig_instance: AIConfigRuntime, inputs: server_common.T_Operation) -> T_MethodOutput: # type: ignore[fixme]
+ pass
+
+
+class RunOperationFn(Protocol, Generic[server_common.T_Operation]):
+ @abstractmethod
+ async def __call__(
+ self, aiconfig_instance: AIConfigRuntime, instance_id: str, inputs: server_common.T_Operation
+ ) -> server_common.OperationOutput:
+ pass
+
+
+class AIConfigFn(Protocol, Generic[server_common.P, server_common.T_Output]):
+ @abstractmethod
+ async def __call__(
+ self, aiconfig_instance: AIConfigRuntime, *args: server_common.P.args, **kwargs: server_common.P.kwargs
+ ) -> server_common.T_Output:
+ pass
+
+
+def aiconfig_result_to_operation_output(
+ res_aiconfig: Result[AIConfigRuntime, str], instance_id: str, message_suffix: str = ""
+) -> server_common.OperationOutput:
+ match res_aiconfig:
+ case Ok(aiconfig_instance):
+ return server_common.OperationOutput(
+ instance_id=instance_id,
+ message=message_suffix,
+ is_success=True,
+ aiconfig_instance=aiconfig_instance,
+ )
+ case Err(e):
+ return server_common.OperationOutput(
+ instance_id=instance_id,
+ message=f"No AIConfig: {e}" + message_suffix,
+ is_success=False,
+ aiconfig_instance=None,
+ )
+
+
+def operation_output_to_response(operation_output: server_common.OperationOutput) -> server_common.Response:
+ return server_common.Response(
+ instance_id=operation_output.instance_id,
+ message=operation_output.message,
+ is_success=operation_output.is_success,
+ aiconfig_instance=operation_output.aiconfig_instance,
+ data=operation_output.data,
+ )
+
+
+def operation_to_aiconfig_path(operation: server_common.Operation) -> server_common.UnvalidatedPath | None:
+ match operation:
+ case server_common.Load(path=path_raw):
+ return path_raw
+ case server_common.Save(path=path_raw):
+ return path_raw
+ case _:
+ return None
+
+
+def operation_input_to_output(run_method_fn: RunMethodFn[server_common.T_Operation, T_MethodOutput]) -> RunOperationFn[server_common.T_Operation]:
+ """Decorator to make a function:
+ (a) robust to exceptions,
+ (b) Convert an arbitrary output into a (standard) OperationOutput.
+
+ The input function takes an AIConfigRuntime instance and one of the Operation subtypes
+ and returns some value depending on which operation was run.
+
+ The output (decorated) function does essentially the same thing, but with the properties listed above.
+ The output function also automatically accepts the instance_id, which maps 1:1 with the aiconfig,
+ and bundles it into the operation output.
+
+ See `run_add_prompt() for example`.
+
+ """
+
+ async def _new_fn(aiconfig_instance: AIConfigRuntime, instance_id: str, inputs: server_common.T_Operation) -> server_common.OperationOutput:
+ @core_utils.safe_run_fn_async
+ async def _wrap_input_fn(aiconfig_instance: AIConfigRuntime, inputs: server_common.T_Operation) -> T_MethodOutput:
+ return await run_method_fn(aiconfig_instance, inputs)
+
+ method_output = await _wrap_input_fn(aiconfig_instance, inputs)
+
+ logger.info(f"Ran operation: {inputs}")
+ out = server_common.OperationOutput.from_method_output(instance_id, aiconfig_instance, method_output, f"Ran operation: {inputs}")
+ logger.info(f"{out.instance_id=}, {out.message=}")
+ return out
+
+ return _new_fn
+
+
+@core_utils.safe_run_fn
+def safe_run_create() -> AIConfigRuntime:
+ out = AIConfigRuntime.create() # type: ignore
+ return out
+
+
+@core_utils.safe_run_fn_async
+async def safe_save_to_disk(aiconfig_instance: AIConfigRuntime, path: server_common.UnvalidatedPath) -> None:
+ return aiconfig_instance.save(path)
+
+
+@core_utils.safe_run_fn
+def safe_load_from_disk(aiconfig_path: server_common.ValidatedPath) -> AIConfigRuntime:
+ aiconfig = AIConfigRuntime.load(aiconfig_path) # type: ignore
+ return aiconfig
+
+
+def _import_module_from_path(path_to_module: str) -> Result[ModuleType, str]:
+ logger.debug(f"{path_to_module=}")
+ resolved_path = server_common.resolve_path(path_to_module)
+ logger.debug(f"{resolved_path=}")
+ module_name = os.path.basename(resolved_path).replace(".py", "")
+
+ try:
+ spec = importlib.util.spec_from_file_location(module_name, resolved_path)
+ if spec is None:
+ return Err(f"Could not import module from path: {resolved_path}")
+ elif spec.loader is None:
+ return Err(f"Could not import module from path: {resolved_path} (no loader)")
+ else:
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module)
+ return Ok(module)
+ except Exception as e:
+ return core_utils.ErrWithTraceback(e)
+
+
+@core_utils.safe_run_fn_async
+async def _register_user_model_parsers(user_register_fn: Callable[[], None]) -> None:
+ out = user_register_fn()
+ return out
+
+
+def _load_register_fn_from_user_module(user_module: ModuleType) -> Result[Callable[[], None], str]:
+ if not hasattr(user_module, "register_model_parsers"):
+ return Err(f"User module {user_module} does not have a register_model_parsers function.")
+ register_fn = getattr(user_module, "register_model_parsers")
+ if not callable(register_fn):
+ return Err(f"User module {user_module} does not have a register_model_parsers function")
+ else:
+ return Ok(register_fn)
+
+
+async def load_user_parser_module(path_to_module: str) -> Result[None, str]:
+ logger.info(f"Importing parsers module from {path_to_module}")
+ res_user_module = _import_module_from_path(path_to_module)
+ register_result = await (
+ res_user_module.and_then(_load_register_fn_from_user_module)
+ #
+ .and_then_async(_register_user_model_parsers)
+ )
+ return register_result
diff --git a/python/src/aiconfig/editor/server/server_v2_run_operation.py b/python/src/aiconfig/editor/server/server_v2_run_operation.py
new file mode 100644
index 000000000..534798fd8
--- /dev/null
+++ b/python/src/aiconfig/editor/server/server_v2_run_operation.py
@@ -0,0 +1,298 @@
+## SECTION: Imports
+
+import asyncio
+import dataclasses
+import json
+import threading
+
+from queue import Empty, Queue
+from typing import Any, cast
+
+import aiconfig.editor.server.server_v2_common as server_common
+import aiconfig.editor.server.server_v2_operation_lib as operation_lib
+import lastmile_utils.lib.core.api as core_utils
+from aiconfig.Config import AIConfigRuntime
+from aiconfig.model_parser import InferenceOptions
+from aiconfig.registry import ModelParserRegistry
+from fastapi import WebSocket
+from result import Err, Ok, Result
+
+from aiconfig.schema import ExecuteResult, Prompt
+
+logger = core_utils.get_logger(__name__, log_file_path="editor_server_v2.log")
+
+
+## SECTION: AIConfig operation run functions (following operations sent over websocket)
+
+
+async def run_operation(
+ operation: server_common.Operation, aiconfig_instance: AIConfigRuntime, instance_id: str, websocket: WebSocket
+) -> server_common.OperationOutput:
+ match operation:
+ case server_common.GetInstanceStatus():
+ return await run_get_instance_status(instance_id)
+ case server_common.ListModels():
+ return await run_list_models(instance_id)
+ case server_common.LoadModelParserModule(path=path_raw):
+ return await run_load_model_parser_module(instance_id, path_raw)
+ case server_common.Create():
+ return await run_create(instance_id)
+ case server_common.Load(path=path_raw):
+ return await run_load(aiconfig_instance, instance_id, path_raw)
+ case server_common.Run():
+ return await run_run(aiconfig_instance, instance_id, operation, websocket)
+ case server_common.AddPrompt():
+ return await run_add_prompt(aiconfig_instance, instance_id, operation)
+ case server_common.UpdatePrompt():
+ return await run_update_prompt(aiconfig_instance, instance_id, operation)
+ case server_common.DeletePrompt():
+ return await run_delete_prompt(aiconfig_instance, instance_id, operation)
+ case server_common.Save():
+ return await run_save(aiconfig_instance, instance_id, operation)
+ case server_common.UpdateModel():
+ return await run_update_model(aiconfig_instance, instance_id, operation)
+ case server_common.SetParameter():
+ return await run_set_parameter(aiconfig_instance, instance_id, operation)
+ case server_common.SetParameters():
+ return await run_set_parameters(aiconfig_instance, instance_id, operation)
+ case server_common.DeleteParameter():
+ return await run_delete_parameter(aiconfig_instance, instance_id, operation)
+ case server_common.SetName():
+ return await run_set_name(aiconfig_instance, instance_id, operation)
+ case server_common.SetDescription():
+ return await run_set_description(aiconfig_instance, instance_id, operation)
+ case server_common.MockRun():
+ return await run_mock_run(aiconfig_instance, instance_id, operation)
+
+
+async def run_get_instance_status(instance_id: str) -> server_common.OperationOutput:
+ return server_common.OperationOutput(
+ instance_id=instance_id,
+ message="See `data` field for instance status",
+ is_success=True,
+ aiconfig_instance=None,
+ data={"status": "OK"},
+ )
+
+
+async def run_list_models(
+ instance_id: str,
+) -> server_common.OperationOutput:
+ ids: list[str] = ModelParserRegistry.parser_ids() # type: ignore
+ return server_common.OperationOutput(
+ instance_id=instance_id,
+ message="Listed models",
+ is_success=True,
+ aiconfig_instance=None,
+ data={"ids": ids},
+ )
+
+
+async def run_load_model_parser_module(instance_id: str, path_raw: server_common.UnvalidatedPath) -> server_common.OperationOutput:
+ load_module_result = await server_common.get_validated_path(path_raw).and_then_async(operation_lib.load_user_parser_module)
+ match load_module_result:
+ case Ok(_module):
+ return server_common.OperationOutput(
+ instance_id=instance_id,
+ message=f"Loaded module {path_raw}, output: {_module}",
+ is_success=True,
+ aiconfig_instance=None,
+ )
+ case Err(e):
+ return server_common.OperationOutput(
+ instance_id=instance_id,
+ message=f"Failed to load module {path_raw}: {e}",
+ is_success=False,
+ aiconfig_instance=None,
+ )
+
+
+async def run_create(instance_id: str) -> server_common.OperationOutput:
+ aiconfig_instance = operation_lib.safe_run_create()
+ return operation_lib.aiconfig_result_to_operation_output(aiconfig_instance, instance_id)
+
+
+async def run_load(
+ aiconfig_instance: AIConfigRuntime,
+ instance_id: str,
+ path_raw: server_common.UnvalidatedPath | None,
+) -> server_common.OperationOutput:
+ if path_raw is None:
+ return server_common.OperationOutput(
+ instance_id=instance_id,
+ message="No path given, but AIConfig is loaded into memory. Here it is!",
+ is_success=True,
+ aiconfig_instance=aiconfig_instance,
+ )
+ else:
+ res_path_val = server_common.get_validated_path(path_raw)
+ res_aiconfig = res_path_val.and_then(operation_lib.safe_load_from_disk)
+ message = f"Loaded AIConfig from {res_path_val}. This may have overwritten in-memory changes."
+ logger.warning(message)
+ return server_common.OperationOutput(
+ instance_id=instance_id,
+ message=message,
+ is_success=res_aiconfig.is_ok(),
+ aiconfig_instance=res_aiconfig.unwrap_or(None),
+ )
+
+
+async def run_run(
+ aiconfig_instance: AIConfigRuntime, instance_id: str, inputs: server_common.Run, websocket: WebSocket
+) -> server_common.OperationOutput:
+ if inputs.stream:
+ logger.info(f"Running streaming operation: {inputs}")
+
+ out_queue: Queue[core_utils.JSONObject | None] = Queue()
+
+ def _stream_callback_queue_put(data: core_utils.JSONObject, accumulated_data: Any, index: int) -> None:
+ logger.debug(f"[stream callback]put {data=}")
+ # time.sleep(0.1)
+ out_queue.put(data)
+
+ inference_options = InferenceOptions(stream=True, stream_callback=_stream_callback_queue_put)
+
+ @core_utils.safe_run_fn_async
+ async def _run_streaming_inner(aiconfig_instance: AIConfigRuntime, inputs: server_common.Run) -> list[ExecuteResult]:
+ return await aiconfig_instance.run(inputs.prompt_name, inputs.params, inference_options) # type: ignore
+
+ # This function gets asyncio.run() to type check
+ async def _run_inner_wrap():
+ return await _run_streaming_inner(aiconfig_instance, inputs)
+
+ @dataclasses.dataclass
+ class ThreadOutput:
+ value: Result[list[ExecuteResult], str] = Err("Not set")
+
+ def _run_thread_main(out_queue: Queue[dict[str, str] | None], output: ThreadOutput) -> Result[list[ExecuteResult], str]:
+ method_output = asyncio.run(_run_inner_wrap())
+ logger.info(f"Ran operation: {inputs}")
+ output.value = method_output
+ out_queue.put(None)
+ return method_output
+
+ async def _send(data: core_utils.JSONObject) -> None:
+ logger.debug(f"[send]{data=}")
+ response = json.dumps(data)
+ send_res = await websocket.send_text(response)
+ logger.debug(f"sent {response=}, {send_res=}")
+
+ thread_output = ThreadOutput()
+ thread = threading.Thread(target=_run_thread_main, args=(out_queue, thread_output))
+ thread.start()
+
+ async def _read_queue_and_send_until_empty() -> None:
+ while True:
+ try:
+ data = out_queue.get(block=True, timeout=10)
+ logger.debug(f"[get]{data=}")
+ if data is None:
+ return
+ await _send(data)
+ except Empty as e:
+ raise ValueError(f"Timeout waiting for output from operation") from e
+
+ await _read_queue_and_send_until_empty()
+ thread.join()
+ return server_common.OperationOutput.from_method_output(instance_id, aiconfig_instance, thread_output.value, "Method: run with streaming")
+ else:
+
+ @core_utils.safe_run_fn_async
+ async def _run_not_stream(aiconfig_instance: AIConfigRuntime, inputs: server_common.Run) -> list[ExecuteResult]:
+ out: list[ExecuteResult] = cast(
+ list[ExecuteResult], await aiconfig_instance.run(inputs.prompt_name, inputs.params, InferenceOptions(stream=False)) # type: ignore
+ )
+ return out
+
+ result_run = await _run_not_stream(aiconfig_instance, inputs)
+ return server_common.OperationOutput.from_method_output(instance_id, aiconfig_instance, result_run, "Method: run w/o streaming")
+
+
+@operation_lib.operation_input_to_output
+async def run_add_prompt(aiconfig_instance: AIConfigRuntime, inputs: server_common.AddPrompt) -> None:
+ return aiconfig_instance.add_prompt(inputs.prompt_name, inputs.prompt_data, inputs.index)
+
+
+@operation_lib.operation_input_to_output
+async def run_update_prompt(aiconfig_instance: AIConfigRuntime, inputs: server_common.UpdatePrompt) -> None:
+ return aiconfig_instance.update_prompt(inputs.prompt_name, inputs.prompt_data)
+
+
+@operation_lib.operation_input_to_output
+async def run_delete_prompt(aiconfig_instance: AIConfigRuntime, inputs: server_common.DeletePrompt) -> None:
+ return aiconfig_instance.delete_prompt(inputs.prompt_name)
+
+
+@operation_lib.operation_input_to_output
+async def run_save(aiconfig_instance: AIConfigRuntime, inputs: server_common.Save) -> None:
+ return aiconfig_instance.save(inputs.path)
+
+
+@operation_lib.operation_input_to_output
+async def run_update_model(aiconfig_instance: AIConfigRuntime, inputs: server_common.UpdateModel) -> None:
+ return aiconfig_instance.update_model(inputs.model_name, inputs.settings, inputs.prompt_name)
+
+
+@operation_lib.operation_input_to_output
+async def run_set_parameter(aiconfig_instance: AIConfigRuntime, inputs: server_common.SetParameter) -> None:
+ return aiconfig_instance.set_parameter(inputs.parameter_name, inputs.parameter_value, inputs.prompt_name)
+
+
+@operation_lib.operation_input_to_output
+async def run_set_parameters(aiconfig_instance: AIConfigRuntime, inputs: server_common.SetParameters) -> None:
+ return aiconfig_instance.set_parameters(inputs.parameters, inputs.prompt_name)
+
+
+@operation_lib.operation_input_to_output
+async def run_delete_parameter(aiconfig_instance: AIConfigRuntime, inputs: server_common.DeleteParameter) -> None:
+ return aiconfig_instance.delete_parameter(inputs.parameter_name, inputs.prompt_name) # type: ignore
+
+
+@operation_lib.operation_input_to_output
+async def run_set_name(aiconfig_instance: AIConfigRuntime, inputs: server_common.SetName) -> None:
+ return aiconfig_instance.set_name(inputs.name)
+
+
+@operation_lib.operation_input_to_output
+async def run_set_description(aiconfig_instance: AIConfigRuntime, inputs: server_common.SetDescription) -> None:
+ return aiconfig_instance.set_description(inputs.description)
+
+
+async def run_mock_run(aiconfig_instance: AIConfigRuntime, instance_id: str, inputs: server_common.MockRun) -> server_common.OperationOutput:
+ """Sleep for `seconds` and add two test prompts to the aiconfig."""
+ logger.info(f"Running operation: {inputs}")
+ s = inputs.seconds
+ if s < 0.2:
+ return server_common.OperationOutput(
+ instance_id=instance_id,
+ message=f"Sleep time must be at least 0.2 second, got {s}",
+ is_success=False,
+ aiconfig_instance=None,
+ )
+
+ SLEEP_PART_1 = 0.1
+ SLEEP_PART_2 = s - SLEEP_PART_1
+ await asyncio.sleep(SLEEP_PART_1)
+ try:
+ last = 0
+ if len(aiconfig_instance.prompt_index) > 0:
+ last = max(int(k) for k in aiconfig_instance.prompt_index.keys())
+ next_ = last + 1
+ aiconfig_instance.add_prompt(str(next_), Prompt(name=str(next_), input=f"mock_prompt_{next_}"))
+ await asyncio.sleep(SLEEP_PART_2)
+ aiconfig_instance.add_prompt(str(next_ + 1), Prompt(name=str(next_ + 1), input=f"mock_prompt_{next_+1}"))
+ return server_common.OperationOutput(
+ instance_id=instance_id,
+ message=f"Blocked for {s} seconds and added prompts {next_}, {next_+1}",
+ is_success=True,
+ aiconfig_instance=aiconfig_instance,
+ )
+
+ except ValueError as e:
+ err = core_utils.ErrWithTraceback(e)
+ return server_common.OperationOutput(
+ instance_id=instance_id,
+ message=f"Test aiconfig is invalid. All prompt names must be ints. Got {err}, {aiconfig_instance.prompt_index.keys()}",
+ is_success=False,
+ aiconfig_instance=None,
+ )
diff --git a/python/src/aiconfig/editor/server/server_v2_utils.py b/python/src/aiconfig/editor/server/server_v2_utils.py
new file mode 100644
index 000000000..4a3bad3e5
--- /dev/null
+++ b/python/src/aiconfig/editor/server/server_v2_utils.py
@@ -0,0 +1,250 @@
+## SECTION: Imports and Constants
+
+import asyncio
+from dataclasses import dataclass
+import logging
+import os
+import uuid
+
+import lastmile_utils.lib.core.api as core_utils
+from fastapi import WebSocket
+from result import Err, Ok, Result
+import result
+
+import aiconfig.editor.server.server_v2_operation_lib as operation_lib
+import aiconfig.editor.server.server_v2_run_operation as operations
+import aiconfig.editor.server.server_v2_common as server_common
+
+logger: logging.Logger = core_utils.get_logger(__name__, log_file_path="editor_server_v2.log")
+
+
+@dataclass
+class LoopState:
+ instance_state: server_common.InstanceState
+ operation_task: asyncio.Task[server_common.OperationOutcome] | None
+ recv_task: asyncio.Task[Result[server_common.Command, str]]
+
+ @staticmethod
+ async def new(websocket: WebSocket, edit_config: server_common.EditServerConfig) -> Result["LoopState", str]:
+ instance_state = await _init_websocket_instance(edit_config)
+ return result.do(
+ Ok(
+ LoopState(
+ #
+ instance_state=instance_state_ok,
+ operation_task=None,
+ recv_task=schedule_receive_task(websocket),
+ )
+ )
+ for instance_state_ok in instance_state
+ )
+
+
+async def _init_websocket_instance(edit_config: server_common.EditServerConfig) -> Result[server_common.InstanceState, str]:
+ instance_id = str(uuid.uuid4())
+ logger.info(f"Starting websocket connection. {instance_id=}")
+ aiconfig_path = server_common.UnvalidatedPath(edit_config.aiconfig_path)
+ if os.path.exists(aiconfig_path):
+ return result.do(
+ Ok(server_common.InstanceState(instance_id=instance_id, aiconfig_instance=aiconfig_instance_ok, aiconfig_path=aiconfig_path))
+ for val_path in server_common.get_validated_path(aiconfig_path)
+ for aiconfig_instance_ok in operation_lib.safe_load_from_disk(val_path)
+ )
+ else:
+ return await result.do_async(
+ #
+ Ok(server_common.InstanceState(instance_id=instance_id, aiconfig_instance=aiconfig_instance_ok, aiconfig_path=aiconfig_path))
+ for aiconfig_instance_ok in operation_lib.safe_run_create()
+ for _ in await operation_lib.safe_save_to_disk(aiconfig_instance_ok, aiconfig_path)
+ )
+
+
+## SECTION: Websocket control and AIConfig instance state management
+
+
+async def _receive_command(websocket: WebSocket) -> Result[server_common.Command, str]:
+ data = await websocket.receive_text()
+ logger.debug(f"DATA:#\n{data}#, type: {type(data)}")
+ return result.do(Ok(res_cmd.command) for res_cmd in core_utils.safe_model_validate_json(data, server_common.SerializableCommand))
+
+
+def schedule_operation_task(
+ operation: server_common.Operation, instance_state: server_common.InstanceState, websocket: WebSocket
+) -> asyncio.Task[server_common.OperationOutcome]:
+ logger.info("Running create task")
+
+ async def _operation_task() -> server_common.OperationOutcome:
+ return await _get_operation_outcome(operation, instance_state, websocket)
+
+ # Enter run mode
+ operation_task = asyncio.create_task(_operation_task())
+ logger.info("Created task")
+ return operation_task
+
+
+def schedule_receive_task(websocket: WebSocket) -> asyncio.Task[Result[server_common.Command, str]]:
+ async def _task() -> Result[server_common.Command, str]:
+ return await _receive_command(websocket)
+
+ return asyncio.create_task(_task())
+
+
+async def handle_websocket_loop_iteration(
+ loop_state: LoopState, websocket: WebSocket
+) -> Result[tuple[server_common.Response | None, LoopState], str]:
+ if loop_state.operation_task is None:
+ return await _handle_new_operation_case(loop_state, websocket)
+ else:
+ return await _handle_existing_operation_case(loop_state, loop_state.operation_task, websocket)
+
+
+async def _handle_new_operation_case(
+ current_loop_state: LoopState,
+ websocket: WebSocket,
+) -> Result[tuple[server_common.Response | None, LoopState], str]:
+ instance_state = current_loop_state.instance_state
+ instance_id = instance_state.instance_id
+ res_command = await current_loop_state.recv_task
+ new_recv_task = schedule_receive_task(websocket)
+ logger.info(f"{res_command=}")
+ match res_command:
+ case Ok(command):
+ match command:
+ case server_common.Cancel():
+ logger.info("Received cancel command but no operation is running")
+ response = server_common.Response.from_error_message(instance_id, "Received cancel command but no operation is running")
+ new_loop_state = LoopState(instance_state=instance_state, operation_task=None, recv_task=new_recv_task)
+ return Ok((response, new_loop_state))
+ case _:
+ loop_state = LoopState(
+ instance_state=instance_state,
+ operation_task=schedule_operation_task(command, instance_state, websocket),
+ recv_task=new_recv_task,
+ )
+ return Ok((None, loop_state))
+ case Err(e):
+ logger.error(f"Failed to parse command: {e}")
+ response = server_common.Response.from_error_message(instance_id, f"Failed to parse command: {e}")
+ loop_state = LoopState(instance_state=instance_state, operation_task=None, recv_task=new_recv_task)
+ return Ok((response, loop_state))
+
+
+async def _handle_existing_operation_case(
+ current_loop_state: LoopState, current_operation_task: asyncio.Task[server_common.OperationOutcome], websocket: WebSocket
+) -> Result[tuple[server_common.Response | None, LoopState], str]:
+ # both recv task and operation task are running
+ # Wait for one to finish
+ logger.info("Both recv task and operation task running. Waiting for one to finish.")
+ instance_state = current_loop_state.instance_state
+
+ done, pending = await asyncio.wait([current_operation_task, current_loop_state.recv_task], return_when=asyncio.FIRST_COMPLETED)
+ # At least one is now done
+ if not any(t.done() for t in [current_loop_state.recv_task, current_operation_task]):
+ return Err(f"Got done and pending sets, but tasks are still running. This should not happen. {done=}, {pending=}")
+ else:
+ if current_operation_task.done():
+ done_task = server_common.DoneTask(current_operation_task)
+ return Ok(_handle_operation_task_done(current_loop_state, done_task, instance_state))
+ else:
+ done_task = server_common.DoneTask(current_loop_state.recv_task)
+ return Ok(_handle_recv_task_done(current_operation_task, done_task, instance_state, websocket))
+
+
+def _handle_operation_task_done(
+ current_loop_state: LoopState,
+ current_operation_done_task: server_common.DoneTask[server_common.OperationOutcome],
+ instance_state: server_common.InstanceState,
+) -> tuple[server_common.Response | None, LoopState]:
+ current_operation_task = current_operation_done_task.task
+ if current_operation_task.cancelled():
+ logger.info("Operation task cancelled")
+ loop_state = LoopState(instance_state=instance_state, operation_task=None, recv_task=current_loop_state.recv_task)
+ return (None, loop_state)
+ else:
+ logger.info("Operation task done")
+ task_result = current_operation_task.result()
+ loop_state = LoopState(instance_state=task_result.instance_state, operation_task=None, recv_task=current_loop_state.recv_task)
+ response = operation_lib.operation_output_to_response(task_result.operation_output)
+ return (response, loop_state)
+
+
+def _handle_recv_task_done(
+ current_operation_task: asyncio.Task[server_common.OperationOutcome],
+ current_recv_done_task: server_common.DoneTask[Result[server_common.Command, str]],
+ instance_state: server_common.InstanceState,
+ websocket: WebSocket,
+) -> tuple[server_common.Response | None, LoopState]:
+ current_recv_task = current_recv_done_task.task
+ instance_id = instance_state.instance_id
+ if current_recv_task.cancelled():
+ logger.info("Recv task cancelled")
+ loop_state = LoopState(instance_state=instance_state, operation_task=current_operation_task, recv_task=schedule_receive_task(websocket))
+ return (None, loop_state)
+ else:
+ logger.info("Recv task done")
+ res_command = current_recv_task.result()
+ match res_command:
+ case Ok(command):
+ return _handle_new_command_while_operation_task_running(command, current_operation_task, instance_state, websocket)
+ case Err(e):
+ logger.error(f"Failed to parse command: {e}")
+ response = server_common.Response.from_error_message(instance_id, f"Failed to parse command: {e}")
+ loop_state = LoopState(instance_state=instance_state, operation_task=None, recv_task=schedule_receive_task(websocket))
+ return (response, loop_state)
+
+
+def _handle_new_command_while_operation_task_running(
+ command: server_common.Command,
+ current_operation_task: asyncio.Task[server_common.OperationOutcome],
+ instance_state: server_common.InstanceState,
+ websocket: WebSocket,
+) -> tuple[server_common.Response | None, LoopState]:
+ instance_id = instance_state.instance_id
+ match command:
+ case server_common.Cancel():
+ logger.info("Received cancel command while operation task is running. Cancelling operation.")
+ current_operation_task.cancel()
+ loop_state = LoopState(instance_state=instance_state, operation_task=None, recv_task=schedule_receive_task(websocket))
+ response = server_common.Response(
+ instance_id=instance_id,
+ message="Cancelling command",
+ is_success=True,
+ aiconfig_instance=None,
+ )
+ return (response, loop_state)
+ case _:
+ # TODO: something other than _
+ logger.info("Received operation while operation task is running. Ignoring request.")
+ loop_state = LoopState(
+ instance_state=instance_state,
+ operation_task=current_operation_task,
+ recv_task=schedule_receive_task(websocket),
+ )
+ response = server_common.Response.from_error_message(
+ #
+ instance_id,
+ "Received operation while operation task is running. Ignoring request.",
+ )
+ return (response, loop_state)
+
+
+async def _get_operation_outcome(
+ operation: server_common.Operation,
+ instance_state: server_common.InstanceState,
+ websocket: WebSocket,
+) -> server_common.OperationOutcome:
+ current_aiconfig_instance = instance_state.aiconfig_instance
+ current_aiconfig_path = instance_state.aiconfig_path
+ instance_id = instance_state.instance_id
+ operation_output = await operations.run_operation(operation, current_aiconfig_instance, instance_id, websocket)
+
+ aiconfig_instance_updated = operation_output.aiconfig_instance if operation_output.aiconfig_instance is not None else current_aiconfig_instance
+ aiconfig_path = operation_lib.operation_to_aiconfig_path(operation)
+ aiconfig_path_updated = aiconfig_path or current_aiconfig_path
+ logger.debug("Updated instance: %s", aiconfig_instance_updated)
+ return server_common.OperationOutcome(
+ operation_output=operation_output,
+ instance_state=server_common.InstanceState(
+ instance_id=instance_id, aiconfig_instance=aiconfig_instance_updated, aiconfig_path=aiconfig_path_updated
+ ),
+ )
diff --git a/python/src/aiconfig/scripts/aiconfig_cli_v2.py b/python/src/aiconfig/scripts/aiconfig_cli_v2.py
new file mode 100644
index 000000000..8f0b2074c
--- /dev/null
+++ b/python/src/aiconfig/scripts/aiconfig_cli_v2.py
@@ -0,0 +1,126 @@
+import asyncio
+import logging
+import signal
+import subprocess
+import sys
+
+import lastmile_utils.lib.core.api as core_utils
+import result
+from result import Err, Ok, Result
+
+import aiconfig.editor.server.server_v2_common as server_common
+from aiconfig.editor.server.server_v2 import run_backend_server
+
+
+class AIConfigCLIConfig(core_utils.Record):
+ log_level: str | int = "WARNING"
+
+
+logging.basicConfig(format=core_utils.LOGGER_FMT)
+LOGGER = logging.getLogger(__name__)
+
+
+async def main_with_args(argv: list[str]) -> int:
+ final_result = await run_subcommand(argv)
+ match final_result:
+ case Ok(msg):
+ LOGGER.info("Final result: Ok:\n%s", msg)
+ return 0
+ case Err(e):
+ LOGGER.critical("Final result err: %s", e)
+ return core_utils.result_to_exitcode(Err(e))
+
+
+async def run_subcommand(argv: list[str]) -> Result[str, str]:
+ LOGGER.info("Running subcommand")
+ subparser_record_types = {"edit": server_common.EditServerConfig}
+ main_parser = core_utils.argparsify(AIConfigCLIConfig, subparser_record_types=subparser_record_types)
+
+ res_cli_config = core_utils.parse_args(main_parser, argv[1:], AIConfigCLIConfig)
+ res_cli_config.and_then(_process_cli_config)
+
+ subparser_name = core_utils.get_subparser_name(main_parser, argv[1:])
+ LOGGER.info(f"Running subcommand: {subparser_name}")
+
+ if subparser_name == "edit":
+ LOGGER.debug("Running edit subcommand")
+ res_edit_config = core_utils.parse_args(main_parser, argv[1:], server_common.EditServerConfig)
+ LOGGER.debug(f"{res_edit_config.is_ok()=}")
+ res_servers = await res_edit_config.and_then_async(_run_editor_servers)
+ out: Result[str, str] = await result.do_async(
+ #
+ Ok(",".join(res_servers_ok))
+ #
+ for res_servers_ok in res_servers
+ )
+ return out
+ else:
+ return Err(f"Unknown subparser: {subparser_name}")
+
+
+def _sigint(procs: list[subprocess.Popen[bytes]]) -> Result[str, str]:
+ LOGGER.info("sigint")
+ for p in procs:
+ p.send_signal(signal.SIGINT)
+ return Ok("Sent SIGINT to frontend servers.")
+
+
+async def _run_editor_servers(edit_config: server_common.EditServerConfig) -> Result[list[str], str]:
+ LOGGER.info("Running editor servers")
+ frontend_procs = _run_frontend_server_background() if edit_config.server_mode in [server_common.ServerMode.DEBUG_SERVERS] else Ok([])
+ match frontend_procs:
+ case Ok(_):
+ pass
+ case Err(e):
+ return Err(e)
+
+ results: list[Result[str, str]] = []
+ backend_res = await run_backend_server(edit_config)
+ match backend_res:
+ case Ok(_):
+ pass
+ case Err(e):
+ return Err(e)
+
+ results.append(backend_res)
+
+ sigint_res = frontend_procs.and_then(_sigint)
+ results.append(sigint_res)
+ return core_utils.result_reduce_list_all_ok(results)
+
+
+def _process_cli_config(cli_config: AIConfigCLIConfig) -> Result[bool, str]:
+ LOGGER.setLevel(cli_config.log_level)
+ return Ok(True)
+
+
+def _run_frontend_server_background() -> Result[list[subprocess.Popen[bytes]], str]:
+ LOGGER.info("Running frontend server in background")
+ p1, p2 = None, None
+ try:
+ p1 = subprocess.Popen(["yarn"], cwd="python/src/aiconfig/editor/client")
+ except Exception as e:
+ return core_utils.ErrWithTraceback(e)
+
+ try:
+ p2 = subprocess.Popen(["yarn", "start"], cwd="python/src/aiconfig/editor/client", stdin=subprocess.PIPE)
+ except Exception as e:
+ return core_utils.ErrWithTraceback(e)
+
+ try:
+ assert p2.stdin is not None
+ p2.stdin.write(b"n\n")
+ except Exception as e:
+ return core_utils.ErrWithTraceback(e)
+
+ return Ok([p1, p2])
+
+
+def main() -> int:
+ argv = sys.argv
+ return asyncio.run(main_with_args(argv))
+
+
+if __name__ == "__main__":
+ retcode: int = main()
+ sys.exit(retcode)
diff --git a/python/tests/test_editor_server_v2.py b/python/tests/test_editor_server_v2.py
new file mode 100644
index 000000000..f44d8671f
--- /dev/null
+++ b/python/tests/test_editor_server_v2.py
@@ -0,0 +1,260 @@
+from abc import abstractmethod
+from dataclasses import dataclass
+import json
+from pathlib import Path
+import shlex
+import time
+from typing import Any, Callable, NewType, Protocol
+import lastmile_utils.lib.core.api as core_utils
+
+import dotenv
+import pytest
+from websocket import WebSocket, create_connection # type: ignore
+import aiconfig.editor.server.server_v2_common as server_common
+from xprocess import ProcessStarter, XProcess
+
+RunningServerConfig = NewType("RunningServerConfig", server_common.EditServerConfig)
+
+
+@dataclass
+class ConnectedWebsocket:
+ websocket: WebSocket
+ running_server: RunningServerConfig
+
+
+class GetConnectedWebsocketFn(Protocol):
+ @abstractmethod
+ def __call__(self, edit_config: server_common.EditServerConfig) -> ConnectedWebsocket:
+ pass
+
+
+def _get_cli_command(subcommand_cfg: server_common.EditServerConfig) -> list[str]:
+ match subcommand_cfg:
+ case server_common.EditServerConfig(
+ #
+ server_port=server_port_,
+ aiconfig_path=aiconfig_path_,
+ server_mode=server_mode_,
+ parsers_module_path=parsers_module_path_,
+ ):
+ subcommand = "edit"
+ str_server_mode = server_mode_.name.lower()
+ cmd = shlex.split(
+ f"""
+ python -m 'aiconfig.scripts.aiconfig_cli_v2' {subcommand} \
+ --server-port={server_port_} \
+ --server-mode={str_server_mode} \
+ --aiconfig-path={aiconfig_path_} \
+ --parsers-module-path={parsers_module_path_}
+ """
+ )
+ return cmd
+
+
+def _make_edit_config(**kwargs) -> server_common.EditServerConfig: # type: ignore
+ TEST_PORT = 8011
+ defaults = dict(
+ server_port=TEST_PORT,
+ server_mode=server_common.ServerMode.DEBUG_BACKEND,
+ parsers_module_path="aiconfig.model_parser",
+ )
+ given_kwargs: dict[str, Any] = kwargs
+ new_kwargs = core_utils.dict_union_allow_replace(defaults, given_kwargs, on_conflict="replace")
+ return server_common.EditServerConfig(**new_kwargs)
+
+
+def _make_connected_websocket(get_connected_websocket: GetConnectedWebsocketFn, edit_config: server_common.EditServerConfig) -> ConnectedWebsocket:
+ connected_websocket = get_connected_websocket(edit_config)
+ running_server = connected_websocket.running_server
+ assert running_server.aiconfig_path == edit_config.aiconfig_path
+ return connected_websocket
+
+
+def _make_default_connected_websocket(get_connected_websocket: GetConnectedWebsocketFn, tmp_path: Path) -> ConnectedWebsocket:
+ the_path = (tmp_path / "the.aiconfig.json").as_posix()
+ edit_config = _make_edit_config(aiconfig_path=the_path)
+ return _make_connected_websocket(get_connected_websocket, edit_config)
+
+
+def _load_aiconfig(websocket: WebSocket) -> dict[str, Any]:
+ ws_send_command(websocket, "load", dict())
+ resp = ws_receive_response(websocket)
+ assert resp is not None
+ return resp["aiconfig"]
+
+
+@pytest.fixture
+def get_running_server(xprocess: XProcess, request: pytest.FixtureRequest):
+ def _make(edit_config: server_common.EditServerConfig):
+ class Starter(ProcessStarter):
+ @property
+ def pattern(self): # type: ignore
+ return "Running on http://127.0.0.1"
+
+ # command to start process
+ @property
+ def args(self): # type: ignore
+ return _get_cli_command(edit_config)
+
+ terminate_on_interrupt = True
+
+ # ensure process is running and return its logfile
+ pid, logfile = xprocess.ensure("myserver", Starter) # type: ignore
+ print(f"{logfile=}")
+
+ dotenv.load_dotenv()
+
+ def cleanup():
+ # clean up whole process tree afterwards
+ xprocess.getinfo("myserver").terminate() # type: ignore
+
+ request.addfinalizer(cleanup)
+
+ return RunningServerConfig(edit_config)
+
+ return _make
+
+
+@pytest.fixture
+def get_connected_websocket(get_running_server: Callable[[server_common.EditServerConfig], RunningServerConfig], request: pytest.FixtureRequest):
+ def _make(edit_config: server_common.EditServerConfig):
+ running_server = get_running_server(edit_config)
+ url = f"ws://localhost:{running_server.server_port}/ws_manage_aiconfig_instance"
+ ws = create_connection(url)
+ synchronize_with_server(ws)
+
+ def cleanup():
+ ws.close()
+
+ request.addfinalizer(cleanup)
+ return ConnectedWebsocket(websocket=ws, running_server=running_server)
+
+ return _make
+
+
+def ws_send_command(websocket: WebSocket, command_name: str, command_params: dict[str, Any]):
+ command_json_obj = dict(command_name=command_name, **command_params)
+ message_obj = {"command": command_json_obj}
+ cmd_str = json.dumps(message_obj)
+ res = websocket.send_text(cmd_str)
+ print(f"sent {cmd_str=}, {res=}")
+ return res
+
+
+def ws_receive_response(websocket: WebSocket) -> dict[str, Any] | None:
+ resp_str = websocket.recv()
+ if not resp_str:
+ return None
+ resp = json.loads(resp_str)
+ return resp
+
+
+def synchronize_with_server(websocket: WebSocket):
+ """Send a command and wait for response
+ to make sure server is initialized and websocket is set up."""
+ ws_send_command(
+ websocket,
+ "mock_run",
+ command_params=dict(seconds=0.2),
+ )
+ _sync_resp = ws_receive_response(websocket)
+ # print(f"{sync_resp=}")
+
+
+def test_editor_server_start_new_file(get_connected_websocket: GetConnectedWebsocketFn, tmp_path: Path):
+ the_path = (tmp_path / "the.aiconfig.json").as_posix()
+ edit_config = _make_edit_config(aiconfig_path=the_path)
+ assert core_utils.load_json_file(the_path).is_err()
+ connected_websocket = _make_connected_websocket(get_connected_websocket, edit_config)
+ assert connected_websocket.running_server.aiconfig_path == the_path
+
+ running_server = connected_websocket.running_server
+
+ assert core_utils.load_json_file(running_server.aiconfig_path).is_ok()
+
+
+def _get_mock_default_prompt(prompt_num: int) -> dict[str, Any]:
+ return dict(name=str(prompt_num), input=f"mock_prompt_{prompt_num}", metadata=None, outputs=[])
+
+
+def test_editor_mock_run_simple(get_connected_websocket: GetConnectedWebsocketFn, tmp_path: Path):
+ connected_websocket = _make_default_connected_websocket(get_connected_websocket, tmp_path)
+ websocket = connected_websocket.websocket
+ ws_send_command(websocket, "mock_run", dict(seconds=0.3))
+ resp = ws_receive_response(websocket)
+ assert resp is not None
+ message, is_success, aiconfig, _data = resp["message"], resp["is_success"], resp["aiconfig"], resp["data"]
+ assert is_success
+ # synchronize_with_server runs mock_run with 0.2 seconds, which adds
+ # prompts 1 and 2.
+ # This command is expected to add 3 and 4.
+ assert message == "Blocked for 0.3 seconds and added prompts 3, 4"
+ prompts = aiconfig.get("prompts", [])
+ assert prompts == [_get_mock_default_prompt(i + 1) for i in range(4)]
+ print("Done")
+
+
+def test_editor_mock_run_cancel_rollback_1(get_connected_websocket: GetConnectedWebsocketFn, tmp_path: Path):
+ """Cancel immediately, before any mutation can happen"""
+ connected_websocket = _make_default_connected_websocket(get_connected_websocket, tmp_path)
+ websocket = connected_websocket.websocket
+
+ aiconfig_before = _load_aiconfig(websocket)
+ print(f"{aiconfig_before=}")
+
+ ws_send_command(websocket, "mock_run", dict(seconds=0.3))
+ ws_send_command(websocket, "cancel", dict())
+ cancel_resp = ws_receive_response(websocket)
+ print(f"{cancel_resp=}")
+ ws_send_command(websocket, "load", dict())
+ aiconfig_after = _load_aiconfig(websocket)
+ assert aiconfig_before == aiconfig_after
+
+
+def test_editor_mock_run_cancel_rollback_2(get_connected_websocket: GetConnectedWebsocketFn, tmp_path: Path):
+ """wait until some mutation happens, then cancel"""
+ connected_websocket = _make_default_connected_websocket(get_connected_websocket, tmp_path)
+ websocket = connected_websocket.websocket
+
+ aiconfig_before = _load_aiconfig(websocket)
+ print(f"{aiconfig_before=}")
+
+ ws_send_command(websocket, "mock_run", dict(seconds=0.3))
+
+ # mock run does its first mutation after 0.1 seconds.
+ # Waiting 0.2 before canceling allows it to happen, testing rollback.
+ time.sleep(0.2)
+ ws_send_command(websocket, "cancel", dict())
+ cancel_resp = ws_receive_response(websocket)
+ print(f"{cancel_resp=}")
+ aiconfig_after = _load_aiconfig(websocket)
+ assert aiconfig_before == aiconfig_after, f"{aiconfig_before=}, {aiconfig_after=}"
+
+
+def test_editor_mock_run_cancel_fast_walltime(get_connected_websocket: GetConnectedWebsocketFn, tmp_path: Path):
+ """wait until some mutation happens, then cancel"""
+ connected_websocket = _make_default_connected_websocket(get_connected_websocket, tmp_path)
+ websocket = connected_websocket.websocket
+
+ # Control: block for 0.3 seconds, don't cancel
+ ts_start = time.time()
+ ws_send_command(websocket, "mock_run", dict(seconds=0.3))
+ _mock_run_resp = ws_receive_response(websocket)
+ ts_end = time.time()
+ s_elapsed = ts_end - ts_start
+ print(f"[control]{s_elapsed=}")
+ assert s_elapsed >= 0.3
+
+ ts_start = time.time()
+ ws_send_command(websocket, "mock_run", dict(seconds=0.3))
+ # Cancel immediately
+ ws_send_command(websocket, "cancel", dict())
+ cancel_resp = ws_receive_response(websocket)
+ ts_end = time.time()
+ s_elapsed = ts_end - ts_start
+ print(f"[cancel]{s_elapsed=}")
+ assert s_elapsed < 0.01
+
+ assert cancel_resp is not None
+ assert cancel_resp["message"] == "Cancelling command"
+ assert cancel_resp["is_success"] == True