Skip to content

Commit

Permalink
Merge pull request #32 from dreadnode/feat/expand-user-defined-models
Browse files Browse the repository at this point in the history
feat: Expand user-defined models
  • Loading branch information
evilsocket authored Jan 13, 2025
2 parents c4f4c64 + 987baea commit dee6828
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 81 deletions.
67 changes: 43 additions & 24 deletions dreadnode_cli/agent/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pathlib
import shutil
import time
Expand All @@ -20,19 +21,21 @@
format_runs,
format_strike_models,
format_strikes,
format_user_models,
)
from dreadnode_cli.agent.templates import cli as templates_cli
from dreadnode_cli.agent.templates.format import format_templates
from dreadnode_cli.agent.templates.manager import TemplateManager
from dreadnode_cli.config import UserConfig, UserModel, UserModels
from dreadnode_cli.api import Client
from dreadnode_cli.config import UserConfig
from dreadnode_cli.model.config import UserModels
from dreadnode_cli.model.format import format_user_models
from dreadnode_cli.profile.cli import switch as switch_profile
from dreadnode_cli.types import GithubRepo
from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli

cli = typer.Typer(no_args_is_help=True)

cli.add_typer(templates_cli, name="templates", help="Interact with Strike templates")
cli.add_typer(templates_cli, name="templates", help="Manage Agent templates")


def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None = None) -> None:
Expand All @@ -48,8 +51,8 @@ def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None
plural = "s" if len(agent_config.linked_profiles) > 1 else ""
raise Exception(
f"This agent is linked to the [magenta]{linked_profiles}[/] server profile{plural}, "
f"but the current server profile is [yellow]{user_config.active_profile_name}[/], ",
"use [bold]dreadnode agent push[/] to create a new link with this profile.",
f"but the current server profile is [yellow]{user_config.active_profile_name}[/], "
"use [bold]dreadnode agent push[/] to create a new link with this profile."
)

if agent_config.active_link.profile != user_config.active_profile_name:
Expand All @@ -70,7 +73,7 @@ def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None
switch_profile(agent_config.active_link.profile)


@cli.command(help="Initialize a new agent project")
@cli.command(help="Initialize a new agent project", no_args_is_help=True)
@pretty_cli
def init(
strike: t.Annotated[str, typer.Argument(help="The target strike")],
Expand Down Expand Up @@ -341,19 +344,34 @@ def deploy(
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")

user_models = UserModels.read()
user_model: UserModel | None = None

# Verify the model if it was supplied
if model is not None:
# check if it's a user model
user_model = next((m for m in user_models.models if m.key == model), None)
if not user_model:
# check if it's a strike model
strike_response = client.get_strike(strike)
if not any(m.key == model for m in strike_response.models):
models(directory, strike=strike)
print()
raise Exception(f"Model '{model}' is not a user model nor was found in strike '{strike_response.name}'")
user_model: Client.UserModel | None = None

# Check for a user-defined model
if model in user_models.models:
user_model = Client.UserModel(
key=model,
generator_id=user_models.models[model].generator_id,
api_key=user_models.models[model].api_key,
)

# Resolve the API key from env vars
if user_model.api_key.startswith("$"):
try:
user_model.api_key = os.environ[user_model.api_key[1:]]
except KeyError as e:
raise Exception(
f"API key cannot be read from '{user_model.api_key}', environment variable not found."
) from e

# Otherwise we'll ensure this is a valid strike-native model
if user_model is None and model is not None:
strike_response = client.get_strike(strike)
if not any(m.key == model for m in strike_response.models):
models(directory, strike=strike)
print()
raise Exception(
f"Model '{model}' is not user-defined nor is it available in strike '{strike_response.name}'"
)

run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model, user_model=user_model)
agent_config.add_run(run.id).write(directory)
Expand All @@ -380,20 +398,21 @@ def models(
) -> None:
user_models = UserModels.read()
if user_models.models:
print("[bold]User models:[/]\n")
print("[bold]User-defined models:[/]\n")
print(format_user_models(user_models.models))
print()

if strike is None:
agent_config = AgentConfig.read(directory)
ensure_profile(agent_config)
strike = agent_config.strike

strike = strike or agent_config.strike
if strike is None:
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")

strike_response = api.create_client().get_strike(strike)
if user_models.models:
print("\n[bold]Strike models:[/]\n")
print("\n[bold]Dreadnode-provided models:[/]\n")
print(format_strike_models(strike_response.models))


Expand Down Expand Up @@ -522,7 +541,7 @@ def links(
print(table)


@cli.command(help="Switch to a different agent link")
@cli.command(help="Switch to a different agent link", no_args_is_help=True)
@pretty_cli
def switch(
agent_or_profile: t.Annotated[str, typer.Argument(help="Agent key/id or profile name")],
Expand All @@ -544,7 +563,7 @@ def switch(
print(f":exclamation: '{agent_or_profile}' not found, use [bold]dreadnode agent links[/]")


@cli.command(help="Clone a github repository")
@cli.command(help="Clone a github repository", no_args_is_help=True)
@pretty_cli
def clone(
repo: t.Annotated[str, typer.Argument(help="Repository name or URL")],
Expand Down
20 changes: 0 additions & 20 deletions dreadnode_cli/agent/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from rich.text import Text

from dreadnode_cli import api
from dreadnode_cli.config import UserModel

P = t.ParamSpec("P")

Expand Down Expand Up @@ -66,25 +65,6 @@ def format_time(dt: datetime | None) -> str:
return dt.astimezone().strftime("%c") if dt else "-"


def format_user_models(models: list[UserModel]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("key")
table.add_column("name")
table.add_column("provider")
table.add_column("api_key")

for model in models:
provider_style = get_model_provider_style(model.provider)
table.add_row(
Text(model.key),
Text(model.name, style=f"bold {provider_style}"),
Text(model.provider, style=provider_style),
Text("yes" if model.api_key else "no", style="green" if model.api_key else "dim"),
)

return table


def format_strike_models(models: list[api.Client.StrikeModel]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("key")
Expand Down
5 changes: 3 additions & 2 deletions dreadnode_cli/agent/templates/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from dreadnode_cli.agent.templates.format import format_templates
from dreadnode_cli.agent.templates.manager import TemplateManager
from dreadnode_cli.defaults import TEMPLATES_DEFAULT_REPO
from dreadnode_cli.ext.typer import AliasGroup
from dreadnode_cli.types import GithubRepo
from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli

cli = typer.Typer(no_args_is_help=True)
cli = typer.Typer(no_args_is_help=True, cls=AliasGroup)


@cli.command(help="List available agent templates with their descriptions")
@cli.command("show|list", help="List available agent templates with their descriptions")
@pretty_cli
def show() -> None:
template_manager = TemplateManager()
Expand Down
7 changes: 6 additions & 1 deletion dreadnode_cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rich import print

from dreadnode_cli import __version__, utils
from dreadnode_cli.config import UserConfig, UserModel
from dreadnode_cli.config import UserConfig
from dreadnode_cli.defaults import (
DEBUG,
DEFAULT_MAX_POLL_TIME,
Expand Down Expand Up @@ -377,6 +377,11 @@ class StrikeRunSummaryResponse(_StrikeRun):
class StrikeRunResponse(_StrikeRun):
zones: list["Client.StrikeRunZone"]

class UserModel(BaseModel):
key: str
generator_id: str
api_key: str

def get_strike(self, strike: str) -> StrikeResponse:
response = self.request("GET", f"/api/strikes/{strike}")
return self.StrikeResponse(**response.json())
Expand Down
2 changes: 2 additions & 0 deletions dreadnode_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dreadnode_cli.challenge import cli as challenge_cli
from dreadnode_cli.config import ServerConfig, UserConfig
from dreadnode_cli.defaults import PLATFORM_BASE_URL
from dreadnode_cli.model import cli as models_cli
from dreadnode_cli.profile import cli as profile_cli
from dreadnode_cli.utils import pretty_cli

Expand All @@ -21,6 +22,7 @@
cli.add_typer(profile_cli, name="profile", help="Manage server profiles")
cli.add_typer(challenge_cli, name="challenge", help="Interact with Crucible challenges")
cli.add_typer(agent_cli, name="agent", help="Interact with Strike agents")
cli.add_typer(models_cli, name="model", help="Manage user-defined inference models")


@cli.command(help="Authenticate to the platform.")
Expand Down
30 changes: 1 addition & 29 deletions dreadnode_cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rich import print
from ruamel.yaml import YAML

from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH, USER_MODELS_CONFIG_PATH
from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH


class ServerConfig(BaseModel):
Expand Down Expand Up @@ -74,31 +74,3 @@ def set_server_config(self, config: ServerConfig, profile: str | None = None) ->
profile = profile or self.active or DEFAULT_PROFILE_NAME
self.servers[profile] = config
return self


class UserModel(BaseModel):
"""
A user defined model.
"""

key: str
name: str
provider: str
generator_id: str
api_key: str | None = None


class UserModels(BaseModel):
"""User models configuration."""

models: list[UserModel] = []

@classmethod
def read(cls) -> "UserModels":
"""Read the user models configuration from the file system or return an empty instance."""

if not USER_MODELS_CONFIG_PATH.exists():
return cls()

with USER_MODELS_CONFIG_PATH.open("r") as f:
return cls.model_validate(YAML().load(f))
21 changes: 21 additions & 0 deletions dreadnode_cli/ext/typer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import re

from click import Command, Context
from typer.core import TyperGroup

# https://github.com/fastapi/typer/issues/132


class AliasGroup(TyperGroup):
_CMD_SPLIT_P = re.compile(r" ?[,|] ?")

def get_command(self, ctx: Context, cmd_name: str) -> Command | None:
cmd_name = self._group_cmd_name(cmd_name)
return super().get_command(ctx, cmd_name)

def _group_cmd_name(self, default_name: str) -> str:
for cmd in self.commands.values():
name = cmd.name
if name and default_name in self._CMD_SPLIT_P.split(name):
return name
return default_name
3 changes: 3 additions & 0 deletions dreadnode_cli/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from dreadnode_cli.model.cli import cli

__all__ = ["cli"]
66 changes: 66 additions & 0 deletions dreadnode_cli/model/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import typing as t

import typer
from rich import print

from dreadnode_cli.defaults import USER_MODELS_CONFIG_PATH
from dreadnode_cli.ext.typer import AliasGroup
from dreadnode_cli.model.config import UserModel, UserModels
from dreadnode_cli.model.format import format_user_models
from dreadnode_cli.utils import pretty_cli

cli = typer.Typer(no_args_is_help=True, cls=AliasGroup)


@cli.command("show|list", help="List all configured models")
@pretty_cli
def show() -> None:
config = UserModels.read()
if not config.models:
print(":exclamation: No models are configured, use [bold]dreadnode models add[/].")
return

print(format_user_models(config.models))


@cli.command(
help="Add a new inference model",
epilog="If $ENV_VAR syntax is used for the api key, it will be replaced with the environment value when used.",
no_args_is_help=True,
)
@pretty_cli
def add(
id: t.Annotated[str, typer.Option("--id", help="Identifier for referencing this model")],
generator_id: t.Annotated[str, typer.Option("--generator-id", "-g", help="Rigging (LiteLLM) generator id")],
api_key: t.Annotated[
str, typer.Option("--api-key", "-k", help="API key for the inference provider (supports $ENV_VAR syntax)")
],
name: t.Annotated[str | None, typer.Option("--name", "-n", help="Friendly name")] = None,
provider: t.Annotated[str | None, typer.Option("--provider", "-p", help="Provider name")] = None,
update: t.Annotated[bool, typer.Option("--update", "-u", help="Update an existing model if it exists")] = False,
) -> None:
config = UserModels.read()
exists = id in config.models

if exists and not update:
print(f":exclamation: Model with id [bold]{id}[/] already exists (use -u/--update to modify)")
return

config.models[id] = UserModel(name=name, provider=provider, generator_id=generator_id, api_key=api_key)
config.write()

print(f":wrench: {'Updated' if exists else 'Added'} model [bold]{id}[/] in {USER_MODELS_CONFIG_PATH}")


@cli.command(help="Remove an user inference model", no_args_is_help=True)
@pretty_cli
def forget(id: t.Annotated[str, typer.Argument(help="Model to remove")]) -> None:
config = UserModels.read()
if id not in config.models:
print(f":exclamation: Model with id [bold]{id}[/] does not exist")
return

del config.models[id]
config.write()

print(f":axe: Forgot about [bold]{id}[/] in {USER_MODELS_CONFIG_PATH}")
Loading

0 comments on commit dee6828

Please sign in to comment.