Skip to content

Commit

Permalink
Merge pull request #31 from dreadnode/simone/eng-652-support-user-pro…
Browse files Browse the repository at this point in the history
…vided-api-keys-and-model-configurations-for

new: implemented support for user models (ENG-652)
  • Loading branch information
evilsocket authored Jan 10, 2025
2 parents 210f8e1 + 4335a76 commit c4f4c64
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 18 deletions.
34 changes: 25 additions & 9 deletions dreadnode_cli/agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
from dreadnode_cli.agent.format import (
format_agent,
format_agent_versions,
format_models,
format_run,
format_runs,
format_strike_models,
format_strikes,
format_user_models,
)
from dreadnode_cli.agent.templates import cli as templates_cli
from dreadnode_cli.agent.templates.format import format_templates
from dreadnode_cli.agent.templates.manager import TemplateManager
from dreadnode_cli.config import UserConfig
from dreadnode_cli.config import UserConfig, UserModel, UserModels
from dreadnode_cli.profile.cli import switch as switch_profile
from dreadnode_cli.types import GithubRepo
from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli
Expand Down Expand Up @@ -339,14 +340,22 @@ def deploy(
if strike is None:
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")

user_models = UserModels.read()
user_model: UserModel | None = None

# Verify the model if it was supplied
if model is not None:
strike_response = client.get_strike(strike)
if not any(m.key == model for m in strike_response.models):
print(format_models(strike_response.models))
raise Exception(f"Model '{model}' not found in strike '{strike_response.name}'")

run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model)
# check if it's a user model
user_model = next((m for m in user_models.models if m.key == model), None)
if not user_model:
# check if it's a strike model
strike_response = client.get_strike(strike)
if not any(m.key == model for m in strike_response.models):
models(directory, strike=strike)
print()
raise Exception(f"Model '{model}' is not a user model nor was found in strike '{strike_response.name}'")

run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model, user_model=user_model)
agent_config.add_run(run.id).write(directory)
formatted = format_run(run)

Expand All @@ -369,6 +378,11 @@ def models(
] = pathlib.Path("."),
strike: t.Annotated[str | None, typer.Option("--strike", "-s", help="The strike to query")] = None,
) -> None:
user_models = UserModels.read()
if user_models.models:
print("[bold]User models:[/]\n")
print(format_user_models(user_models.models))

if strike is None:
agent_config = AgentConfig.read(directory)
ensure_profile(agent_config)
Expand All @@ -378,7 +392,9 @@ def models(
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")

strike_response = api.create_client().get_strike(strike)
print(format_models(strike_response.models))
if user_models.models:
print("\n[bold]Strike models:[/]\n")
print(format_strike_models(strike_response.models))


@cli.command(help="List available strikes")
Expand Down
29 changes: 26 additions & 3 deletions dreadnode_cli/agent/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
from rich.text import Text

from dreadnode_cli import api
from dreadnode_cli.config import UserModel

P = t.ParamSpec("P")

# um@ is added to indicate a user model
USER_MODEL_PREFIX: str = "um@"


def get_status_style(status: api.Client.StrikeRunStatus | api.Client.StrikeRunZoneStatus | None) -> str:
return (
Expand Down Expand Up @@ -62,7 +66,26 @@ def format_time(dt: datetime | None) -> str:
return dt.astimezone().strftime("%c") if dt else "-"


def format_models(models: list[api.Client.StrikeModel]) -> RenderableType:
def format_user_models(models: list[UserModel]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("key")
table.add_column("name")
table.add_column("provider")
table.add_column("api_key")

for model in models:
provider_style = get_model_provider_style(model.provider)
table.add_row(
Text(model.key),
Text(model.name, style=f"bold {provider_style}"),
Text(model.provider, style=provider_style),
Text("yes" if model.api_key else "no", style="green" if model.api_key else "dim"),
)

return table


def format_strike_models(models: list[api.Client.StrikeModel]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("key")
table.add_column("name")
Expand Down Expand Up @@ -272,7 +295,7 @@ def format_run(run: api.Client.StrikeRunResponse, *, verbose: bool = False, incl
agent_name = f"[bold magenta]{run.agent_key}[/]"

table.add_row("", "")
table.add_row("model", run.model or "<default>")
table.add_row("model", run.model.replace(USER_MODEL_PREFIX, "") if run.model else "<default>")
table.add_row("agent", f"{agent_name} ([dim]rev[/] [yellow]{run.agent_revision}[/])")
table.add_row("image", Text(run.agent_version.container.image, style="cyan"))
table.add_row("notes", run.agent_version.notes or "-")
Expand Down Expand Up @@ -304,7 +327,7 @@ def format_runs(runs: list[api.Client.StrikeRunSummaryResponse]) -> RenderableTy
str(run.id),
f"[bold magenta]{run.agent_key}[/] [dim]:[/] [yellow]{run.agent_revision}[/]",
Text(run.status, style="bold " + get_status_style(run.status)),
Text(run.model or "-"),
Text(run.model.replace(USER_MODEL_PREFIX, "") if run.model else "-"),
format_time(run.start),
Text(format_duration(run.start, run.end), style="bold cyan"),
)
Expand Down
10 changes: 8 additions & 2 deletions dreadnode_cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rich import print

from dreadnode_cli import __version__, utils
from dreadnode_cli.config import UserConfig
from dreadnode_cli.config import UserConfig, UserModel
from dreadnode_cli.defaults import (
DEBUG,
DEFAULT_MAX_POLL_TIME,
Expand Down Expand Up @@ -430,14 +430,20 @@ def create_strike_agent_version(
return self.StrikeAgentResponse(**response.json())

def start_strike_run(
self, agent_version_id: UUID, *, model: str | None = None, strike: UUID | str | None = None
self,
agent_version_id: UUID,
*,
model: str | None = None,
user_model: UserModel | None = None,
strike: UUID | str | None = None,
) -> StrikeRunResponse:
response = self.request(
"POST",
"/api/strikes/runs",
json_data={
"agent_version_id": str(agent_version_id),
"model": model,
"user_model": user_model.model_dump(mode="json") if user_model else None,
"strike": str(strike) if strike else None,
},
)
Expand Down
36 changes: 32 additions & 4 deletions dreadnode_cli/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pydantic
from pydantic import BaseModel
from rich import print
from ruamel.yaml import YAML

from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH
from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH, USER_MODELS_CONFIG_PATH


class ServerConfig(pydantic.BaseModel):
class ServerConfig(BaseModel):
"""Server specific authentication data and API URL."""

url: str
Expand All @@ -16,7 +16,7 @@ class ServerConfig(pydantic.BaseModel):
refresh_token: str


class UserConfig(pydantic.BaseModel):
class UserConfig(BaseModel):
"""User configuration supporting multiple server profiles."""

active: str | None = None
Expand Down Expand Up @@ -74,3 +74,31 @@ def set_server_config(self, config: ServerConfig, profile: str | None = None) ->
profile = profile or self.active or DEFAULT_PROFILE_NAME
self.servers[profile] = config
return self


class UserModel(BaseModel):
"""
A user defined model.
"""

key: str
name: str
provider: str
generator_id: str
api_key: str | None = None


class UserModels(BaseModel):
"""User models configuration."""

models: list[UserModel] = []

@classmethod
def read(cls) -> "UserModels":
"""Read the user models configuration from the file system or return an empty instance."""

if not USER_MODELS_CONFIG_PATH.exists():
return cls()

with USER_MODELS_CONFIG_PATH.open("r") as f:
return cls.model_validate(YAML().load(f))
6 changes: 6 additions & 0 deletions dreadnode_cli/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "config"
)

# path to the user models configuration file
USER_MODELS_CONFIG_PATH = pathlib.Path(
# allow overriding the user config file via env variable
os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "models.yml"
)

# path to the templates directory
TEMPLATES_PATH = pathlib.Path(
# allow overriding the templates path via env variable
Expand Down

0 comments on commit c4f4c64

Please sign in to comment.