diff --git a/swanlab/api/auth/login.py b/swanlab/api/auth/login.py index baf22107..58fa3344 100644 --- a/swanlab/api/auth/login.py +++ b/swanlab/api/auth/login.py @@ -8,17 +8,19 @@ 用户登录接口,输入用户的apikey,保存用户token到本地 进行一些交互定义和数据请求 """ -from swanlab.error import ValidationError, APIKeyFormatError -from swankit.log import FONT +import getpass +import os +import sys + +import requests from swankit.env import is_windows -from swanlab.package import get_user_setting_path, get_host_api +from swankit.log import FONT + from swanlab.api.info import LoginInfo -from swanlab.log import swanlog from swanlab.env import in_jupyter, SwanLabEnv -import getpass -import requests -import sys -import os +from swanlab.error import ValidationError, APIKeyFormatError +from swanlab.log import swanlog +from swanlab.package import get_user_setting_path, get_host_api def login_request(api_key: str, timeout: int = 20) -> requests.Response: @@ -64,7 +66,6 @@ def input_api_key( _t = sys.excepthook sys.excepthook = _abort_tip if not again: - swanlog.info("Logging into swanlab cloud.") swanlog.info("You can find your API key at: " + FONT.yellow(get_user_setting_path())) # windows 额外打印提示信息 if is_windows(): diff --git a/swanlab/data/callback_cloud.py b/swanlab/data/callback_cloud.py index a8e65423..2235d947 100644 --- a/swanlab/data/callback_cloud.py +++ b/swanlab/data/callback_cloud.py @@ -7,7 +7,6 @@ @Description: 云端回调 """ -import io import json import os import sys @@ -23,7 +22,7 @@ from swanlab.api.upload.model import ColumnModel, ScalarModel, MediaModel, FileModel from swanlab.data.cloud import ThreadPool from swanlab.data.cloud import UploadType -from swanlab.env import in_jupyter, SwanLabEnv +from swanlab.env import in_jupyter, SwanLabEnv, is_interactive from swanlab.error import KeyFileError from swanlab.log import swanlog from swanlab.package import ( @@ -144,7 +143,7 @@ def __init__(self, public: bool): self.public = public @classmethod - def get_login_info(cls): + def create_login_info(cls): """ 发起登录,获取登录信息,执行此方法会覆盖原有的login_info """ @@ -152,15 +151,11 @@ def get_login_info(cls): try: key = get_key() except KeyFileError: - try: - fd = sys.stdin.fileno() - # 不是标准终端,且非jupyter环境,无法控制其回显 - if not os.isatty(fd) and not in_jupyter(): - raise KeyFileError("The key file is not found, call `swanlab.login()` or use `swanlab login` ") - # 当使用capsys、capfd或monkeypatch等fixture来捕获或修改标准输入输出时,会抛出io.UnsupportedOperation - # 这种情况下为用户自定义情况 - except io.UnsupportedOperation: - pass + pass + if key is None and not is_interactive(): + raise KeyFileError( + "api key not configured (no-tty), call `swanlab.login(api_key=[your_api_key])` or set `swanlab.init(mode=\"local\")`." + ) return terminal_login(key) @staticmethod @@ -208,12 +203,11 @@ def __str__(self): def on_init(self, project: str, workspace: str, logdir: str = None, **kwargs) -> int: super(CloudRunCallback, self).on_init(project, workspace, logdir) - # 检测是否有最新的版本 - self._get_package_latest_version() if self.login_info is None: swanlog.debug("Login info is None, get login info.") - self.login_info = self.get_login_info() - + self.login_info = self.create_login_info() + # 检测是否有最新的版本 + self._get_package_latest_version() http = create_http(self.login_info) return http.mount_project(project, workspace, self.public).history_exp_count diff --git a/swanlab/data/sdk.py b/swanlab/data/sdk.py index 9f37324e..34cf6ea7 100644 --- a/swanlab/data/sdk.py +++ b/swanlab/data/sdk.py @@ -12,9 +12,10 @@ from swanboard import SwanBoardCallback from swankit.env import SwanLabMode +from swankit.log import FONT -from swanlab.api import code_login -from swanlab.env import SwanLabEnv +from swanlab.api import code_login, terminal_login +from swanlab.env import SwanLabEnv, is_interactive from swanlab.log import swanlog from .callback_cloud import CloudRunCallback from .callback_local import LocalRunCallback @@ -27,6 +28,8 @@ get_run, ) from .run.helper import SwanLabRunOperator +from ..error import KeyFileError +from ..package import get_key, get_host_web def _check_proj_name(name: str) -> str: @@ -68,7 +71,10 @@ def login(api_key: str = None): """ if SwanLabRun.is_started(): raise RuntimeError("You must call swanlab.login() before using init()") - CloudRunCallback.login_info = code_login(api_key) if api_key else CloudRunCallback.get_login_info() + CloudRunCallback.login_info = code_login(api_key) if api_key else CloudRunCallback.create_login_info() + + +MODES = Literal["disabled", "cloud", "local"] def init( @@ -78,7 +84,7 @@ def init( description: str = None, config: Union[dict, str] = None, logdir: str = None, - mode: Literal["disabled", "cloud", "local"] = None, + mode: MODES = None, load: str = None, public: bool = None, **kwargs, @@ -152,13 +158,9 @@ def init( project = _load_data(load_data, "project", project) workspace = _load_data(load_data, "workspace", workspace) public = _load_data(load_data, "private", public) - - # ---------------------------------- 模式选择 ---------------------------------- - # for - - # ---------------------------------- helper初始化 ---------------------------------- - operator, c = _create_operator(mode, public) project = _check_proj_name(project if project else os.path.basename(os.getcwd())) # 默认实验名称为当前目录名 + # ---------------------------------- 启动操作员 ---------------------------------- + operator, c = _create_operator(mode, public) exp_num = SwanLabRunOperator.parse_return( operator.on_init(project, workspace, logdir=logdir), key=c.__str__() if c else None, @@ -238,6 +240,8 @@ def _init_mode(mode: str = None): 传入的mode必须为SwanLabMode枚举中的一个值,否则报错ValueError 如果环境变量和传入的mode参数都为None,则默认为cloud + 从环境变量中提取mode参数以后,还有一步让用户选择运行模式的交互,详见issue: https://github.com/SwanHubX/SwanLab/issues/632 + :param mode: str, optional 传入的mode参数 :return: str mode @@ -252,8 +256,45 @@ def _init_mode(mode: str = None): if mode is not None and mode not in allowed: raise ValueError(f"`mode` must be one of {allowed}, but got {mode}") mode = "cloud" if mode is None else mode + # 如果mode为cloud,且没找到 api key或者未登录,则提示用户输入 + try: + get_key() + no_api_key = False + except KeyFileError: + no_api_key = True + login_info = None + if mode == "cloud" and no_api_key: + # 判断当前进程是否在交互模式下 + if is_interactive(): + swanlog.info( + f"Using SwanLab to track your experiments. Please refer to {FONT.yellow('https://docs.swanlab.cn')} for more information." + ) + swanlog.info("(1) Create a SwanLab account.") + swanlog.info("(2) Use an existing SwanLab account.") + swanlog.info("(3) Don't visualize my results.") + + # 交互选择 + tip = FONT.swanlab("Enter your choice: ") + code = input(tip) + while code not in ["1", "2", "3"]: + swanlog.warning("Invalid choice, please enter again.") + code = input(tip) + if code == "3": + mode = "local" + elif code == "2": + swanlog.info("You chose 'Use an existing swanlab account'") + swanlog.info("Logging into " + FONT.yellow(get_host_web())) + login_info = terminal_login() + elif code == "1": + swanlog.info("You chose 'Create a swanlab account'") + swanlog.info("Create a SwanLab account here: " + FONT.yellow(get_host_web() + "/login")) + login_info = terminal_login() + else: + raise ValueError("Invalid choice") + + # 如果不在就不管 os.environ[mode_key] = mode - return mode + return mode, login_info def _init_config(config: Union[dict, str]): @@ -284,7 +325,9 @@ def _create_operator(mode: str, public: bool) -> Tuple[SwanLabRunOperator, Optio :param public: 是否公开 :return: SwanLabRunOperator, CloudRunCallback """ - mode = _init_mode(mode) + mode, login_info = _init_mode(mode) + CloudRunCallback.login_info = login_info + if mode == SwanLabMode.DISABLED.value: swanlog.warning("SwanLab run disabled, the data will not be saved or uploaded.") return SwanLabRunOperator(), None diff --git a/swanlab/env.py b/swanlab/env.py index 0472c293..a51c8f80 100644 --- a/swanlab/env.py +++ b/swanlab/env.py @@ -8,7 +8,9 @@ 除了utils和error模块,其他模块都可以使用这个模块 """ import enum +import io import os +import sys from typing import List import swankit.env as E @@ -127,3 +129,17 @@ def in_jupyter() -> bool: return True except NameError: return False + + +def is_interactive(): + """ + 是否为可交互式环境(输入连接tty设备) + 特殊的环境:jupyter notebook + """ + try: + fd = sys.stdin.fileno() + return os.isatty(fd) or in_jupyter() + # 当使用capsys、capfd或monkeypatch等fixture来捕获或修改标准输入输出时,会抛出io.UnsupportedOperation + # 多为测试情况,可交互 + except io.UnsupportedOperation: + return True diff --git a/swanlab/package.py b/swanlab/package.py index 8aec21a1..eeea3618 100644 --- a/swanlab/package.py +++ b/swanlab/package.py @@ -7,19 +7,22 @@ @Description: 用于管理swanlab的包管理器的模块,做一些封装 """ -from .env import get_save_dir, SwanLabEnv -from .error import KeyFileError -from typing import Optional -import requests -import netrc import json +import netrc import os +from typing import Optional + +import requests + +from .env import get_save_dir, SwanLabEnv +from .error import KeyFileError package_path = os.path.join(os.path.dirname(__file__), "package.json") # ---------------------------------- 版本号相关 ---------------------------------- + def get_package_version() -> str: """获取swanlab的版本号 :return: swanlab的版本号 @@ -69,7 +72,7 @@ def get_user_setting_path() -> str: """获取用户设置的url :return: 用户设置的url """ - return get_host_web() + "/settings" + return get_host_web() + "/space/~/settings" def get_project_url(username: str, projname: str) -> str: diff --git a/test/unit/data/test_sdk.py b/test/unit/data/test_sdk.py index 4109ebad..9d555e32 100644 --- a/test/unit/data/test_sdk.py +++ b/test/unit/data/test_sdk.py @@ -46,10 +46,13 @@ def test_init_error_mode(self): S._init_mode("123456") # noqa @pytest.mark.parametrize("mode", ["disabled", "local", "cloud"]) - def test_init_mode(self, mode): + def test_init_mode(self, mode, monkeypatch): """ 初始化时mode参数正确 """ + if mode == 'cloud': + mode = 'local' + monkeypatch.setattr("builtins.input", lambda _: "3") S._init_mode(mode) assert os.environ[MODE] == mode del os.environ[MODE] @@ -58,14 +61,88 @@ def test_init_mode(self, mode): # assert os.environ[MODE] == mode @pytest.mark.parametrize("mode", ["disabled", "local", "cloud"]) - def test_overwrite_mode(self, mode): + def test_overwrite_mode(self, mode, monkeypatch): """ 初始化时mode参数正确,覆盖环境变量 """ + if mode == 'cloud': + mode = 'local' + monkeypatch.setattr("builtins.input", lambda _: "3") os.environ[MODE] = "123456" S._init_mode(mode) assert os.environ[MODE] == mode + def test_no_api_key_to_cloud(self, monkeypatch): + """ + 初始化时mode为cloud,但是没有设置apikey + """ + if SwanLabEnv.API_KEY.value in os.environ: + del os.environ[SwanLabEnv.API_KEY.value] + monkeypatch.setattr("builtins.input", lambda _: "3") + mode, login_info = S._init_mode("cloud") + assert mode == "local" + assert login_info is None + + @pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") + def test_init_cloud_with_no_api_key(self, monkeypatch): + """ + 初始化时mode为cloud,但是没有设置apikey + """ + api_key = os.environ[SwanLabEnv.API_KEY.value] + del os.environ[SwanLabEnv.API_KEY.value] + # 在测试时默认会在交互模式下 + # 接下来需要模拟终端连接,使用monkeypatch + # 三种选择方式: + # 1. 输入api key + # 2. 创建账号 + # 3. 使用本地版 + + # 选择第三种 + monkeypatch.setattr("builtins.input", lambda _: "3") + mode, login_info = S._init_mode("cloud") + assert mode == "local" + assert login_info is None + + # 选择第二种 + monkeypatch.setattr("builtins.input", lambda _: "2") + monkeypatch.setattr("getpass.getpass", lambda _: api_key) + mode, login_info = S._init_mode("cloud") + assert mode == "cloud" + assert login_info is not None + + # 登录后会保存key,测试时需要删除 + os.remove(os.path.join(get_save_dir(), ".netrc")) + + # 选择第一种 + monkeypatch.setattr("builtins.input", lambda _: "1") + monkeypatch.setattr("getpass.getpass", lambda _: api_key) + mode, login_info = S._init_mode("cloud") + assert mode == "cloud" + assert login_info is not None + + # 登录后会保存key,测试时需要删除 + os.remove(os.path.join(get_save_dir(), ".netrc")) + + # 选择其他的 + def create_other_input(): + first = True + + def oi(): + nonlocal first + if first: + first = False + return "123456" + else: + return "3" + + return oi + + other_input = create_other_input() + monkeypatch.setattr("builtins.input", lambda _: other_input()) + mode, login_info = S._init_mode("cloud") + assert mode == "local" + assert login_info is None + class TestInitMode: """ diff --git a/test/unit/test_env.py b/test/unit/test_env.py index 5827631d..3db74234 100644 --- a/test/unit/test_env.py +++ b/test/unit/test_env.py @@ -7,11 +7,12 @@ @Description: 测试swanlab.env模块 """ +import os + import pytest -from swanlab.env import SwanLabEnv import swanlab -import os +from swanlab.env import SwanLabEnv, is_interactive def test_default(): @@ -37,3 +38,8 @@ def test_check(): os.environ[SwanLabEnv.RUNTIME.value] = "124" with pytest.raises(ValueError): SwanLabEnv.check() + + +def test_is_interactive(): + # 测试时默认返回true + assert is_interactive() == True diff --git a/test/unit/test_package.py b/test/unit/test_package.py index 8a6f0a73..94ba6060 100644 --- a/test/unit/test_package.py +++ b/test/unit/test_package.py @@ -1,10 +1,12 @@ +import json +import netrc +import os + +import nanoid import pytest + import swanlab.package as P from swanlab.env import SwanLabEnv, get_save_dir -import nanoid -import netrc -import json -import os _ = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) package_data = json.load(open(os.path.join(_, "swanlab", "package.json"))) @@ -47,7 +49,7 @@ def test_get_user_setting_path(): """ 测试获取用户设置文件路径 """ - assert P.get_user_setting_path() == P.get_host_web() + "/settings" + assert P.get_user_setting_path() == P.get_host_web() + "/space/~/settings" def test_get_project_url(): @@ -66,14 +68,15 @@ def test_get_experiment_url(): username = nanoid.generate() projname = nanoid.generate() expid = nanoid.generate() - assert P.get_experiment_url( - username, projname, - expid - ) == P.get_host_web() + "/@" + username + "/" + projname + "/runs/" + expid + assert ( + P.get_experiment_url(username, projname, expid) + == P.get_host_web() + "/@" + username + "/" + projname + "/runs/" + expid + ) # ---------------------------------- 登录部分 ---------------------------------- + class TestGetKey: @staticmethod def remove_env_key(): @@ -102,12 +105,14 @@ def test_no_file(self): """ self.remove_env_key() from swanlab.error import KeyFileError + with pytest.raises(KeyFileError) as e: P.get_key() assert str(e.value) == "The file does not exist" def test_no_host(self): from swanlab.error import KeyFileError + self.test_ok() # 此时删除了环境变量 host = nanoid.generate() os.environ[SwanLabEnv.API_HOST.value] = host