Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add choices for swanlab.init #783

Merged
merged 6 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions swanlab/api/auth/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@
用户登录接口,输入用户的apikey,保存用户token到本地
进行一些交互定义和数据请求
"""
from swanlab.error import ValidationError, APIKeyFormatError
from swankit.log import FONT
import getpass
import os
import sys

import requests
from swankit.env import is_windows
from swanlab.package import get_user_setting_path, get_host_api
from swankit.log import FONT

from swanlab.api.info import LoginInfo
from swanlab.log import swanlog
from swanlab.env import in_jupyter, SwanLabEnv
import getpass
import requests
import sys
import os
from swanlab.error import ValidationError, APIKeyFormatError
from swanlab.log import swanlog
from swanlab.package import get_user_setting_path, get_host_api


def login_request(api_key: str, timeout: int = 20) -> requests.Response:
Expand Down Expand Up @@ -64,7 +66,6 @@ def input_api_key(
_t = sys.excepthook
sys.excepthook = _abort_tip
if not again:
swanlog.info("Logging into swanlab cloud.")
swanlog.info("You can find your API key at: " + FONT.yellow(get_user_setting_path()))
# windows 额外打印提示信息
if is_windows():
Expand Down
26 changes: 10 additions & 16 deletions swanlab/data/callback_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
@Description:
云端回调
"""
import io
import json
import os
import sys
Expand All @@ -23,7 +22,7 @@
from swanlab.api.upload.model import ColumnModel, ScalarModel, MediaModel, FileModel
from swanlab.data.cloud import ThreadPool
from swanlab.data.cloud import UploadType
from swanlab.env import in_jupyter, SwanLabEnv
from swanlab.env import in_jupyter, SwanLabEnv, is_interactive
from swanlab.error import KeyFileError
from swanlab.log import swanlog
from swanlab.package import (
Expand Down Expand Up @@ -144,23 +143,19 @@ def __init__(self, public: bool):
self.public = public

@classmethod
def get_login_info(cls):
def create_login_info(cls):
"""
发起登录,获取登录信息,执行此方法会覆盖原有的login_info
"""
key = None
try:
key = get_key()
except KeyFileError:
try:
fd = sys.stdin.fileno()
# 不是标准终端,且非jupyter环境,无法控制其回显
if not os.isatty(fd) and not in_jupyter():
raise KeyFileError("The key file is not found, call `swanlab.login()` or use `swanlab login` ")
# 当使用capsys、capfd或monkeypatch等fixture来捕获或修改标准输入输出时,会抛出io.UnsupportedOperation
# 这种情况下为用户自定义情况
except io.UnsupportedOperation:
pass
pass
if key is None and not is_interactive():
raise KeyFileError(
"api key not configured (no-tty), call `swanlab.login(api_key=[your_api_key])` or set `swanlab.init(mode=\"local\")`."
)
return terminal_login(key)

@staticmethod
Expand Down Expand Up @@ -208,12 +203,11 @@ def __str__(self):

def on_init(self, project: str, workspace: str, logdir: str = None, **kwargs) -> int:
super(CloudRunCallback, self).on_init(project, workspace, logdir)
# 检测是否有最新的版本
self._get_package_latest_version()
if self.login_info is None:
swanlog.debug("Login info is None, get login info.")
self.login_info = self.get_login_info()

self.login_info = self.create_login_info()
# 检测是否有最新的版本
self._get_package_latest_version()
http = create_http(self.login_info)
return http.mount_project(project, workspace, self.public).history_exp_count

Expand Down
67 changes: 55 additions & 12 deletions swanlab/data/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@

from swanboard import SwanBoardCallback
from swankit.env import SwanLabMode
from swankit.log import FONT

from swanlab.api import code_login
from swanlab.env import SwanLabEnv
from swanlab.api import code_login, terminal_login
from swanlab.env import SwanLabEnv, is_interactive
from swanlab.log import swanlog
from .callback_cloud import CloudRunCallback
from .callback_local import LocalRunCallback
Expand All @@ -27,6 +28,8 @@
get_run,
)
from .run.helper import SwanLabRunOperator
from ..error import KeyFileError
from ..package import get_key, get_host_web


def _check_proj_name(name: str) -> str:
Expand Down Expand Up @@ -68,7 +71,10 @@ def login(api_key: str = None):
"""
if SwanLabRun.is_started():
raise RuntimeError("You must call swanlab.login() before using init()")
CloudRunCallback.login_info = code_login(api_key) if api_key else CloudRunCallback.get_login_info()
CloudRunCallback.login_info = code_login(api_key) if api_key else CloudRunCallback.create_login_info()


MODES = Literal["disabled", "cloud", "local"]


def init(
Expand All @@ -78,7 +84,7 @@ def init(
description: str = None,
config: Union[dict, str] = None,
logdir: str = None,
mode: Literal["disabled", "cloud", "local"] = None,
mode: MODES = None,
load: str = None,
public: bool = None,
**kwargs,
Expand Down Expand Up @@ -152,13 +158,9 @@ def init(
project = _load_data(load_data, "project", project)
workspace = _load_data(load_data, "workspace", workspace)
public = _load_data(load_data, "private", public)

# ---------------------------------- 模式选择 ----------------------------------
# for

# ---------------------------------- helper初始化 ----------------------------------
operator, c = _create_operator(mode, public)
project = _check_proj_name(project if project else os.path.basename(os.getcwd())) # 默认实验名称为当前目录名
# ---------------------------------- 启动操作员 ----------------------------------
operator, c = _create_operator(mode, public)
exp_num = SwanLabRunOperator.parse_return(
operator.on_init(project, workspace, logdir=logdir),
key=c.__str__() if c else None,
Expand Down Expand Up @@ -238,6 +240,8 @@ def _init_mode(mode: str = None):
传入的mode必须为SwanLabMode枚举中的一个值,否则报错ValueError
如果环境变量和传入的mode参数都为None,则默认为cloud

从环境变量中提取mode参数以后,还有一步让用户选择运行模式的交互,详见issue: https://github.com/SwanHubX/SwanLab/issues/632

:param mode: str, optional
传入的mode参数
:return: str mode
Expand All @@ -252,8 +256,45 @@ def _init_mode(mode: str = None):
if mode is not None and mode not in allowed:
raise ValueError(f"`mode` must be one of {allowed}, but got {mode}")
mode = "cloud" if mode is None else mode
# 如果mode为cloud,且没找到 api key或者未登录,则提示用户输入
try:
get_key()
no_api_key = False
except KeyFileError:
no_api_key = True
login_info = None
if mode == "cloud" and no_api_key:
# 判断当前进程是否在交互模式下
if is_interactive():
swanlog.info(
f"Using SwanLab to track your experiments. Please refer to {FONT.yellow('https://docs.swanlab.cn')} for more information."
)
swanlog.info("(1) Create a SwanLab account.")
swanlog.info("(2) Use an existing SwanLab account.")
swanlog.info("(3) Don't visualize my results.")

# 交互选择
tip = FONT.swanlab("Enter your choice: ")
code = input(tip)
while code not in ["1", "2", "3"]:
swanlog.warning("Invalid choice, please enter again.")
code = input(tip)
if code == "3":
mode = "local"
elif code == "2":
swanlog.info("You chose 'Use an existing swanlab account'")
swanlog.info("Logging into " + FONT.yellow(get_host_web()))
login_info = terminal_login()
elif code == "1":
swanlog.info("You chose 'Create a swanlab account'")
swanlog.info("Create a SwanLab account here: " + FONT.yellow(get_host_web() + "/login"))
login_info = terminal_login()
else:
raise ValueError("Invalid choice")

# 如果不在就不管
os.environ[mode_key] = mode
return mode
return mode, login_info


def _init_config(config: Union[dict, str]):
Expand Down Expand Up @@ -284,7 +325,9 @@ def _create_operator(mode: str, public: bool) -> Tuple[SwanLabRunOperator, Optio
:param public: 是否公开
:return: SwanLabRunOperator, CloudRunCallback
"""
mode = _init_mode(mode)
mode, login_info = _init_mode(mode)
CloudRunCallback.login_info = login_info

if mode == SwanLabMode.DISABLED.value:
swanlog.warning("SwanLab run disabled, the data will not be saved or uploaded.")
return SwanLabRunOperator(), None
Expand Down
16 changes: 16 additions & 0 deletions swanlab/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
除了utils和error模块,其他模块都可以使用这个模块
"""
import enum
import io
import os
import sys
from typing import List

import swankit.env as E
Expand Down Expand Up @@ -127,3 +129,17 @@ def in_jupyter() -> bool:
return True
except NameError:
return False


def is_interactive():
"""
是否为可交互式环境(输入连接tty设备)
特殊的环境:jupyter notebook
"""
try:
fd = sys.stdin.fileno()
return os.isatty(fd) or in_jupyter()
# 当使用capsys、capfd或monkeypatch等fixture来捕获或修改标准输入输出时,会抛出io.UnsupportedOperation
# 多为测试情况,可交互
except io.UnsupportedOperation:
return True
15 changes: 9 additions & 6 deletions swanlab/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@
@Description:
用于管理swanlab的包管理器的模块,做一些封装
"""
from .env import get_save_dir, SwanLabEnv
from .error import KeyFileError
from typing import Optional
import requests
import netrc
import json
import netrc
import os
from typing import Optional

import requests

from .env import get_save_dir, SwanLabEnv
from .error import KeyFileError

package_path = os.path.join(os.path.dirname(__file__), "package.json")


# ---------------------------------- 版本号相关 ----------------------------------


def get_package_version() -> str:
"""获取swanlab的版本号
:return: swanlab的版本号
Expand Down Expand Up @@ -69,7 +72,7 @@ def get_user_setting_path() -> str:
"""获取用户设置的url
:return: 用户设置的url
"""
return get_host_web() + "/settings"
return get_host_web() + "/space/~/settings"


def get_project_url(username: str, projname: str) -> str:
Expand Down
81 changes: 79 additions & 2 deletions test/unit/data/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,13 @@ def test_init_error_mode(self):
S._init_mode("123456") # noqa

@pytest.mark.parametrize("mode", ["disabled", "local", "cloud"])
def test_init_mode(self, mode):
def test_init_mode(self, mode, monkeypatch):
"""
初始化时mode参数正确
"""
if mode == 'cloud':
mode = 'local'
monkeypatch.setattr("builtins.input", lambda _: "3")
S._init_mode(mode)
assert os.environ[MODE] == mode
del os.environ[MODE]
Expand All @@ -58,14 +61,88 @@ def test_init_mode(self, mode):
# assert os.environ[MODE] == mode

@pytest.mark.parametrize("mode", ["disabled", "local", "cloud"])
def test_overwrite_mode(self, mode):
def test_overwrite_mode(self, mode, monkeypatch):
"""
初始化时mode参数正确,覆盖环境变量
"""
if mode == 'cloud':
mode = 'local'
monkeypatch.setattr("builtins.input", lambda _: "3")
os.environ[MODE] = "123456"
S._init_mode(mode)
assert os.environ[MODE] == mode

def test_no_api_key_to_cloud(self, monkeypatch):
"""
初始化时mode为cloud,但是没有设置apikey
"""
if SwanLabEnv.API_KEY.value in os.environ:
del os.environ[SwanLabEnv.API_KEY.value]
monkeypatch.setattr("builtins.input", lambda _: "3")
mode, login_info = S._init_mode("cloud")
assert mode == "local"
assert login_info is None

@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test")
def test_init_cloud_with_no_api_key(self, monkeypatch):
"""
初始化时mode为cloud,但是没有设置apikey
"""
api_key = os.environ[SwanLabEnv.API_KEY.value]
del os.environ[SwanLabEnv.API_KEY.value]
# 在测试时默认会在交互模式下
# 接下来需要模拟终端连接,使用monkeypatch
# 三种选择方式:
# 1. 输入api key
# 2. 创建账号
# 3. 使用本地版

# 选择第三种
monkeypatch.setattr("builtins.input", lambda _: "3")
mode, login_info = S._init_mode("cloud")
assert mode == "local"
assert login_info is None

# 选择第二种
monkeypatch.setattr("builtins.input", lambda _: "2")
monkeypatch.setattr("getpass.getpass", lambda _: api_key)
mode, login_info = S._init_mode("cloud")
assert mode == "cloud"
assert login_info is not None

# 登录后会保存key,测试时需要删除
os.remove(os.path.join(get_save_dir(), ".netrc"))

# 选择第一种
monkeypatch.setattr("builtins.input", lambda _: "1")
monkeypatch.setattr("getpass.getpass", lambda _: api_key)
mode, login_info = S._init_mode("cloud")
assert mode == "cloud"
assert login_info is not None

# 登录后会保存key,测试时需要删除
os.remove(os.path.join(get_save_dir(), ".netrc"))

# 选择其他的
def create_other_input():
first = True

def oi():
nonlocal first
if first:
first = False
return "123456"
else:
return "3"

return oi

other_input = create_other_input()
monkeypatch.setattr("builtins.input", lambda _: other_input())
mode, login_info = S._init_mode("cloud")
assert mode == "local"
assert login_info is None


class TestInitMode:
"""
Expand Down
Loading
Loading