Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new: implemented support for user models (ENG-652) #31

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions dreadnode_cli/agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
from dreadnode_cli.agent.format import (
format_agent,
format_agent_versions,
format_models,
format_run,
format_runs,
format_strike_models,
format_strikes,
format_user_models,
)
from dreadnode_cli.agent.templates import cli as templates_cli
from dreadnode_cli.agent.templates.format import format_templates
from dreadnode_cli.agent.templates.manager import TemplateManager
from dreadnode_cli.config import UserConfig
from dreadnode_cli.config import UserConfig, UserModel, UserModels
from dreadnode_cli.profile.cli import switch as switch_profile
from dreadnode_cli.types import GithubRepo
from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli
Expand Down Expand Up @@ -339,14 +340,22 @@ def deploy(
if strike is None:
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")

user_models = UserModels.read()
user_model: UserModel | None = None

# Verify the model if it was supplied
if model is not None:
strike_response = client.get_strike(strike)
if not any(m.key == model for m in strike_response.models):
print(format_models(strike_response.models))
raise Exception(f"Model '{model}' not found in strike '{strike_response.name}'")

run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model)
# check if it's a user model
user_model = next((m for m in user_models.models if m.key == model), None)
if not user_model:
# check if it's a strike model
strike_response = client.get_strike(strike)
if not any(m.key == model for m in strike_response.models):
models(directory, strike=strike)
print()
raise Exception(f"Model '{model}' is not a user model nor was found in strike '{strike_response.name}'")

run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model, user_model=user_model)
agent_config.add_run(run.id).write(directory)
formatted = format_run(run)

Expand All @@ -369,6 +378,11 @@ def models(
] = pathlib.Path("."),
strike: t.Annotated[str | None, typer.Option("--strike", "-s", help="The strike to query")] = None,
) -> None:
user_models = UserModels.read()
if user_models.models:
print("[bold]User models:[/]\n")
print(format_user_models(user_models.models))

if strike is None:
agent_config = AgentConfig.read(directory)
ensure_profile(agent_config)
Expand All @@ -378,7 +392,9 @@ def models(
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")

strike_response = api.create_client().get_strike(strike)
print(format_models(strike_response.models))
if user_models.models:
print("\n[bold]Strike models:[/]\n")
print(format_strike_models(strike_response.models))


@cli.command(help="List available strikes")
Expand Down
29 changes: 26 additions & 3 deletions dreadnode_cli/agent/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
from rich.text import Text

from dreadnode_cli import api
from dreadnode_cli.config import UserModel

P = t.ParamSpec("P")

# um@ is added to indicate a user model
USER_MODEL_PREFIX: str = "um@"


def get_status_style(status: api.Client.StrikeRunStatus | api.Client.StrikeRunZoneStatus | None) -> str:
return (
Expand Down Expand Up @@ -62,7 +66,26 @@ def format_time(dt: datetime | None) -> str:
return dt.astimezone().strftime("%c") if dt else "-"


def format_models(models: list[api.Client.StrikeModel]) -> RenderableType:
def format_user_models(models: list[UserModel]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("key")
table.add_column("name")
table.add_column("provider")
table.add_column("api_key")

for model in models:
provider_style = get_model_provider_style(model.provider)
table.add_row(
Text(model.key),
Text(model.name, style=f"bold {provider_style}"),
Text(model.provider, style=provider_style),
Text("yes" if model.api_key else "no", style="green" if model.api_key else "dim"),
)

return table


def format_strike_models(models: list[api.Client.StrikeModel]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("key")
table.add_column("name")
Expand Down Expand Up @@ -272,7 +295,7 @@ def format_run(run: api.Client.StrikeRunResponse, *, verbose: bool = False, incl
agent_name = f"[bold magenta]{run.agent_key}[/]"

table.add_row("", "")
table.add_row("model", run.model or "<default>")
table.add_row("model", run.model.replace(USER_MODEL_PREFIX, "") if run.model else "<default>")
table.add_row("agent", f"{agent_name} ([dim]rev[/] [yellow]{run.agent_revision}[/])")
table.add_row("image", Text(run.agent_version.container.image, style="cyan"))
table.add_row("notes", run.agent_version.notes or "-")
Expand Down Expand Up @@ -304,7 +327,7 @@ def format_runs(runs: list[api.Client.StrikeRunSummaryResponse]) -> RenderableTy
str(run.id),
f"[bold magenta]{run.agent_key}[/] [dim]:[/] [yellow]{run.agent_revision}[/]",
Text(run.status, style="bold " + get_status_style(run.status)),
Text(run.model or "-"),
Text(run.model.replace(USER_MODEL_PREFIX, "") if run.model else "-"),
format_time(run.start),
Text(format_duration(run.start, run.end), style="bold cyan"),
)
Expand Down
10 changes: 8 additions & 2 deletions dreadnode_cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rich import print

from dreadnode_cli import __version__, utils
from dreadnode_cli.config import UserConfig
from dreadnode_cli.config import UserConfig, UserModel
from dreadnode_cli.defaults import (
DEBUG,
DEFAULT_MAX_POLL_TIME,
Expand Down Expand Up @@ -430,14 +430,20 @@ def create_strike_agent_version(
return self.StrikeAgentResponse(**response.json())

def start_strike_run(
self, agent_version_id: UUID, *, model: str | None = None, strike: UUID | str | None = None
self,
agent_version_id: UUID,
*,
model: str | None = None,
user_model: UserModel | None = None,
strike: UUID | str | None = None,
) -> StrikeRunResponse:
response = self.request(
"POST",
"/api/strikes/runs",
json_data={
"agent_version_id": str(agent_version_id),
"model": model,
"user_model": user_model.model_dump(mode="json") if user_model else None,
"strike": str(strike) if strike else None,
},
)
Expand Down
36 changes: 32 additions & 4 deletions dreadnode_cli/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pydantic
from pydantic import BaseModel
from rich import print
from ruamel.yaml import YAML

from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH
from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH, USER_MODELS_CONFIG_PATH


class ServerConfig(pydantic.BaseModel):
class ServerConfig(BaseModel):
"""Server specific authentication data and API URL."""

url: str
Expand All @@ -16,7 +16,7 @@ class ServerConfig(pydantic.BaseModel):
refresh_token: str


class UserConfig(pydantic.BaseModel):
class UserConfig(BaseModel):
"""User configuration supporting multiple server profiles."""

active: str | None = None
Expand Down Expand Up @@ -74,3 +74,31 @@ def set_server_config(self, config: ServerConfig, profile: str | None = None) ->
profile = profile or self.active or DEFAULT_PROFILE_NAME
self.servers[profile] = config
return self


class UserModel(BaseModel):
"""
A user defined model.
"""

key: str
name: str
provider: str
generator_id: str
api_key: str | None = None


class UserModels(BaseModel):
"""User models configuration."""

models: list[UserModel] = []

@classmethod
def read(cls) -> "UserModels":
"""Read the user models configuration from the file system or return an empty instance."""

if not USER_MODELS_CONFIG_PATH.exists():
return cls()

with USER_MODELS_CONFIG_PATH.open("r") as f:
return cls.model_validate(YAML().load(f))
6 changes: 6 additions & 0 deletions dreadnode_cli/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "config"
)

# path to the user models configuration file
USER_MODELS_CONFIG_PATH = pathlib.Path(
# allow overriding the user config file via env variable
os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "models.yml"
)

# path to the templates directory
TEMPLATES_PATH = pathlib.Path(
# allow overriding the templates path via env variable
Expand Down
Loading