diff --git a/dreadnode_cli/agent/cli.py b/dreadnode_cli/agent/cli.py index 8aa5570..c22dc8e 100644 --- a/dreadnode_cli/agent/cli.py +++ b/dreadnode_cli/agent/cli.py @@ -1,3 +1,4 @@ +import os import pathlib import shutil import time @@ -20,19 +21,21 @@ format_runs, format_strike_models, format_strikes, - format_user_models, ) from dreadnode_cli.agent.templates import cli as templates_cli from dreadnode_cli.agent.templates.format import format_templates from dreadnode_cli.agent.templates.manager import TemplateManager -from dreadnode_cli.config import UserConfig, UserModel, UserModels +from dreadnode_cli.api import Client +from dreadnode_cli.config import UserConfig +from dreadnode_cli.model.config import UserModels +from dreadnode_cli.model.format import format_user_models from dreadnode_cli.profile.cli import switch as switch_profile from dreadnode_cli.types import GithubRepo from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli cli = typer.Typer(no_args_is_help=True) -cli.add_typer(templates_cli, name="templates", help="Interact with Strike templates") +cli.add_typer(templates_cli, name="templates", help="Manage Agent templates") def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None = None) -> None: @@ -48,8 +51,8 @@ def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None plural = "s" if len(agent_config.linked_profiles) > 1 else "" raise Exception( f"This agent is linked to the [magenta]{linked_profiles}[/] server profile{plural}, " - f"but the current server profile is [yellow]{user_config.active_profile_name}[/], ", - "use [bold]dreadnode agent push[/] to create a new link with this profile.", + f"but the current server profile is [yellow]{user_config.active_profile_name}[/], " + "use [bold]dreadnode agent push[/] to create a new link with this profile." ) if agent_config.active_link.profile != user_config.active_profile_name: @@ -70,7 +73,7 @@ def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None switch_profile(agent_config.active_link.profile) -@cli.command(help="Initialize a new agent project") +@cli.command(help="Initialize a new agent project", no_args_is_help=True) @pretty_cli def init( strike: t.Annotated[str, typer.Argument(help="The target strike")], @@ -341,19 +344,34 @@ def deploy( raise Exception("No strike specified, use -s/--strike or set the strike in the agent config") user_models = UserModels.read() - user_model: UserModel | None = None - - # Verify the model if it was supplied - if model is not None: - # check if it's a user model - user_model = next((m for m in user_models.models if m.key == model), None) - if not user_model: - # check if it's a strike model - strike_response = client.get_strike(strike) - if not any(m.key == model for m in strike_response.models): - models(directory, strike=strike) - print() - raise Exception(f"Model '{model}' is not a user model nor was found in strike '{strike_response.name}'") + user_model: Client.UserModel | None = None + + # Check for a user-defined model + if model in user_models.models: + user_model = Client.UserModel( + key=model, + generator_id=user_models.models[model].generator_id, + api_key=user_models.models[model].api_key, + ) + + # Resolve the API key from env vars + if user_model.api_key.startswith("$"): + try: + user_model.api_key = os.environ[user_model.api_key[1:]] + except KeyError as e: + raise Exception( + f"API key cannot be read from '{user_model.api_key}', environment variable not found." + ) from e + + # Otherwise we'll ensure this is a valid strike-native model + if user_model is None and model is not None: + strike_response = client.get_strike(strike) + if not any(m.key == model for m in strike_response.models): + models(directory, strike=strike) + print() + raise Exception( + f"Model '{model}' is not user-defined nor is it available in strike '{strike_response.name}'" + ) run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model, user_model=user_model) agent_config.add_run(run.id).write(directory) @@ -380,20 +398,21 @@ def models( ) -> None: user_models = UserModels.read() if user_models.models: - print("[bold]User models:[/]\n") + print("[bold]User-defined models:[/]\n") print(format_user_models(user_models.models)) + print() if strike is None: agent_config = AgentConfig.read(directory) ensure_profile(agent_config) + strike = agent_config.strike - strike = strike or agent_config.strike if strike is None: raise Exception("No strike specified, use -s/--strike or set the strike in the agent config") strike_response = api.create_client().get_strike(strike) if user_models.models: - print("\n[bold]Strike models:[/]\n") + print("\n[bold]Dreadnode-provided models:[/]\n") print(format_strike_models(strike_response.models)) @@ -522,7 +541,7 @@ def links( print(table) -@cli.command(help="Switch to a different agent link") +@cli.command(help="Switch to a different agent link", no_args_is_help=True) @pretty_cli def switch( agent_or_profile: t.Annotated[str, typer.Argument(help="Agent key/id or profile name")], @@ -544,7 +563,7 @@ def switch( print(f":exclamation: '{agent_or_profile}' not found, use [bold]dreadnode agent links[/]") -@cli.command(help="Clone a github repository") +@cli.command(help="Clone a github repository", no_args_is_help=True) @pretty_cli def clone( repo: t.Annotated[str, typer.Argument(help="Repository name or URL")], diff --git a/dreadnode_cli/agent/format.py b/dreadnode_cli/agent/format.py index 3f71362..3d6da46 100644 --- a/dreadnode_cli/agent/format.py +++ b/dreadnode_cli/agent/format.py @@ -9,7 +9,6 @@ from rich.text import Text from dreadnode_cli import api -from dreadnode_cli.config import UserModel P = t.ParamSpec("P") @@ -66,25 +65,6 @@ def format_time(dt: datetime | None) -> str: return dt.astimezone().strftime("%c") if dt else "-" -def format_user_models(models: list[UserModel]) -> RenderableType: - table = Table(box=box.ROUNDED) - table.add_column("key") - table.add_column("name") - table.add_column("provider") - table.add_column("api_key") - - for model in models: - provider_style = get_model_provider_style(model.provider) - table.add_row( - Text(model.key), - Text(model.name, style=f"bold {provider_style}"), - Text(model.provider, style=provider_style), - Text("yes" if model.api_key else "no", style="green" if model.api_key else "dim"), - ) - - return table - - def format_strike_models(models: list[api.Client.StrikeModel]) -> RenderableType: table = Table(box=box.ROUNDED) table.add_column("key") diff --git a/dreadnode_cli/agent/templates/cli.py b/dreadnode_cli/agent/templates/cli.py index fa872a6..9c77105 100644 --- a/dreadnode_cli/agent/templates/cli.py +++ b/dreadnode_cli/agent/templates/cli.py @@ -9,13 +9,14 @@ from dreadnode_cli.agent.templates.format import format_templates from dreadnode_cli.agent.templates.manager import TemplateManager from dreadnode_cli.defaults import TEMPLATES_DEFAULT_REPO +from dreadnode_cli.ext.typer import AliasGroup from dreadnode_cli.types import GithubRepo from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli -cli = typer.Typer(no_args_is_help=True) +cli = typer.Typer(no_args_is_help=True, cls=AliasGroup) -@cli.command(help="List available agent templates with their descriptions") +@cli.command("show|list", help="List available agent templates with their descriptions") @pretty_cli def show() -> None: template_manager = TemplateManager() diff --git a/dreadnode_cli/api.py b/dreadnode_cli/api.py index a7a9f3d..abf9335 100644 --- a/dreadnode_cli/api.py +++ b/dreadnode_cli/api.py @@ -11,7 +11,7 @@ from rich import print from dreadnode_cli import __version__, utils -from dreadnode_cli.config import UserConfig, UserModel +from dreadnode_cli.config import UserConfig from dreadnode_cli.defaults import ( DEBUG, DEFAULT_MAX_POLL_TIME, @@ -377,6 +377,11 @@ class StrikeRunSummaryResponse(_StrikeRun): class StrikeRunResponse(_StrikeRun): zones: list["Client.StrikeRunZone"] + class UserModel(BaseModel): + key: str + generator_id: str + api_key: str + def get_strike(self, strike: str) -> StrikeResponse: response = self.request("GET", f"/api/strikes/{strike}") return self.StrikeResponse(**response.json()) diff --git a/dreadnode_cli/cli.py b/dreadnode_cli/cli.py index 3423fe2..433380f 100644 --- a/dreadnode_cli/cli.py +++ b/dreadnode_cli/cli.py @@ -9,6 +9,7 @@ from dreadnode_cli.challenge import cli as challenge_cli from dreadnode_cli.config import ServerConfig, UserConfig from dreadnode_cli.defaults import PLATFORM_BASE_URL +from dreadnode_cli.model import cli as models_cli from dreadnode_cli.profile import cli as profile_cli from dreadnode_cli.utils import pretty_cli @@ -21,6 +22,7 @@ cli.add_typer(profile_cli, name="profile", help="Manage server profiles") cli.add_typer(challenge_cli, name="challenge", help="Interact with Crucible challenges") cli.add_typer(agent_cli, name="agent", help="Interact with Strike agents") +cli.add_typer(models_cli, name="model", help="Manage user-defined inference models") @cli.command(help="Authenticate to the platform.") diff --git a/dreadnode_cli/config.py b/dreadnode_cli/config.py index b9c901a..0fb0b9b 100644 --- a/dreadnode_cli/config.py +++ b/dreadnode_cli/config.py @@ -2,7 +2,7 @@ from rich import print from ruamel.yaml import YAML -from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH, USER_MODELS_CONFIG_PATH +from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH class ServerConfig(BaseModel): @@ -74,31 +74,3 @@ def set_server_config(self, config: ServerConfig, profile: str | None = None) -> profile = profile or self.active or DEFAULT_PROFILE_NAME self.servers[profile] = config return self - - -class UserModel(BaseModel): - """ - A user defined model. - """ - - key: str - name: str - provider: str - generator_id: str - api_key: str | None = None - - -class UserModels(BaseModel): - """User models configuration.""" - - models: list[UserModel] = [] - - @classmethod - def read(cls) -> "UserModels": - """Read the user models configuration from the file system or return an empty instance.""" - - if not USER_MODELS_CONFIG_PATH.exists(): - return cls() - - with USER_MODELS_CONFIG_PATH.open("r") as f: - return cls.model_validate(YAML().load(f)) diff --git a/dreadnode_cli/ext/typer.py b/dreadnode_cli/ext/typer.py new file mode 100644 index 0000000..f5dd18f --- /dev/null +++ b/dreadnode_cli/ext/typer.py @@ -0,0 +1,21 @@ +import re + +from click import Command, Context +from typer.core import TyperGroup + +# https://github.com/fastapi/typer/issues/132 + + +class AliasGroup(TyperGroup): + _CMD_SPLIT_P = re.compile(r" ?[,|] ?") + + def get_command(self, ctx: Context, cmd_name: str) -> Command | None: + cmd_name = self._group_cmd_name(cmd_name) + return super().get_command(ctx, cmd_name) + + def _group_cmd_name(self, default_name: str) -> str: + for cmd in self.commands.values(): + name = cmd.name + if name and default_name in self._CMD_SPLIT_P.split(name): + return name + return default_name diff --git a/dreadnode_cli/model/__init__.py b/dreadnode_cli/model/__init__.py new file mode 100644 index 0000000..c00b94a --- /dev/null +++ b/dreadnode_cli/model/__init__.py @@ -0,0 +1,3 @@ +from dreadnode_cli.model.cli import cli + +__all__ = ["cli"] diff --git a/dreadnode_cli/model/cli.py b/dreadnode_cli/model/cli.py new file mode 100644 index 0000000..44f8581 --- /dev/null +++ b/dreadnode_cli/model/cli.py @@ -0,0 +1,66 @@ +import typing as t + +import typer +from rich import print + +from dreadnode_cli.defaults import USER_MODELS_CONFIG_PATH +from dreadnode_cli.ext.typer import AliasGroup +from dreadnode_cli.model.config import UserModel, UserModels +from dreadnode_cli.model.format import format_user_models +from dreadnode_cli.utils import pretty_cli + +cli = typer.Typer(no_args_is_help=True, cls=AliasGroup) + + +@cli.command("show|list", help="List all configured models") +@pretty_cli +def show() -> None: + config = UserModels.read() + if not config.models: + print(":exclamation: No models are configured, use [bold]dreadnode models add[/].") + return + + print(format_user_models(config.models)) + + +@cli.command( + help="Add a new inference model", + epilog="If $ENV_VAR syntax is used for the api key, it will be replaced with the environment value when used.", + no_args_is_help=True, +) +@pretty_cli +def add( + id: t.Annotated[str, typer.Option("--id", help="Identifier for referencing this model")], + generator_id: t.Annotated[str, typer.Option("--generator-id", "-g", help="Rigging (LiteLLM) generator id")], + api_key: t.Annotated[ + str, typer.Option("--api-key", "-k", help="API key for the inference provider (supports $ENV_VAR syntax)") + ], + name: t.Annotated[str | None, typer.Option("--name", "-n", help="Friendly name")] = None, + provider: t.Annotated[str | None, typer.Option("--provider", "-p", help="Provider name")] = None, + update: t.Annotated[bool, typer.Option("--update", "-u", help="Update an existing model if it exists")] = False, +) -> None: + config = UserModels.read() + exists = id in config.models + + if exists and not update: + print(f":exclamation: Model with id [bold]{id}[/] already exists (use -u/--update to modify)") + return + + config.models[id] = UserModel(name=name, provider=provider, generator_id=generator_id, api_key=api_key) + config.write() + + print(f":wrench: {'Updated' if exists else 'Added'} model [bold]{id}[/] in {USER_MODELS_CONFIG_PATH}") + + +@cli.command(help="Remove an user inference model", no_args_is_help=True) +@pretty_cli +def forget(id: t.Annotated[str, typer.Argument(help="Model to remove")]) -> None: + config = UserModels.read() + if id not in config.models: + print(f":exclamation: Model with id [bold]{id}[/] does not exist") + return + + del config.models[id] + config.write() + + print(f":axe: Forgot about [bold]{id}[/] in {USER_MODELS_CONFIG_PATH}") diff --git a/dreadnode_cli/model/config.py b/dreadnode_cli/model/config.py new file mode 100644 index 0000000..8d86359 --- /dev/null +++ b/dreadnode_cli/model/config.py @@ -0,0 +1,48 @@ +from pydantic import BaseModel, field_validator +from rich import print +from ruamel.yaml import YAML + +from dreadnode_cli.defaults import USER_MODELS_CONFIG_PATH + + +class UserModel(BaseModel): + """ + A user defined inference model. + """ + + name: str | None = None + provider: str | None = None + generator_id: str + api_key: str + + @field_validator("generator_id", mode="after") + def check_for_api_key_in_generator_id(cls, value: str) -> str: + """Print a warning if an API key is included in the generator ID.""" + + if ",api_key=" in value: + print(f":heavy_exclamation_mark: API keys should not be included in generator ids: [bold]{value}[/]") + print() + + return value + + +class UserModels(BaseModel): + """User models configuration.""" + + models: dict[str, UserModel] = {} + + @classmethod + def read(cls) -> "UserModels": + """Read the user models configuration from the file system or return an empty instance.""" + + if not USER_MODELS_CONFIG_PATH.exists(): + return cls() + + with USER_MODELS_CONFIG_PATH.open("r") as f: + return cls.model_validate(YAML().load(f)) + + def write(self) -> None: + """Write the user models configuration to the file system.""" + + with USER_MODELS_CONFIG_PATH.open("w") as f: + YAML().dump(self.model_dump(mode="json", exclude_none=True), f) diff --git a/dreadnode_cli/model/format.py b/dreadnode_cli/model/format.py new file mode 100644 index 0000000..e632fda --- /dev/null +++ b/dreadnode_cli/model/format.py @@ -0,0 +1,36 @@ +import typing as t + +from rich import box +from rich.console import RenderableType +from rich.table import Table +from rich.text import Text + +from dreadnode_cli.model.config import UserModel + +P = t.ParamSpec("P") + + +def format_api_key(api_key: str) -> RenderableType: + if api_key.startswith("$"): # Environment variable + return Text(api_key, style="blue") + return Text(api_key[:5] + "***" if len(api_key) > 5 else "***", style="magenta") + + +def format_user_models(models: dict[str, UserModel]) -> RenderableType: + table = Table(box=box.ROUNDED) + table.add_column("ID", style="bold cyan") + table.add_column("Name") + table.add_column("Provider") + table.add_column("Generator ID") + table.add_column("API Key") + + for model_id, model in models.items(): + table.add_row( + Text(model_id, style="bold"), + Text(model.name or "-", style="dim" if not model.name else ""), + Text(model.provider or "-", style="dim" if not model.provider else ""), + Text(model.generator_id), + format_api_key(model.api_key), + ) + + return table diff --git a/dreadnode_cli/profile/cli.py b/dreadnode_cli/profile/cli.py index c60e747..e81d06f 100644 --- a/dreadnode_cli/profile/cli.py +++ b/dreadnode_cli/profile/cli.py @@ -7,14 +7,15 @@ from dreadnode_cli import utils from dreadnode_cli.api import Token from dreadnode_cli.config import UserConfig +from dreadnode_cli.ext.typer import AliasGroup from dreadnode_cli.utils import pretty_cli -cli = typer.Typer(no_args_is_help=True) +cli = typer.Typer(no_args_is_help=True, cls=AliasGroup) -@cli.command(help="List all server profiles") +@cli.command("show|list", help="List all server profiles") @pretty_cli -def list() -> None: +def show() -> None: config = UserConfig.read() if not config.servers: print(":exclamation: No server profiles are configured") @@ -45,7 +46,7 @@ def list() -> None: print(table) -@cli.command(help="Set the active server profile") +@cli.command(help="Set the active server profile", no_args_is_help=True) @pretty_cli def switch(profile: t.Annotated[str, typer.Argument(help="Profile to switch to")]) -> None: config = UserConfig.read() @@ -63,7 +64,7 @@ def switch(profile: t.Annotated[str, typer.Argument(help="Profile to switch to") print() -@cli.command(help="Remove a server profile") +@cli.command(help="Remove a server profile", no_args_is_help=True) @pretty_cli def forget(profile: t.Annotated[str, typer.Argument(help="Profile of the server to remove")]) -> None: config = UserConfig.read()