Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Expand user-defined models #32

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 43 additions & 24 deletions dreadnode_cli/agent/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pathlib
import shutil
import time
Expand All @@ -20,19 +21,21 @@
format_runs,
format_strike_models,
format_strikes,
format_user_models,
)
from dreadnode_cli.agent.templates import cli as templates_cli
from dreadnode_cli.agent.templates.format import format_templates
from dreadnode_cli.agent.templates.manager import TemplateManager
from dreadnode_cli.config import UserConfig, UserModel, UserModels
from dreadnode_cli.api import Client
from dreadnode_cli.config import UserConfig
from dreadnode_cli.model.config import UserModels
from dreadnode_cli.model.format import format_user_models
from dreadnode_cli.profile.cli import switch as switch_profile
from dreadnode_cli.types import GithubRepo
from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli

cli = typer.Typer(no_args_is_help=True)

cli.add_typer(templates_cli, name="templates", help="Interact with Strike templates")
cli.add_typer(templates_cli, name="templates", help="Manage Agent templates")


def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None = None) -> None:
Expand All @@ -48,8 +51,8 @@ def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None
plural = "s" if len(agent_config.linked_profiles) > 1 else ""
raise Exception(
f"This agent is linked to the [magenta]{linked_profiles}[/] server profile{plural}, "
f"but the current server profile is [yellow]{user_config.active_profile_name}[/], ",
"use [bold]dreadnode agent push[/] to create a new link with this profile.",
f"but the current server profile is [yellow]{user_config.active_profile_name}[/], "
"use [bold]dreadnode agent push[/] to create a new link with this profile."
)

if agent_config.active_link.profile != user_config.active_profile_name:
Expand All @@ -70,7 +73,7 @@ def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None
switch_profile(agent_config.active_link.profile)


@cli.command(help="Initialize a new agent project")
@cli.command(help="Initialize a new agent project", no_args_is_help=True)
@pretty_cli
def init(
strike: t.Annotated[str, typer.Argument(help="The target strike")],
Expand Down Expand Up @@ -341,19 +344,34 @@ def deploy(
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")

user_models = UserModels.read()
user_model: UserModel | None = None

# Verify the model if it was supplied
if model is not None:
# check if it's a user model
user_model = next((m for m in user_models.models if m.key == model), None)
if not user_model:
# check if it's a strike model
strike_response = client.get_strike(strike)
if not any(m.key == model for m in strike_response.models):
models(directory, strike=strike)
print()
raise Exception(f"Model '{model}' is not a user model nor was found in strike '{strike_response.name}'")
user_model: Client.UserModel | None = None

# Check for a user-defined model
if model in user_models.models:
user_model = Client.UserModel(
key=model,
generator_id=user_models.models[model].generator_id,
api_key=user_models.models[model].api_key,
)

# Resolve the API key from env vars
if user_model.api_key.startswith("$"):
try:
user_model.api_key = os.environ[user_model.api_key[1:]]
except KeyError as e:
raise Exception(
f"API key cannot be read from '{user_model.api_key}', environment variable not found."
) from e

# Otherwise we'll ensure this is a valid strike-native model
if user_model is None and model is not None:
strike_response = client.get_strike(strike)
if not any(m.key == model for m in strike_response.models):
models(directory, strike=strike)
print()
raise Exception(
f"Model '{model}' is not user-defined nor is it available in strike '{strike_response.name}'"
)

run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model, user_model=user_model)
agent_config.add_run(run.id).write(directory)
Expand All @@ -380,20 +398,21 @@ def models(
) -> None:
user_models = UserModels.read()
if user_models.models:
print("[bold]User models:[/]\n")
print("[bold]User-defined models:[/]\n")
print(format_user_models(user_models.models))
print()

if strike is None:
agent_config = AgentConfig.read(directory)
ensure_profile(agent_config)
strike = agent_config.strike

strike = strike or agent_config.strike
if strike is None:
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")

strike_response = api.create_client().get_strike(strike)
if user_models.models:
print("\n[bold]Strike models:[/]\n")
print("\n[bold]Dreadnode-provided models:[/]\n")
print(format_strike_models(strike_response.models))


Expand Down Expand Up @@ -522,7 +541,7 @@ def links(
print(table)


@cli.command(help="Switch to a different agent link")
@cli.command(help="Switch to a different agent link", no_args_is_help=True)
@pretty_cli
def switch(
agent_or_profile: t.Annotated[str, typer.Argument(help="Agent key/id or profile name")],
Expand All @@ -544,7 +563,7 @@ def switch(
print(f":exclamation: '{agent_or_profile}' not found, use [bold]dreadnode agent links[/]")


@cli.command(help="Clone a github repository")
@cli.command(help="Clone a github repository", no_args_is_help=True)
@pretty_cli
def clone(
repo: t.Annotated[str, typer.Argument(help="Repository name or URL")],
Expand Down
20 changes: 0 additions & 20 deletions dreadnode_cli/agent/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from rich.text import Text

from dreadnode_cli import api
from dreadnode_cli.config import UserModel

P = t.ParamSpec("P")

Expand Down Expand Up @@ -66,25 +65,6 @@ def format_time(dt: datetime | None) -> str:
return dt.astimezone().strftime("%c") if dt else "-"


def format_user_models(models: list[UserModel]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("key")
table.add_column("name")
table.add_column("provider")
table.add_column("api_key")

for model in models:
provider_style = get_model_provider_style(model.provider)
table.add_row(
Text(model.key),
Text(model.name, style=f"bold {provider_style}"),
Text(model.provider, style=provider_style),
Text("yes" if model.api_key else "no", style="green" if model.api_key else "dim"),
)

return table


def format_strike_models(models: list[api.Client.StrikeModel]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("key")
Expand Down
5 changes: 3 additions & 2 deletions dreadnode_cli/agent/templates/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from dreadnode_cli.agent.templates.format import format_templates
from dreadnode_cli.agent.templates.manager import TemplateManager
from dreadnode_cli.defaults import TEMPLATES_DEFAULT_REPO
from dreadnode_cli.ext.typer import AliasGroup
from dreadnode_cli.types import GithubRepo
from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli

cli = typer.Typer(no_args_is_help=True)
cli = typer.Typer(no_args_is_help=True, cls=AliasGroup)


@cli.command(help="List available agent templates with their descriptions")
@cli.command("show|list", help="List available agent templates with their descriptions")
@pretty_cli
def show() -> None:
template_manager = TemplateManager()
Expand Down
7 changes: 6 additions & 1 deletion dreadnode_cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rich import print

from dreadnode_cli import __version__, utils
from dreadnode_cli.config import UserConfig, UserModel
from dreadnode_cli.config import UserConfig
from dreadnode_cli.defaults import (
DEBUG,
DEFAULT_MAX_POLL_TIME,
Expand Down Expand Up @@ -377,6 +377,11 @@ class StrikeRunSummaryResponse(_StrikeRun):
class StrikeRunResponse(_StrikeRun):
zones: list["Client.StrikeRunZone"]

class UserModel(BaseModel):
key: str
generator_id: str
api_key: str

def get_strike(self, strike: str) -> StrikeResponse:
response = self.request("GET", f"/api/strikes/{strike}")
return self.StrikeResponse(**response.json())
Expand Down
2 changes: 2 additions & 0 deletions dreadnode_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dreadnode_cli.challenge import cli as challenge_cli
from dreadnode_cli.config import ServerConfig, UserConfig
from dreadnode_cli.defaults import PLATFORM_BASE_URL
from dreadnode_cli.model import cli as models_cli
from dreadnode_cli.profile import cli as profile_cli
from dreadnode_cli.utils import pretty_cli

Expand All @@ -21,6 +22,7 @@
cli.add_typer(profile_cli, name="profile", help="Manage server profiles")
cli.add_typer(challenge_cli, name="challenge", help="Interact with Crucible challenges")
cli.add_typer(agent_cli, name="agent", help="Interact with Strike agents")
cli.add_typer(models_cli, name="model", help="Manage user-defined inference models")


@cli.command(help="Authenticate to the platform.")
Expand Down
30 changes: 1 addition & 29 deletions dreadnode_cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rich import print
from ruamel.yaml import YAML

from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH, USER_MODELS_CONFIG_PATH
from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH


class ServerConfig(BaseModel):
Expand Down Expand Up @@ -74,31 +74,3 @@ def set_server_config(self, config: ServerConfig, profile: str | None = None) ->
profile = profile or self.active or DEFAULT_PROFILE_NAME
self.servers[profile] = config
return self


class UserModel(BaseModel):
"""
A user defined model.
"""

key: str
name: str
provider: str
generator_id: str
api_key: str | None = None


class UserModels(BaseModel):
"""User models configuration."""

models: list[UserModel] = []

@classmethod
def read(cls) -> "UserModels":
"""Read the user models configuration from the file system or return an empty instance."""

if not USER_MODELS_CONFIG_PATH.exists():
return cls()

with USER_MODELS_CONFIG_PATH.open("r") as f:
return cls.model_validate(YAML().load(f))
21 changes: 21 additions & 0 deletions dreadnode_cli/ext/typer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import re

from click import Command, Context
from typer.core import TyperGroup

# https://github.com/fastapi/typer/issues/132


class AliasGroup(TyperGroup):
_CMD_SPLIT_P = re.compile(r" ?[,|] ?")

def get_command(self, ctx: Context, cmd_name: str) -> Command | None:
cmd_name = self._group_cmd_name(cmd_name)
return super().get_command(ctx, cmd_name)

def _group_cmd_name(self, default_name: str) -> str:
for cmd in self.commands.values():
name = cmd.name
if name and default_name in self._CMD_SPLIT_P.split(name):
return name
return default_name
3 changes: 3 additions & 0 deletions dreadnode_cli/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from dreadnode_cli.model.cli import cli

__all__ = ["cli"]
66 changes: 66 additions & 0 deletions dreadnode_cli/model/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import typing as t

import typer
from rich import print

from dreadnode_cli.defaults import USER_MODELS_CONFIG_PATH
from dreadnode_cli.ext.typer import AliasGroup
from dreadnode_cli.model.config import UserModel, UserModels
from dreadnode_cli.model.format import format_user_models
from dreadnode_cli.utils import pretty_cli

cli = typer.Typer(no_args_is_help=True, cls=AliasGroup)


@cli.command("show|list", help="List all configured models")
@pretty_cli
def show() -> None:
config = UserModels.read()
if not config.models:
print(":exclamation: No models are configured, use [bold]dreadnode models add[/].")
return

print(format_user_models(config.models))


@cli.command(
help="Add a new inference model",
epilog="If $ENV_VAR syntax is used for the api key, it will be replaced with the environment value when used.",
no_args_is_help=True,
)
@pretty_cli
def add(
id: t.Annotated[str, typer.Option("--id", help="Identifier for referencing this model")],
generator_id: t.Annotated[str, typer.Option("--generator-id", "-g", help="Rigging (LiteLLM) generator id")],
api_key: t.Annotated[
str, typer.Option("--api-key", "-k", help="API key for the inference provider (supports $ENV_VAR syntax)")
],
name: t.Annotated[str | None, typer.Option("--name", "-n", help="Friendly name")] = None,
provider: t.Annotated[str | None, typer.Option("--provider", "-p", help="Provider name")] = None,
update: t.Annotated[bool, typer.Option("--update", "-u", help="Update an existing model if it exists")] = False,
) -> None:
config = UserModels.read()
exists = id in config.models

if exists and not update:
print(f":exclamation: Model with id [bold]{id}[/] already exists (use -u/--update to modify)")
return

config.models[id] = UserModel(name=name, provider=provider, generator_id=generator_id, api_key=api_key)
config.write()

print(f":wrench: {'Updated' if exists else 'Added'} model [bold]{id}[/] in {USER_MODELS_CONFIG_PATH}")


@cli.command(help="Remove an user inference model", no_args_is_help=True)
@pretty_cli
def forget(id: t.Annotated[str, typer.Argument(help="Model to remove")]) -> None:
config = UserModels.read()
if id not in config.models:
print(f":exclamation: Model with id [bold]{id}[/] does not exist")
return

del config.models[id]
config.write()

print(f":axe: Forgot about [bold]{id}[/] in {USER_MODELS_CONFIG_PATH}")
Loading
Loading