Skip to content

Commit

Permalink
Revert "new: implemented user models support (ENG-652)"
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket authored Jan 8, 2025
1 parent 894b39e commit 773303e
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 96 deletions.
34 changes: 9 additions & 25 deletions dreadnode_cli/agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@
from dreadnode_cli.agent.format import (
format_agent,
format_agent_versions,
format_models,
format_run,
format_runs,
format_strike_models,
format_strikes,
format_user_models,
)
from dreadnode_cli.agent.templates import cli as templates_cli
from dreadnode_cli.agent.templates.format import format_templates
from dreadnode_cli.agent.templates.manager import TemplateManager
from dreadnode_cli.config import UserConfig, UserModel, UserModels
from dreadnode_cli.config import UserConfig
from dreadnode_cli.profile.cli import switch as switch_profile
from dreadnode_cli.types import GithubRepo
from dreadnode_cli.utils import download_and_unzip_archive, get_repo_archive_source_path, pretty_cli
Expand Down Expand Up @@ -340,22 +339,14 @@ def deploy(
if strike is None:
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")

user_models = UserModels.read()
user_model: UserModel | None = None

# Verify the model if it was supplied
if model is not None:
# check if it's a user model
user_model = next((m for m in user_models.models if m.key == model), None)
if not user_model:
# check if it's a strike model
strike_response = client.get_strike(strike)
if not any(m.key == model for m in strike_response.models):
models(directory, strike=strike)
print()
raise Exception(f"Model '{model}' is not a user model nor was found in strike '{strike_response.name}'")

run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model, user_model=user_model)
strike_response = client.get_strike(strike)
if not any(m.key == model for m in strike_response.models):
print(format_models(strike_response.models))
raise Exception(f"Model '{model}' not found in strike '{strike_response.name}'")

run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model)
agent_config.add_run(run.id).write(directory)
formatted = format_run(run)

Expand All @@ -378,11 +369,6 @@ def models(
] = pathlib.Path("."),
strike: t.Annotated[str | None, typer.Option("--strike", "-s", help="The strike to query")] = None,
) -> None:
user_models = UserModels.read()
if user_models.models:
print("[bold]User models:[/]\n")
print(format_user_models(user_models.models))

if strike is None:
agent_config = AgentConfig.read(directory)
ensure_profile(agent_config)
Expand All @@ -392,9 +378,7 @@ def models(
raise Exception("No strike specified, use -s/--strike or set the strike in the agent config")

strike_response = api.create_client().get_strike(strike)
if user_models.models:
print("\n[bold]Strike models:[/]\n")
print(format_strike_models(strike_response.models))
print(format_models(strike_response.models))


@cli.command(help="List available strikes")
Expand Down
28 changes: 3 additions & 25 deletions dreadnode_cli/agent/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from rich.text import Text

from dreadnode_cli import api
from dreadnode_cli.config import UserModel

P = t.ParamSpec("P")

Expand Down Expand Up @@ -63,26 +62,7 @@ def format_time(dt: datetime | None) -> str:
return dt.astimezone().strftime("%c") if dt else "-"


def format_user_models(models: list[UserModel]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("key")
table.add_column("name")
table.add_column("provider")
table.add_column("api_key")

for model in models:
provider_style = get_model_provider_style(model.provider)
table.add_row(
Text(model.key),
Text(model.name, style=f"bold {provider_style}"),
Text(model.provider, style=provider_style),
Text("yes" if model.api_key else "no", style="green" if model.api_key else "dim"),
)

return table


def format_strike_models(models: list[api.Client.StrikeModel]) -> RenderableType:
def format_models(models: list[api.Client.StrikeModel]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("key")
table.add_column("name")
Expand Down Expand Up @@ -292,8 +272,7 @@ def format_run(run: api.Client.StrikeRunResponse, *, verbose: bool = False, incl
agent_name = f"[bold magenta]{run.agent_key}[/]"

table.add_row("", "")
# um@ is added to indicate a user model
table.add_row("model", run.model.replace("um@", "") if run.model else "<default>")
table.add_row("model", run.model or "<default>")
table.add_row("agent", f"{agent_name} ([dim]rev[/] [yellow]{run.agent_revision}[/])")
table.add_row("image", Text(run.agent_version.container.image, style="cyan"))
table.add_row("notes", run.agent_version.notes or "-")
Expand Down Expand Up @@ -325,8 +304,7 @@ def format_runs(runs: list[api.Client.StrikeRunSummaryResponse]) -> RenderableTy
str(run.id),
f"[bold magenta]{run.agent_key}[/] [dim]:[/] [yellow]{run.agent_revision}[/]",
Text(run.status, style="bold " + get_status_style(run.status)),
# um@ is added to indicate a user model
Text(run.model.replace("um@", "") if run.model else "-"),
Text(run.model or "-"),
format_time(run.start),
Text(format_duration(run.start, run.end), style="bold cyan"),
)
Expand Down
10 changes: 2 additions & 8 deletions dreadnode_cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rich import print

from dreadnode_cli import __version__, utils
from dreadnode_cli.config import UserConfig, UserModel
from dreadnode_cli.config import UserConfig
from dreadnode_cli.defaults import (
DEBUG,
DEFAULT_MAX_POLL_TIME,
Expand Down Expand Up @@ -430,20 +430,14 @@ def create_strike_agent_version(
return self.StrikeAgentResponse(**response.json())

def start_strike_run(
self,
agent_version_id: UUID,
*,
model: str | None = None,
user_model: UserModel | None = None,
strike: UUID | str | None = None,
self, agent_version_id: UUID, *, model: str | None = None, strike: UUID | str | None = None
) -> StrikeRunResponse:
response = self.request(
"POST",
"/api/strikes/runs",
json_data={
"agent_version_id": str(agent_version_id),
"model": model,
"user_model": user_model.model_dump(mode="json") if user_model else None,
"strike": str(strike) if strike else None,
},
)
Expand Down
36 changes: 4 additions & 32 deletions dreadnode_cli/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from pydantic import BaseModel
import pydantic
from rich import print
from ruamel.yaml import YAML

from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH, USER_MODELS_CONFIG_PATH
from dreadnode_cli.defaults import DEFAULT_PROFILE_NAME, USER_CONFIG_PATH


class ServerConfig(BaseModel):
class ServerConfig(pydantic.BaseModel):
"""Server specific authentication data and API URL."""

url: str
Expand All @@ -16,7 +16,7 @@ class ServerConfig(BaseModel):
refresh_token: str


class UserConfig(BaseModel):
class UserConfig(pydantic.BaseModel):
"""User configuration supporting multiple server profiles."""

active: str | None = None
Expand Down Expand Up @@ -74,31 +74,3 @@ def set_server_config(self, config: ServerConfig, profile: str | None = None) ->
profile = profile or self.active or DEFAULT_PROFILE_NAME
self.servers[profile] = config
return self


class UserModel(BaseModel):
"""
A user defined model.
"""

key: str
name: str
provider: str
generator_id: str
api_key: str | None = None


class UserModels(BaseModel):
"""User models configuration."""

models: list[UserModel] = []

@classmethod
def read(cls) -> "UserModels":
"""Read the user models configuration from the file system or return an empty instance."""

if not USER_MODELS_CONFIG_PATH.exists():
return cls()

with USER_MODELS_CONFIG_PATH.open("r") as f:
return cls.model_validate(YAML().load(f))
6 changes: 0 additions & 6 deletions dreadnode_cli/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@
os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "config"
)

# path to the user models configuration file
USER_MODELS_CONFIG_PATH = pathlib.Path(
# allow overriding the user config file via env variable
os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "models.yml"
)

# path to the templates directory
TEMPLATES_PATH = pathlib.Path(
# allow overriding the templates path via env variable
Expand Down

0 comments on commit 773303e

Please sign in to comment.