From 722a46b8a08fa0402eabe0530cac8be67a0dd94a Mon Sep 17 00:00:00 2001 From: monoxgas Date: Fri, 6 Dec 2024 14:04:44 -0700 Subject: [PATCH 1/2] Improved handling for agent links, mismatched profiles, etc. --- dreadnode_cli/agent/cli.py | 95 ++++++++++++++++++------ dreadnode_cli/agent/config.py | 2 +- dreadnode_cli/agent/tests/test_config.py | 53 ++++++++++++- dreadnode_cli/profile/cli.py | 1 + 4 files changed, 126 insertions(+), 25 deletions(-) diff --git a/dreadnode_cli/agent/cli.py b/dreadnode_cli/agent/cli.py index 1cf4511..1fbb9fe 100644 --- a/dreadnode_cli/agent/cli.py +++ b/dreadnode_cli/agent/cli.py @@ -24,12 +24,48 @@ ) from dreadnode_cli.agent.templates import Template, install_template, install_template_from_dir from dreadnode_cli.config import UserConfig +from dreadnode_cli.profile.cli import switch as switch_profile from dreadnode_cli.types import GithubRepo from dreadnode_cli.utils import download_and_unzip_archive, pretty_cli, repo_exists cli = typer.Typer(no_args_is_help=True) +def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None = None) -> None: + """Ensure the active agent link matches the current server profile.""" + + user_config = user_config or UserConfig.read() + + if not user_config.active_profile_name: + raise Exception("No server profile is set, use [bold]dreadnode login[/] to authenticate") + + if agent_config.links and not agent_config.has_link_to_profile(user_config.active_profile_name): + linked_profiles = ", ".join(agent_config.linked_profiles) + plural = "s" if len(agent_config.linked_profiles) > 1 else "" + raise Exception( + f"This agent is linked to the [magenta]{linked_profiles}[/] server profile{plural}, " + f"but the current server profile is [yellow]{user_config.active_profile_name}[/], ", + "use [bold]dreadnode agent push[/] to create a new link with this profile.", + ) + + if agent_config.active_link.profile != user_config.active_profile_name: + if ( + Prompt.ask( + f"Current agent link points to the [yellow]{agent_config.active_link.profile}[/] server profile, " + f"would you like to switch to it?", + choices=["y", "n"], + default="y", + ) + == "n" + ): + print() + raise Exception( + "Agent link does not match the current server profile. Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]." + ) + + switch_profile(agent_config.active_link.profile) + + @cli.command(help="List all available templates with their descriptions") @pretty_cli def templates() -> None: @@ -164,12 +200,9 @@ def push( if not user_config.active_profile_name: raise Exception("No server profile is set, use [bold]dreadnode login[/] to authenticate") - if agent_config.links and not agent_config.is_linked_to_profile(user_config.active_profile_name): - linked_profiles = ", ".join(agent_config.linked_profiles) - plural = "s" if len(agent_config.linked_profiles) > 1 else "" - raise Exception( - f"This agent is linked to the [magenta]{linked_profiles}[/] server profile{plural}, but the current server profile is [yellow]{user_config.active_profile_name}[/], create the agent again." - ) + if agent_config.links and not agent_config.has_link_to_profile(user_config.active_profile_name): + print(f":link: Linking as a fresh agent to the current profile [magenta]{user_config.active_profile_name}[/]") + new = True server_config = user_config.get_server_config() @@ -239,6 +272,8 @@ def deploy( watch: t.Annotated[bool, typer.Option("--watch", "-w", help="Watch the run status")] = True, ) -> None: agent_config = AgentConfig.read(directory) + ensure_profile(agent_config) + active_link = agent_config.active_link client = api.create_client() @@ -280,6 +315,7 @@ def models( ) -> None: if strike is None: agent_config = AgentConfig.read(directory) + ensure_profile(agent_config) strike = strike or agent_config.strike if strike is None: @@ -310,7 +346,10 @@ def latest( ] = False, raw: t.Annotated[bool, typer.Option("--raw", help="Show raw JSON output")] = False, ) -> None: - active_link = AgentConfig.read(directory).active_link + agent_config = AgentConfig.read(directory) + ensure_profile(agent_config) + + active_link = agent_config.active_link if not active_link.runs: print(":exclamation: No runs yet, use [bold]dreadnode agent deploy[/]") return @@ -332,9 +371,11 @@ def show( typer.Option("--dir", "-d", help="The agent directory", file_okay=False, resolve_path=True), ] = pathlib.Path("."), ) -> None: - active_link = AgentConfig.read(directory).active_link + agent_config = AgentConfig.read(directory) + ensure_profile(agent_config) + client = api.create_client() - agent = client.get_strike_agent(active_link.id) + agent = client.get_strike_agent(agent_config.active_link.id) print(format_agent(agent)) @@ -345,9 +386,11 @@ def versions( pathlib.Path, typer.Argument(help="The agent directory", file_okay=False, resolve_path=True) ] = pathlib.Path("."), ) -> None: - active_link = AgentConfig.read(directory).active_link + agent_config = AgentConfig.read(directory) + ensure_profile(agent_config) + client = api.create_client() - agent = client.get_strike_agent(active_link.id) + agent = client.get_strike_agent(agent_config.active_link.id) print(format_agent_versions(agent)) @@ -358,10 +401,13 @@ def runs( pathlib.Path, typer.Argument(help="The agent directory", file_okay=False, resolve_path=True) ] = pathlib.Path("."), ) -> None: - active_link = AgentConfig.read(directory).active_link + agent_config = AgentConfig.read(directory) + ensure_profile(agent_config) client = api.create_client() - runs = [run for run in client.list_strike_runs() if run.id in active_link.runs and run.start is not None] + runs = [ + run for run in client.list_strike_runs() if run.id in agent_config.active_link.runs and run.start is not None + ] runs = sorted(runs, key=lambda r: r.start or 0, reverse=True) if not runs: @@ -379,23 +425,26 @@ def links( ] = pathlib.Path("."), ) -> None: agent_config = AgentConfig.read(directory) - client = api.create_client() - + user_config = UserConfig.read() _ = agent_config.active_link table = Table(box=box.ROUNDED) table.add_column("Key", style="magenta") table.add_column("Name", style="cyan") + table.add_column("Profile") table.add_column("ID") for key, link in agent_config.links.items(): - active = key == agent_config.active + active_link = key == agent_config.active + mismatched_profile = active_link and user_config.active_profile_name != link.profile + client = api.create_client(profile=agent_config.links[key].profile) agent = client.get_strike_agent(link.id) table.add_row( - agent.key + ("*" if active else ""), + agent.key + ("*" if active_link else ""), agent.name or "N/A", + link.profile + ("[bold red]* (not-active)[/]" if mismatched_profile else ""), f"[dim]{agent.id}[/]", - style="bold" if active else None, + style="bold" if active_link else None, ) print(table) @@ -404,7 +453,7 @@ def links( @cli.command(help="Switch/link to a different agent") @pretty_cli def switch( - agent: t.Annotated[str, typer.Argument(help="Agent key or id")], + agent_or_profile: t.Annotated[str, typer.Argument(help="Agent key/id or profile name")], directory: t.Annotated[ pathlib.Path, typer.Argument(help="The agent directory", file_okay=False, resolve_path=True) ] = pathlib.Path("."), @@ -412,10 +461,12 @@ def switch( agent_config = AgentConfig.read(directory) for key, link in agent_config.links.items(): - if agent in (key, link.id): - print(f":robot: Switched to link [bold magenta]{key}[/] ([dim]{link.id}[/])") + if agent_or_profile in (key, link.id) or agent_or_profile == link.profile: + print( + f":robot: Switched to link [bold magenta]{key}[/] for profile [cyan]{link.profile}[/] ([dim]{link.id}[/])" + ) agent_config.active = key agent_config.write(directory) return - print(f":exclamation: Agent '{agent}' not found, use [bold]dreadnode agent links[/]") + print(f":exclamation: '{agent_or_profile}' not found, use [bold]dreadnode agent links[/]") diff --git a/dreadnode_cli/agent/config.py b/dreadnode_cli/agent/config.py index aaf4875..7fd2dd9 100644 --- a/dreadnode_cli/agent/config.py +++ b/dreadnode_cli/agent/config.py @@ -53,7 +53,7 @@ def add_link(self, key: str, id: UUID, profile: str) -> "AgentConfig": def linked_profiles(self) -> list[str]: return list({link.profile for link in self.links.values()}) - def is_linked_to_profile(self, profile: str) -> bool: + def has_link_to_profile(self, profile: str) -> bool: return any(link.profile == profile for link in self.links.values()) def add_run(self, id: UUID) -> "AgentConfig": diff --git a/dreadnode_cli/agent/tests/test_config.py b/dreadnode_cli/agent/tests/test_config.py index 5da9de3..85bd61d 100644 --- a/dreadnode_cli/agent/tests/test_config.py +++ b/dreadnode_cli/agent/tests/test_config.py @@ -1,9 +1,12 @@ from pathlib import Path +from unittest.mock import patch from uuid import UUID import pytest +from dreadnode_cli.agent.cli import ensure_active_profile from dreadnode_cli.agent.config import AgentConfig +from dreadnode_cli.config import ServerConfig, UserConfig def test_agent_config_read_not_initialized(tmp_path: Path) -> None: @@ -28,8 +31,8 @@ def test_agent_config_add_link() -> None: assert config.links["test"].runs == [] assert config.links["test"].profile == "test" assert config.linked_profiles == ["test"] - assert config.is_linked_to_profile("test") - assert not config.is_linked_to_profile("other") + assert config.has_link_to_profile("test") + assert not config.has_link_to_profile("other") def test_agent_config_add_run() -> None: @@ -82,3 +85,49 @@ def test_agent_config_update_active() -> None: config.links.clear() config._update_active() assert config.active is None + + +def test_ensure_profile() -> None: + agent_config = AgentConfig(project_name="test") + user_config = UserConfig() + + # We don't have any profiles + with pytest.raises(Exception, match="No server profile is set"): + ensure_active_profile(agent_config, user_config=user_config) + + server_config = ServerConfig( + url="http://test", + email="test@test.com", + username="test", + api_key="test", + access_token="test", + refresh_token="test", + ) + + user_config.set_server_config(server_config, profile="main") + user_config.set_server_config(server_config, profile="other") + user_config.active = "main" + + # We have no links + with pytest.raises(Exception, match="No agent is currently linked"): + ensure_active_profile(agent_config, user_config=user_config) + + # We have a link, but none are available for the current profile + agent_config.add_link("test-other", UUID("00000000-0000-0000-0000-000000000000"), "other") + with pytest.raises(Exception, match="This agent is linked to the"): + ensure_active_profile(agent_config, user_config=user_config) + + # We have another link, but the profiles don't match + agent_config.add_link("test-main", UUID("00000000-0000-0000-0000-000000000000"), "main") + agent_config.active = "test-other" + with patch("rich.prompt.Prompt.ask", return_value="n"): + with pytest.raises(Exception, match="Agent link does not match the current server profile"): + ensure_active_profile(agent_config, user_config=user_config) + + # We should switch if the user agrees + assert user_config.active == "main" + with patch("rich.prompt.Prompt.ask", return_value="y"), patch("dreadnode_cli.config.UserConfig.write"), patch( + "dreadnode_cli.config.UserConfig.read", return_value=user_config + ): + ensure_active_profile(agent_config, user_config=user_config) + assert user_config.active == "other" diff --git a/dreadnode_cli/profile/cli.py b/dreadnode_cli/profile/cli.py index fb39eb7..c60e747 100644 --- a/dreadnode_cli/profile/cli.py +++ b/dreadnode_cli/profile/cli.py @@ -60,6 +60,7 @@ def switch(profile: t.Annotated[str, typer.Argument(help="Profile to switch to") print(f"|- email: [bold]{config.servers[profile].email}[/]") print(f"|- username: {config.servers[profile].username}") print(f"|- url: {config.servers[profile].url}") + print() @cli.command(help="Remove a server profile") From e630b4672111cc5d065f1ac21f4cda1274a810d9 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Fri, 6 Dec 2024 14:14:58 -0700 Subject: [PATCH 2/2] Fix broken function name --- dreadnode_cli/agent/tests/test_config.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dreadnode_cli/agent/tests/test_config.py b/dreadnode_cli/agent/tests/test_config.py index 85bd61d..6cd9e20 100644 --- a/dreadnode_cli/agent/tests/test_config.py +++ b/dreadnode_cli/agent/tests/test_config.py @@ -4,7 +4,7 @@ import pytest -from dreadnode_cli.agent.cli import ensure_active_profile +from dreadnode_cli.agent.cli import ensure_profile from dreadnode_cli.agent.config import AgentConfig from dreadnode_cli.config import ServerConfig, UserConfig @@ -93,7 +93,7 @@ def test_ensure_profile() -> None: # We don't have any profiles with pytest.raises(Exception, match="No server profile is set"): - ensure_active_profile(agent_config, user_config=user_config) + ensure_profile(agent_config, user_config=user_config) server_config = ServerConfig( url="http://test", @@ -110,24 +110,24 @@ def test_ensure_profile() -> None: # We have no links with pytest.raises(Exception, match="No agent is currently linked"): - ensure_active_profile(agent_config, user_config=user_config) + ensure_profile(agent_config, user_config=user_config) # We have a link, but none are available for the current profile agent_config.add_link("test-other", UUID("00000000-0000-0000-0000-000000000000"), "other") with pytest.raises(Exception, match="This agent is linked to the"): - ensure_active_profile(agent_config, user_config=user_config) + ensure_profile(agent_config, user_config=user_config) # We have another link, but the profiles don't match agent_config.add_link("test-main", UUID("00000000-0000-0000-0000-000000000000"), "main") agent_config.active = "test-other" with patch("rich.prompt.Prompt.ask", return_value="n"): with pytest.raises(Exception, match="Agent link does not match the current server profile"): - ensure_active_profile(agent_config, user_config=user_config) + ensure_profile(agent_config, user_config=user_config) # We should switch if the user agrees assert user_config.active == "main" with patch("rich.prompt.Prompt.ask", return_value="y"), patch("dreadnode_cli.config.UserConfig.write"), patch( "dreadnode_cli.config.UserConfig.read", return_value=user_config ): - ensure_active_profile(agent_config, user_config=user_config) + ensure_profile(agent_config, user_config=user_config) assert user_config.active == "other"