Skip to content

Commit

Permalink
Added generic repo cloning (mainly for beta users). Some float format…
Browse files Browse the repository at this point in the history
…ting. Add --rebuild for containers. Some help text updates.
  • Loading branch information
monoxgas committed Dec 11, 2024
1 parent 34e6f71 commit 5598f06
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 30 deletions.
90 changes: 68 additions & 22 deletions dreadnode_cli/agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_repo_archive_source_path(source_dir: pathlib.Path) -> pathlib.Path:
return source_dir


@cli.command(help="List all available templates with their descriptions")
@cli.command(help="List available agent templates with their descriptions")
@pretty_cli
def templates() -> None:
print(format_templates())
Expand Down Expand Up @@ -117,6 +117,14 @@ def init(
),
] = None,
) -> None:
try:
AgentConfig.read(directory)
if Prompt.ask(":axe: Agent config exists, overwrite?", choices=["y", "n"], default="n") == "n":
return
print()
except Exception:
pass

print(f":coffee: Fetching strike '{strike}' ...")

client = api.create_client()
Expand All @@ -134,14 +142,6 @@ def init(

directory.mkdir(exist_ok=True)

try:
AgentConfig.read(directory)
if Prompt.ask(":axe: Agent config exists, overwrite?", choices=["y", "n"], default="n") == "n":
return
print()
except Exception:
pass

context = {"project_name": project_name, "strike": strike_response}

if source is None:
Expand All @@ -167,9 +167,8 @@ def init(

# This could be a private repo that the user can access
# by getting an access token from our API
elif github_repo.namespace == "dreadnode" and (
github_access_token := client.get_github_access_token([github_repo.repo])
):
elif github_repo.namespace == "dreadnode":
github_access_token = client.get_github_access_token([github_repo.repo])
print(":key: Accessed private repository")
source_dir = download_and_unzip_archive(
github_repo.api_zip_url, headers={"Authorization": f"Bearer {github_access_token.token}"}
Expand Down Expand Up @@ -207,7 +206,7 @@ def init(
print(f"Initialized [b]{directory}[/]")


@cli.command(help="Push a new version of the agent.")
@cli.command(help="Push a new version of the active agent")
@pretty_cli
def push(
directory: t.Annotated[
Expand All @@ -221,6 +220,7 @@ def push(
] = None,
new: t.Annotated[bool, typer.Option("--new", "-n", help="Create a new agent instead of a new version")] = False,
notes: t.Annotated[str | None, typer.Option("--message", "-m", help="Notes for the new version")] = None,
rebuild: t.Annotated[bool, typer.Option("--rebuild", "-r", help="Force rebuild the agent image")] = False,
) -> None:
env = {env_var.split("=")[0]: env_var.split("=")[1] for env_var in env_vars or []}

Expand All @@ -243,7 +243,7 @@ def push(

print()
print(f":wrench: Building agent from [b]{directory}[/] ...")
image = docker.build(directory)
image = docker.build(directory, force_rebuild=rebuild)
repository = f"{registry}/{server_config.username}/agents/{agent_config.project_name}"
tag = tag or image.id[-8:]

Expand Down Expand Up @@ -288,7 +288,7 @@ def push(
print(":tada: Agent pushed. use [bold]dreadnode agent deploy[/] to start a new run.")


@cli.command(help="Start a new run using the latest agent version")
@cli.command(help="Start a new run using the latest active agent version")
@pretty_cli
def deploy(
model: t.Annotated[
Expand Down Expand Up @@ -355,15 +355,15 @@ def models(
print(format_models(strike_response.models))


@cli.command(help="List all strikes")
@cli.command(help="List available strikes")
@pretty_cli
def strikes() -> None:
client = api.create_client()
strikes = client.list_strikes()
print(format_strikes(strikes))


@cli.command(help="Show the latest run of the currently active agent")
@cli.command(help="Show the latest run of the active agent")
@pretty_cli
def latest(
directory: t.Annotated[
Expand Down Expand Up @@ -393,7 +393,7 @@ def latest(
print(format_run(run, verbose=verbose, include_logs=logs))


@cli.command(help="Show the status of the currently active agent")
@cli.command(help="Show the status of the active agent")
@pretty_cli
def show(
directory: t.Annotated[
Expand All @@ -409,7 +409,7 @@ def show(
print(format_agent(agent))


@cli.command(help="List historical versions of this agent")
@cli.command(help="List historical versions of the active agent")
@pretty_cli
def versions(
directory: t.Annotated[
Expand All @@ -424,7 +424,7 @@ def versions(
print(format_agent_versions(agent))


@cli.command(help="List all runs for the currently active agent")
@cli.command(help="List runs for the active agent")
@pretty_cli
def runs(
directory: t.Annotated[
Expand All @@ -447,7 +447,7 @@ def runs(
print(format_runs(runs))


@cli.command(help="List all available links")
@cli.command(help="List available agent links")
@pretty_cli
def links(
directory: t.Annotated[
Expand Down Expand Up @@ -480,7 +480,7 @@ def links(
print(table)


@cli.command(help="Switch/link to a different agent")
@cli.command(help="Switch to a different agent link")
@pretty_cli
def switch(
agent_or_profile: t.Annotated[str, typer.Argument(help="Agent key/id or profile name")],
Expand All @@ -500,3 +500,49 @@ def switch(
return

print(f":exclamation: '{agent_or_profile}' not found, use [bold]dreadnode agent links[/]")


@cli.command(help="Clone a github repository")
@pretty_cli
def clone(
repo: t.Annotated[str, typer.Argument(help="Repository name or URL")],
target: t.Annotated[
pathlib.Path | None, typer.Argument(help="The target directory", file_okay=False, resolve_path=True)
] = None,
) -> None:
github_repo = GithubRepo(repo)

# Check if the target directory exists
target = target or pathlib.Path(github_repo.repo)
if target.exists():
if Prompt.ask(f":axe: Overwrite {target.absolute()}?", choices=["y", "n"], default="n") == "n":
return
print()
shutil.rmtree(target)

# Check if the repo is accessible
if repo_exists(github_repo):
temp_dir = download_and_unzip_archive(github_repo.zip_url)

# This could be a private repo that the user can access
# by getting an access token from our API
elif github_repo.namespace == "dreadnode":
github_access_token = api.create_client().get_github_access_token([github_repo.repo])
print(":key: Accessed private repository")
temp_dir = download_and_unzip_archive(
github_repo.api_zip_url, headers={"Authorization": f"Bearer {github_access_token.token}"}
)

else:
raise Exception(f"Repository '{github_repo}' not found or inaccessible")

# We assume the repo download results in a single
# child folder which is the real target
sub_dirs = list(temp_dir.iterdir())
if len(sub_dirs) == 1 and sub_dirs[0].is_dir():
temp_dir = sub_dirs[0]

shutil.move(temp_dir, target)

print()
print(f":tada: Cloned [b]{repo}[/] to [b]{target.absolute()}[/]")
6 changes: 4 additions & 2 deletions dreadnode_cli/agent/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ def login(registry: str, username: str, password: str) -> None:
client.api.login(username=username, password=password, registry=registry)


def build(directory: str | pathlib.Path) -> Image:
def build(directory: str | pathlib.Path, *, force_rebuild: bool = False) -> Image:
if client is None:
raise Exception("Docker not available")

id: str | None = None
for item in client.api.build(path=str(directory), platform="linux/amd64", decode=True):
for item in client.api.build(
path=str(directory), platform="linux/amd64", decode=True, nocache=force_rebuild, pull=force_rebuild
):
if "error" in item:
print()
raise Exception(item["error"])
Expand Down
20 changes: 14 additions & 6 deletions dreadnode_cli/agent/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,12 @@ def get_status_style(status: api.Client.StrikeRunStatus | api.Client.StrikeRunZo

def get_model_provider_style(provider: str) -> str:
return {
"OpenAI": "spring_green4",
"Hugging Face": "blue",
"Dreadnode (private)": "magenta",
"Anthropic": "tan",
"OpenAI": "turquoise4",
"Hugging Face": "dark_cyan",
"Anthropic": "cornflower_blue",
"Google": "cyan",
"MistralAI": "orange_red1",
"Groq": "red",
"MistralAI": "light_salmon3",
"Groq": "grey63",
}.get(provider, "")


Expand Down Expand Up @@ -165,6 +164,9 @@ def format_zones_summary(zones: list[api.Client.StrikeRunZone]) -> RenderableTyp
output.score.value if hasattr(output, "score") and output.score else 0 for output in zone.outputs
)

if isinstance(zone_score, float):
zone_score = round(zone_score, 2)

table.add_row(
zone.key,
Text(zone.status, style=get_status_style(zone.status)),
Expand All @@ -187,6 +189,9 @@ def format_zones_verbose(zones: list[api.Client.StrikeRunZone], *, include_logs:

zone_score = sum(output.score.value if output.score else 0 for output in zone.outputs)

if isinstance(zone_score, float):
zone_score = round(zone_score, 2)

table.add_row("id", f"[dim]{zone.id}[/]")
table.add_row("score", f"[yellow]{zone_score}[/]" if zone_score else "[dim]0[/]")
table.add_row("outputs", f"[magenta]{len(zone.outputs)}[/]" if zone.outputs else "[dim]0[/]")
Expand All @@ -206,6 +211,9 @@ def format_zones_verbose(zones: list[api.Client.StrikeRunZone], *, include_logs:
outputs_table.add_column("explanation")

for output in zone.outputs:
if output.score and isinstance(output.score.value, float):
output.score.value = round(output.score.value, 3)

outputs_table.add_row(
str(output.score.value) if output.score else "-",
Pretty(output.data),
Expand Down

0 comments on commit 5598f06

Please sign in to comment.