diff --git a/swanlab/api/http.py b/swanlab/api/http.py index 07f5d1c1a..e88debd19 100644 --- a/swanlab/api/http.py +++ b/swanlab/api/http.py @@ -14,7 +14,7 @@ from .cos import CosClient from swanlab.data.modules import MediaBuffer from swanlab.error import NetworkError, ApiError -from swanlab.package import get_host_api +from swanlab.package import get_host_api, get_package_version from swankit.log import FONT from swanlab.log import swanlog import requests @@ -56,6 +56,7 @@ def __init__(self, login_info: LoginInfo): self.__session: Optional[requests.Session] = None # 当前项目所属的username self.__username = login_info.username + self.__version = get_package_version() # 创建会话 self.__create_session() @@ -115,6 +116,7 @@ def __create_session(self): 创建会话,这将在HTTP类实例化时调用 """ session = requests.Session() + session.headers["swanlab-sdk"] = self.__version session.cookies.update({"sid": self.__login_info.sid}) # 注册响应钩子 diff --git a/swanlab/cli/commands/task/__init__.py b/swanlab/cli/commands/task/__init__.py index c9bf1f77d..3dbcdc520 100644 --- a/swanlab/cli/commands/task/__init__.py +++ b/swanlab/cli/commands/task/__init__.py @@ -10,6 +10,7 @@ """ from .launch import launch from .list import list +from .search import search import click __all__ = ["task"] @@ -17,6 +18,9 @@ @click.group() def task(): + """ + Beta Function: launch a task to train on the cloud! + """ pass @@ -24,3 +28,5 @@ def task(): task.add_command(launch) # noinspection PyTypeChecker task.add_command(list) +# noinspection PyTypeChecker +task.add_command(search) diff --git a/swanlab/cli/commands/task/list.py b/swanlab/cli/commands/task/list.py index f80980a8a..725e6ab15 100644 --- a/swanlab/cli/commands/task/list.py +++ b/swanlab/cli/commands/task/list.py @@ -8,7 +8,6 @@ 列出任务状态 """ import time - import click from typing import List from .utils import login_init_sid @@ -18,7 +17,7 @@ from rich.table import Table from rich.live import Live from swanlab.api import get_http -from swanlab.package import get_experiment_url +from .utils import TaskModel @click.command() @@ -40,41 +39,6 @@ def list(max_num: int): # noqa class ListTasksModel: - class TaskListModel: - """ - 获取到的任务列表模型 - """ - - def __init__(self, username: str, task: dict, ): - self.username = username - self.name = task["name"] - """ - 任务名称 - """ - self.python = task["python"] - """ - 任务的python版本 - """ - self.project_name = task.get("pName", None) - """ - 项目名称 - """ - self.experiment_id = task.get("eId", None) - """ - 实验ID - """ - self.created_at = task["createdAt"] - self.started_at = task.get("startedAt", None) - self.finished_at = task.get("finishedAt", None) - self.status = task["status"] - self.msg = task.get("msg", None) - - @property - def url(self): - if self.project_name is None or self.experiment_id is None: - return None - return get_experiment_url(self.username, self.project_name, self.experiment_id) - def __init__(self, num: int, username: str): """ :param num: 最大显示的任务数 @@ -86,33 +50,35 @@ def __init__(self, num: int, username: str): def __dict__(self): return {"num": self.num} - def list(self) -> List[TaskListModel]: + def list(self) -> List[TaskModel]: tasks = self.http.get("/task", self.__dict__()) - return [self.TaskListModel(self.username, task) for task in tasks] + return [TaskModel(self.username, task) for task in tasks] def table(self): st = Table( expand=True, show_header=True, - header_style="bold", - title="[magenta][b]Now Tasks![/b]", + title="[magenta][b]Now Task[/b]", highlight=True, border_style="magenta", ) - st.add_column("Task Name", justify="right") + st.add_column("Task ID", justify="right") + st.add_column("Task Name", justify="center") st.add_column("Status", justify="center") st.add_column("URL", justify="center") - st.add_column("Python Version", justify="center"), - st.add_column("Created Time", justify="center") st.add_column("Started Time", justify="center") st.add_column("Finished Time", justify="center") for tlm in self.list(): + status = tlm.status + if status == "COMPLETED": + status = f"[green]{status}[/green]" + elif status == "CRASHED": + status = f"[red]{status}[/red]" st.add_row( + tlm.cuid, tlm.name, - tlm.status, + status, tlm.url, - tlm.python, - tlm.created_at, tlm.started_at, tlm.finished_at, ) @@ -169,8 +135,13 @@ def term_output(self): title="[blue][b]Log Messages[/b]", highlight=True, border_style="blue", + show_footer=True, + footer_style="bold", + ) + to.add_column( + "Log Output", + "Run [b][white]swanlab task search [Task ID][/white][/b] to get more task info" ) - to.add_column("Log Output") return to def redraw_term_output(self, ): diff --git a/swanlab/cli/commands/task/search.py b/swanlab/cli/commands/task/search.py new file mode 100644 index 000000000..b4b846845 --- /dev/null +++ b/swanlab/cli/commands/task/search.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024/7/26 17:22 +@File: detail.py +@IDE: pycharm +@Description: + 根据cuid获取任务详情 +""" +import click +from swanlab.api import get_http +from .utils import TaskModel, login_init_sid +from rich.syntax import Syntax, Console +import json + + +def validate_six_char_string(_, __, value): + if value is None: + raise click.BadParameter('Parameter is required') + if not isinstance(value, str): + raise click.BadParameter('Value must be a string') + if len(value) != 6: + raise click.BadParameter('String must be exactly 6 characters long') + return value + + +@click.command() +@click.argument("cuid", type=str, callback=validate_six_char_string) +def search(cuid): + """ + Get task detail by cuid + """ + login_info = login_init_sid() + http = get_http() + data = http.get(f"/task/{cuid}") + tm = TaskModel(login_info.username, data) + """ + 任务名称,python版本,入口文件,任务状态,URL,创建时间,执行时间,结束时间,错误信息 + """ + console = Console() + console.print("\n[bold]Task Info[/bold]") + console.print(f"[bold]Task Name:[/bold] [yellow]{tm.name}[/yellow]") + console.print(f"[bold]Python Version:[/bold] [white]{tm.python}[white]") + console.print(f"[bold]Entry File:[/bold] [white]{tm.index}[white]") + icon = '✅' + if tm.status == 'CRASHED': + icon = '❌' + elif tm.status != 'COMPLETED': + icon = '🏃' + console.print(f"[bold]Status:[/bold] {icon} {tm.status}") + tm.url is not None and console.print(f"[bold]SwanLab URL:[/bold] {tm.url}") + console.print(f"[bold]Created At:[/bold] {tm.created_at}") + tm.started_at is not None and console.print(f"[bold]Started At:[/bold] {tm.started_at}") + tm.finished_at is not None and console.print(f"[bold]Finished At:[/bold] {tm.finished_at}") + tm.msg is not None and console.print(f"[bold][red]Task Error[/red]:[/bold] \n\n{tm.msg}\n") diff --git a/swanlab/cli/commands/task/utils.py b/swanlab/cli/commands/task/utils.py index 472d5e7e6..6979c44e3 100644 --- a/swanlab/cli/commands/task/utils.py +++ b/swanlab/cli/commands/task/utils.py @@ -7,9 +7,10 @@ @Description: 任务相关工具函数 """ -from swanlab.package import get_key +from swanlab.package import get_key, get_experiment_url from swanlab.api import terminal_login, create_http, LoginInfo from swanlab.error import KeyFileError +from datetime import datetime def login_init_sid() -> LoginInfo: @@ -21,3 +22,50 @@ def login_init_sid() -> LoginInfo: login_info = terminal_login(key) create_http(login_info) return login_info + + +class TaskModel: + """ + 获取到的任务列表模型 + """ + + def __init__(self, username: str, task: dict, ): + self.cuid = task["cuid"] + self.username = username + self.name = task["name"] + """ + 任务名称 + """ + self.python = task["python"] + """ + 任务的python版本 + """ + self.index = task["index"] + """ + 任务的入口文件 + """ + self.project_name = task.get("pName", None) + """ + 项目名称 + """ + self.experiment_id = task.get("eId", None) + """ + 实验ID + """ + self.created_at = self.fmt_time(task["createdAt"]) + self.started_at = self.fmt_time(task.get("startedAt", None)) + self.finished_at = self.fmt_time(task.get("finishedAt", None)) + self.status = task["status"] + self.msg = task.get("msg", None) + + @property + def url(self): + if self.project_name is None or self.experiment_id is None: + return None + return get_experiment_url(self.username, self.project_name, self.experiment_id) + + @staticmethod + def fmt_time(date: str = None): + if date is None: + return None + return datetime.fromisoformat(date).strftime("%Y-%m-%d %H:%M:%S")