From 4335a767a47c74dcdfdedf635298e988d158517d Mon Sep 17 00:00:00 2001 From: evilsocket Date: Wed, 8 Jan 2025 23:15:27 +0100 Subject: [PATCH] new: implemented support for user models (ENG-652) --- dreadnode_cli/agent/cli.py | 34 ++++++++++++++++++++++++--------- dreadnode_cli/agent/format.py | 29 +++++++++++++++++++++++++--- dreadnode_cli/api.py | 10 ++++++++-- dreadnode_cli/config.py | 36 +++++++++++++++++++++++++++++++---- dreadnode_cli/defaults.py | 6 ++++++ 5 files changed, 97 insertions(+), 18 deletions(-) diff --git a/dreadnode_cli/agent/cli.py b/dreadnode_cli/agent/cli.py index fc7625f..8aa5570 100644 --- a/dreadnode_cli/agent/cli.py +++ b/dreadnode_cli/agent/cli.py @@ -16,15 +16,16 @@ from dreadnode_cli.agent.format import ( format_agent, format_agent_versions, - format_models, format_run, format_runs, + format_strike_models, format_strikes, + format_user_models, ) from dreadnode_cli.agent.templates import cli as templates_cli from dreadnode_cli.agent.templates.format import format_templates from dreadnode_cli.agent.templates.manager import TemplateManager -from dreadnode_cli.config import UserConfig +from dreadnode_cli.config import UserConfig, UserModel, UserModels from dreadnode_cli.profile.cli import switch as switch_profile from dreadnode_cli.types import GithubRepo from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli @@ -339,14 +340,22 @@ def deploy( if strike is None: raise Exception("No strike specified, use -s/--strike or set the strike in the agent config") + user_models = UserModels.read() + user_model: UserModel | None = None + # Verify the model if it was supplied if model is not None: - strike_response = client.get_strike(strike) - if not any(m.key == model for m in strike_response.models): - print(format_models(strike_response.models)) - raise Exception(f"Model '{model}' not found in strike '{strike_response.name}'") - - run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model) + # check if it's a user model + user_model = next((m for m in user_models.models if m.key == model), None) + if not user_model: + # check if it's a strike model + strike_response = client.get_strike(strike) + if not any(m.key == model for m in strike_response.models): + models(directory, strike=strike) + print() + raise Exception(f"Model '{model}' is not a user model nor was found in strike '{strike_response.name}'") + + run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model, user_model=user_model) agent_config.add_run(run.id).write(directory) formatted = format_run(run) @@ -369,6 +378,11 @@ def models( ] = pathlib.Path("."), strike: t.Annotated[str | None, typer.Option("--strike", "-s", help="The strike to query")] = None, ) -> None: + user_models = UserModels.read() + if user_models.models: + print("[bold]User models:[/]\n") + print(format_user_models(user_models.models)) + if strike is None: agent_config = AgentConfig.read(directory) ensure_profile(agent_config) @@ -378,7 +392,9 @@ def models( raise Exception("No strike specified, use -s/--strike or set the strike in the agent config") strike_response = api.create_client().get_strike(strike) - print(format_models(strike_response.models)) + if user_models.models: + print("\n[bold]Strike models:[/]\n") + print(format_strike_models(strike_response.models)) @cli.command(help="List available strikes") diff --git a/dreadnode_cli/agent/format.py b/dreadnode_cli/agent/format.py index f668c1f..3f71362 100644 --- a/dreadnode_cli/agent/format.py +++ b/dreadnode_cli/agent/format.py @@ -9,9 +9,13 @@ from rich.text import Text from dreadnode_cli import api +from dreadnode_cli.config import UserModel P = t.ParamSpec("P") +# um@ is added to indicate a user model +USER_MODEL_PREFIX: str = "um@" + def get_status_style(status: api.Client.StrikeRunStatus | api.Client.StrikeRunZoneStatus | None) -> str: return ( @@ -62,7 +66,26 @@ def format_time(dt: datetime | None) -> str: return dt.astimezone().strftime("%c") if dt else "-" -def format_models(models: list[api.Client.StrikeModel]) -> RenderableType: +def format_user_models(models: list[UserModel]) -> RenderableType: + table = Table(box=box.ROUNDED) + table.add_column("key") + table.add_column("name") + table.add_column("provider") + table.add_column("api_key") + + for model in models: + provider_style = get_model_provider_style(model.provider) + table.add_row( + Text(model.key), + Text(model.name, style=f"bold {provider_style}"), + Text(model.provider, style=provider_style), + Text("yes" if model.api_key else "no", style="green" if model.api_key else "dim"), + ) + + return table + + +def format_strike_models(models: list[api.Client.StrikeModel]) -> RenderableType: table = Table(box=box.ROUNDED) table.add_column("key") table.add_column("name") @@ -272,7 +295,7 @@ def format_run(run: api.Client.StrikeRunResponse, *, verbose: bool = False, incl agent_name = f"[bold magenta]{run.agent_key}[/]" table.add_row("", "") - table.add_row("model", run.model or "") + table.add_row("model", run.model.replace(USER_MODEL_PREFIX, "") if run.model else "") table.add_row("agent", f"{agent_name} ([dim]rev[/] [yellow]{run.agent_revision}[/])") table.add_row("image", Text(run.agent_version.container.image, style="cyan")) table.add_row("notes", run.agent_version.notes or "-") @@ -304,7 +327,7 @@ def format_runs(runs: list[api.Client.StrikeRunSummaryResponse]) -> RenderableTy str(run.id), f"[bold magenta]{run.agent_key}[/] [dim]:[/] [yellow]{run.agent_revision}[/]", Text(run.status, style="bold " + get_status_style(run.status)), - Text(run.model or "-"), + Text(run.model.replace(USER_MODEL_PREFIX, "") if run.model else "-"), format_time(run.start), Text(format_duration(run.start, run.end), style="bold cyan"), ) diff --git a/dreadnode_cli/api.py b/dreadnode_cli/api.py index 48b8d28..a7a9f3d 100644 --- a/dreadnode_cli/api.py +++ b/dreadnode_cli/api.py @@ -11,7 +11,7 @@ from rich import print from dreadnode_cli import __version__, utils -from dreadnode_cli.config import UserConfig +from dreadnode_cli.config import UserConfig, UserModel from dreadnode_cli.defaults import ( DEBUG, DEFAULT_MAX_POLL_TIME, @@ -430,7 +430,12 @@ def create_strike_agent_version( return self.StrikeAgentResponse(**response.json()) def start_strike_run( - self, agent_version_id: UUID, *, model: str | None = None, strike: UUID | str | None = None + self, + agent_version_id: UUID, + *, + model: str | None = None, + user_model: UserModel | None = None, + strike: UUID | str | None = None, ) -> StrikeRunResponse: response = self.request( "POST", @@ -438,6 +443,7 @@ def start_strike_run( json_data={ "agent_version_id": str(agent_version_id), "model": model, + "user_model": user_model.model_dump(mode="json") if user_model else None, "strike": str(strike) if strike else None, }, ) diff --git a/dreadnode_cli/config.py b/dreadnode_cli/config.py index b5c5f5f..b9c901a 100644 --- a/dreadnode_cli/config.py +++ b/dreadnode_cli/config.py @@ -1,11 +1,11 @@ -import pydantic +from pydantic import BaseModel from rich import print from ruamel.yaml import YAML -from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH +from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH, USER_MODELS_CONFIG_PATH -class ServerConfig(pydantic.BaseModel): +class ServerConfig(BaseModel): """Server specific authentication data and API URL.""" url: str @@ -16,7 +16,7 @@ class ServerConfig(pydantic.BaseModel): refresh_token: str -class UserConfig(pydantic.BaseModel): +class UserConfig(BaseModel): """User configuration supporting multiple server profiles.""" active: str | None = None @@ -74,3 +74,31 @@ def set_server_config(self, config: ServerConfig, profile: str | None = None) -> profile = profile or self.active or DEFAULT_PROFILE_NAME self.servers[profile] = config return self + + +class UserModel(BaseModel): + """ + A user defined model. + """ + + key: str + name: str + provider: str + generator_id: str + api_key: str | None = None + + +class UserModels(BaseModel): + """User models configuration.""" + + models: list[UserModel] = [] + + @classmethod + def read(cls) -> "UserModels": + """Read the user models configuration from the file system or return an empty instance.""" + + if not USER_MODELS_CONFIG_PATH.exists(): + return cls() + + with USER_MODELS_CONFIG_PATH.open("r") as f: + return cls.model_validate(YAML().load(f)) diff --git a/dreadnode_cli/defaults.py b/dreadnode_cli/defaults.py index aab6b78..46607cd 100644 --- a/dreadnode_cli/defaults.py +++ b/dreadnode_cli/defaults.py @@ -25,6 +25,12 @@ os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "config" ) +# path to the user models configuration file +USER_MODELS_CONFIG_PATH = pathlib.Path( + # allow overriding the user config file via env variable + os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "models.yml" +) + # path to the templates directory TEMPLATES_PATH = pathlib.Path( # allow overriding the templates path via env variable