Skip to content

Commit

Permalink
Merge pull request #20 from dreadnode/feature/agent-link-improvements
Browse files Browse the repository at this point in the history
feat: Agent link improvements
  • Loading branch information
evilsocket authored Dec 7, 2024
2 parents 7cd2c05 + e630b46 commit 661eddc
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 25 deletions.
95 changes: 73 additions & 22 deletions dreadnode_cli/agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,48 @@
)
from dreadnode_cli.agent.templates import Template, install_template, install_template_from_dir
from dreadnode_cli.config import UserConfig
from dreadnode_cli.profile.cli import switch as switch_profile
from dreadnode_cli.types import GithubRepo
from dreadnode_cli.utils import download_and_unzip_archive, pretty_cli, repo_exists

cli = typer.Typer(no_args_is_help=True)


def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None = None) -> None:
"""Ensure the active agent link matches the current server profile."""

user_config = user_config or UserConfig.read()

if not user_config.active_profile_name:
raise Exception("No server profile is set, use [bold]dreadnode login[/] to authenticate")

if agent_config.links and not agent_config.has_link_to_profile(user_config.active_profile_name):
linked_profiles = ", ".join(agent_config.linked_profiles)
plural = "s" if len(agent_config.linked_profiles) > 1 else ""
raise Exception(
f"This agent is linked to the [magenta]{linked_profiles}[/] server profile{plural}, "
f"but the current server profile is [yellow]{user_config.active_profile_name}[/], ",
"use [bold]dreadnode agent push[/] to create a new link with this profile.",
)

if agent_config.active_link.profile != user_config.active_profile_name:
if (
Prompt.ask(
f"Current agent link points to the [yellow]{agent_config.active_link.profile}[/] server profile, "
f"would you like to switch to it?",
choices=["y", "n"],
default="y",
)
== "n"
):
print()
raise Exception(
"Agent link does not match the current server profile. Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]."
)

switch_profile(agent_config.active_link.profile)


@cli.command(help="List all available templates with their descriptions")
@pretty_cli
def templates() -> None:
Expand Down Expand Up @@ -164,12 +200,9 @@ def push(
if not user_config.active_profile_name:
raise Exception("No server profile is set, use [bold]dreadnode login[/] to authenticate")

if agent_config.links and not agent_config.is_linked_to_profile(user_config.active_profile_name):
linked_profiles = ", ".join(agent_config.linked_profiles)
plural = "s" if len(agent_config.linked_profiles) > 1 else ""
raise Exception(
f"This agent is linked to the [magenta]{linked_profiles}[/] server profile{plural}, but the current server profile is [yellow]{user_config.active_profile_name}[/], create the agent again."
)
if agent_config.links and not agent_config.has_link_to_profile(user_config.active_profile_name):
print(f":link: Linking as a fresh agent to the current profile [magenta]{user_config.active_profile_name}[/]")
new = True

server_config = user_config.get_server_config()

Expand Down Expand Up @@ -239,6 +272,8 @@ def deploy(
watch: t.Annotated[bool, typer.Option("--watch", "-w", help="Watch the run status")] = True,
) -> None:
agent_config = AgentConfig.read(directory)
ensure_profile(agent_config)

active_link = agent_config.active_link

client = api.create_client()
Expand Down Expand Up @@ -280,6 +315,7 @@ def models(
) -> None:
if strike is None:
agent_config = AgentConfig.read(directory)
ensure_profile(agent_config)

strike = strike or agent_config.strike
if strike is None:
Expand Down Expand Up @@ -310,7 +346,10 @@ def latest(
] = False,
raw: t.Annotated[bool, typer.Option("--raw", help="Show raw JSON output")] = False,
) -> None:
active_link = AgentConfig.read(directory).active_link
agent_config = AgentConfig.read(directory)
ensure_profile(agent_config)

active_link = agent_config.active_link
if not active_link.runs:
print(":exclamation: No runs yet, use [bold]dreadnode agent deploy[/]")
return
Expand All @@ -332,9 +371,11 @@ def show(
typer.Option("--dir", "-d", help="The agent directory", file_okay=False, resolve_path=True),
] = pathlib.Path("."),
) -> None:
active_link = AgentConfig.read(directory).active_link
agent_config = AgentConfig.read(directory)
ensure_profile(agent_config)

client = api.create_client()
agent = client.get_strike_agent(active_link.id)
agent = client.get_strike_agent(agent_config.active_link.id)
print(format_agent(agent))


Expand All @@ -345,9 +386,11 @@ def versions(
pathlib.Path, typer.Argument(help="The agent directory", file_okay=False, resolve_path=True)
] = pathlib.Path("."),
) -> None:
active_link = AgentConfig.read(directory).active_link
agent_config = AgentConfig.read(directory)
ensure_profile(agent_config)

client = api.create_client()
agent = client.get_strike_agent(active_link.id)
agent = client.get_strike_agent(agent_config.active_link.id)
print(format_agent_versions(agent))


Expand All @@ -358,10 +401,13 @@ def runs(
pathlib.Path, typer.Argument(help="The agent directory", file_okay=False, resolve_path=True)
] = pathlib.Path("."),
) -> None:
active_link = AgentConfig.read(directory).active_link
agent_config = AgentConfig.read(directory)
ensure_profile(agent_config)

client = api.create_client()
runs = [run for run in client.list_strike_runs() if run.id in active_link.runs and run.start is not None]
runs = [
run for run in client.list_strike_runs() if run.id in agent_config.active_link.runs and run.start is not None
]
runs = sorted(runs, key=lambda r: r.start or 0, reverse=True)

if not runs:
Expand All @@ -379,23 +425,26 @@ def links(
] = pathlib.Path("."),
) -> None:
agent_config = AgentConfig.read(directory)
client = api.create_client()

user_config = UserConfig.read()
_ = agent_config.active_link

table = Table(box=box.ROUNDED)
table.add_column("Key", style="magenta")
table.add_column("Name", style="cyan")
table.add_column("Profile")
table.add_column("ID")

for key, link in agent_config.links.items():
active = key == agent_config.active
active_link = key == agent_config.active
mismatched_profile = active_link and user_config.active_profile_name != link.profile
client = api.create_client(profile=agent_config.links[key].profile)
agent = client.get_strike_agent(link.id)
table.add_row(
agent.key + ("*" if active else ""),
agent.key + ("*" if active_link else ""),
agent.name or "N/A",
link.profile + ("[bold red]* (not-active)[/]" if mismatched_profile else ""),
f"[dim]{agent.id}[/]",
style="bold" if active else None,
style="bold" if active_link else None,
)

print(table)
Expand All @@ -404,18 +453,20 @@ def links(
@cli.command(help="Switch/link to a different agent")
@pretty_cli
def switch(
agent: t.Annotated[str, typer.Argument(help="Agent key or id")],
agent_or_profile: t.Annotated[str, typer.Argument(help="Agent key/id or profile name")],
directory: t.Annotated[
pathlib.Path, typer.Argument(help="The agent directory", file_okay=False, resolve_path=True)
] = pathlib.Path("."),
) -> None:
agent_config = AgentConfig.read(directory)

for key, link in agent_config.links.items():
if agent in (key, link.id):
print(f":robot: Switched to link [bold magenta]{key}[/] ([dim]{link.id}[/])")
if agent_or_profile in (key, link.id) or agent_or_profile == link.profile:
print(
f":robot: Switched to link [bold magenta]{key}[/] for profile [cyan]{link.profile}[/] ([dim]{link.id}[/])"
)
agent_config.active = key
agent_config.write(directory)
return

print(f":exclamation: Agent '{agent}' not found, use [bold]dreadnode agent links[/]")
print(f":exclamation: '{agent_or_profile}' not found, use [bold]dreadnode agent links[/]")
2 changes: 1 addition & 1 deletion dreadnode_cli/agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def add_link(self, key: str, id: UUID, profile: str) -> "AgentConfig":
def linked_profiles(self) -> list[str]:
return list({link.profile for link in self.links.values()})

def is_linked_to_profile(self, profile: str) -> bool:
def has_link_to_profile(self, profile: str) -> bool:
return any(link.profile == profile for link in self.links.values())

def add_run(self, id: UUID) -> "AgentConfig":
Expand Down
53 changes: 51 additions & 2 deletions dreadnode_cli/agent/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from pathlib import Path
from unittest.mock import patch
from uuid import UUID

import pytest

from dreadnode_cli.agent.cli import ensure_profile
from dreadnode_cli.agent.config import AgentConfig
from dreadnode_cli.config import ServerConfig, UserConfig


def test_agent_config_read_not_initialized(tmp_path: Path) -> None:
Expand All @@ -28,8 +31,8 @@ def test_agent_config_add_link() -> None:
assert config.links["test"].runs == []
assert config.links["test"].profile == "test"
assert config.linked_profiles == ["test"]
assert config.is_linked_to_profile("test")
assert not config.is_linked_to_profile("other")
assert config.has_link_to_profile("test")
assert not config.has_link_to_profile("other")


def test_agent_config_add_run() -> None:
Expand Down Expand Up @@ -82,3 +85,49 @@ def test_agent_config_update_active() -> None:
config.links.clear()
config._update_active()
assert config.active is None


def test_ensure_profile() -> None:
agent_config = AgentConfig(project_name="test")
user_config = UserConfig()

# We don't have any profiles
with pytest.raises(Exception, match="No server profile is set"):
ensure_profile(agent_config, user_config=user_config)

server_config = ServerConfig(
url="http://test",
email="[email protected]",
username="test",
api_key="test",
access_token="test",
refresh_token="test",
)

user_config.set_server_config(server_config, profile="main")
user_config.set_server_config(server_config, profile="other")
user_config.active = "main"

# We have no links
with pytest.raises(Exception, match="No agent is currently linked"):
ensure_profile(agent_config, user_config=user_config)

# We have a link, but none are available for the current profile
agent_config.add_link("test-other", UUID("00000000-0000-0000-0000-000000000000"), "other")
with pytest.raises(Exception, match="This agent is linked to the"):
ensure_profile(agent_config, user_config=user_config)

# We have another link, but the profiles don't match
agent_config.add_link("test-main", UUID("00000000-0000-0000-0000-000000000000"), "main")
agent_config.active = "test-other"
with patch("rich.prompt.Prompt.ask", return_value="n"):
with pytest.raises(Exception, match="Agent link does not match the current server profile"):
ensure_profile(agent_config, user_config=user_config)

# We should switch if the user agrees
assert user_config.active == "main"
with patch("rich.prompt.Prompt.ask", return_value="y"), patch("dreadnode_cli.config.UserConfig.write"), patch(
"dreadnode_cli.config.UserConfig.read", return_value=user_config
):
ensure_profile(agent_config, user_config=user_config)
assert user_config.active == "other"
1 change: 1 addition & 0 deletions dreadnode_cli/profile/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def switch(profile: t.Annotated[str, typer.Argument(help="Profile to switch to")
print(f"|- email: [bold]{config.servers[profile].email}[/]")
print(f"|- username: {config.servers[profile].username}")
print(f"|- url: {config.servers[profile].url}")
print()


@cli.command(help="Remove a server profile")
Expand Down

0 comments on commit 661eddc

Please sign in to comment.