Skip to content

Commit

Permalink
Merge branch 'main' into pr/674
Browse files Browse the repository at this point in the history
  • Loading branch information
SAKURA-CAT committed Aug 15, 2024
2 parents 208dace + a31da69 commit 890f1dc
Show file tree
Hide file tree
Showing 13 changed files with 198 additions and 73 deletions.
42 changes: 42 additions & 0 deletions .github/workflows/test-when-pr.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Test When PR

on:
pull_request:
paths:
- swanlab/**
- test/**
branches:
- main

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Set Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip' # 缓存 pip 依赖

- name: Install Dependencies
run: |
pip install -r requirements.txt
pip install -r requirements-media.txt
pip install -r requirements-dev.txt
- name: Test
# 部分环境变量没有用到,不过后续可能会用到,所以先保留
run: |
export SWANLAB_RUNTIME=test-no-cloud
export SWANLAB_WEB_HOST=${{ secrets.SWANLAB_WEB_HOST }}
export SWANLAB_API_HOST=${{ secrets.SWANLAB_API_HOST }}
export SWANLAB_API_KEY=${{ secrets.SWANLAB_API_KEY }}
export PYTHONPATH=$PYTHONPATH:.
pytest test/unit
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ torchvision
python-dotenv
freezegun
build
requests-mock
responses
requests-mock==1.12.1 # 不太好用,即将删除
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
swankit==0.1.1b1
swanboard==0.1.3b5
cos-python-sdk-v5
urllib3
requests
click
pyyaml
Expand Down
12 changes: 12 additions & 0 deletions swanlab/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from swankit.log import FONT
from swanlab.log import swanlog
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import json


Expand Down Expand Up @@ -118,8 +120,18 @@ def __before_request(self):
def __create_session(self):
"""
创建会话,这将在HTTP类实例化时调用
添加了重试策略
"""
session = requests.Session()
retry = Retry(
total=3,
backoff_factor=0.1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=frozenset(["GET", "POST", "PUT", "DELETE", "PATCH"]),
)
adapter = HTTPAdapter(max_retries=retry)
session.mount("https://", adapter)

session.headers["swanlab-sdk"] = self.__version
session.cookies.update({"sid": self.__login_info.sid})

Expand Down
2 changes: 1 addition & 1 deletion swanlab/api/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __str__(self) -> str:
return "Error api key"
if self.__resp.reason == "Forbidden":
return "You need to be verified first"
return self.__resp.reason
return str(self.__resp.status_code) + " " + self.__resp.reason

def save(self):
"""
Expand Down
9 changes: 8 additions & 1 deletion swanlab/data/run/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ def json_serializable(obj):
# 对于可变映射,递归调用此函数处理值,并将key转换为字典
elif isinstance(obj, MutableMapping):
return {str(key): json_serializable(value) for key, value in obj.items()}
raise TypeError(f"Object {obj} is not JSON serializable")

try:
return str(obj)
except Exception:
raise TypeError(f"Object: {obj} is not JSON serializable")


def third_party_config_process(data) -> dict:
Expand Down Expand Up @@ -84,16 +88,19 @@ def parse(config) -> dict:
"""
if config is None:
return {}

# 1. 第三方配置类型判断与转换
try:
return third_party_config_process(config)
except TypeError:
pass

# 2. 将config转换为可被json序列化的字典
try:
return json_serializable(config)
except TypeError: # noqa
pass

# 3. 尝试序列化,序列化成功直接返回
try:
return json.loads(json.dumps(config))
Expand Down
52 changes: 29 additions & 23 deletions swanlab/data/run/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ class SwanLabRun:
"""

def __init__(
self,
project_name: str = None,
experiment_name: str = None,
description: str = None,
run_config=None,
log_level: str = None,
suffix: str = None,
exp_num: int = None,
operator: SwanLabRunOperator = SwanLabRunOperator(),
self,
project_name: str = None,
experiment_name: str = None,
description: str = None,
run_config=None,
log_level: str = None,
suffix: str = None,
exp_num: int = None,
operator: SwanLabRunOperator = SwanLabRunOperator(),
):
"""
Initializing the SwanLabRun class involves configuring the settings and initiating other logging processes.
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(
# ---------------------------------- 初始化类内参数 ----------------------------------
self.__project_name = project_name
# 生成一个唯一的id,随机生成一个8位的16进制字符串,小写
_id = hex(random.randint(0, 2 ** 32 - 1))[2:].zfill(8)
_id = hex(random.randint(0, 2**32 - 1))[2:].zfill(8)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.__run_id = "run-{}-{}".format(timestamp, _id)
# 操作员初始化
Expand All @@ -101,6 +101,11 @@ def __init__(
# ---------------------------------- 初始化日志记录器 ----------------------------------
swanlog.level = self.__check_log_level(log_level)
# ---------------------------------- 初始化配置 ----------------------------------
# 如果config是以下几个类别之一,则抛出异常
if isinstance(run_config, (int, float, str, bool, list, tuple, set)):
raise TypeError(
f"config: {run_config} (type: {type(run_config)}) is not a json serialized dict (Surpport type is dict, MutableMapping, omegaconf.DictConfig, Argparse.Namespace), please check it"
)
global config
config.update(run_config)
setattr(config, "_SwanLabConfig__on_setter", self.__operator.on_runtime_info_update)
Expand All @@ -123,10 +128,11 @@ def _(state: SwanLabRunState):
# 执行__save,必须在on_run之后,因为on_run之前部分的信息还没完全初始化
getattr(config, "_SwanLabConfig__save")()
# 系统信息采集
self.__operator.on_runtime_info_update(RuntimeInfo(
requirements=get_requirements(),
metadata=get_system_info(get_package_version(), self.settings.log_dir)
))
self.__operator.on_runtime_info_update(
RuntimeInfo(
requirements=get_requirements(), metadata=get_system_info(get_package_version(), self.settings.log_dir)
)
)

@property
def operator(self) -> SwanLabRunOperator:
Expand Down Expand Up @@ -302,10 +308,10 @@ def log(self, data: dict, step: int = None):
v = DataWrapper(k, [v])
# 为List[MediaType]或者List[Line]类型,且长度大于0,且所有元素类型相同
elif (
isinstance(v, list)
and len(v) > 0
and all([isinstance(i, (Line, MediaType)) for i in v])
and all([i.__class__ == v[0].__class__ for i in v])
isinstance(v, list)
and len(v) > 0
and all([isinstance(i, (Line, MediaType)) for i in v])
and all([i.__class__ == v[0].__class__ for i in v])
):
v = DataWrapper(k, v)
else:
Expand All @@ -323,11 +329,11 @@ def __str__(self) -> str:
return self.__run_id

def __register_exp(
self,
experiment_name: str,
description: str = None,
suffix: str = None,
num: int = None,
self,
experiment_name: str,
description: str = None,
suffix: str = None,
num: int = None,
) -> SwanLabExp:
"""
注册实验,将实验配置写入数据库中,完成实验配置的初始化
Expand Down
22 changes: 11 additions & 11 deletions swanlab/data/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ def login(api_key: str = None):


def init(
project: str = None,
workspace: str = None,
experiment_name: str = None,
description: str = None,
config: Union[dict, str] = None,
logdir: str = None,
suffix: Union[str, None, bool] = "default",
mode: Literal["disabled", "cloud", "local"] = None,
load: str = None,
public: bool = None,
**kwargs,
project: str = None,
workspace: str = None,
experiment_name: str = None,
description: str = None,
config: Union[dict, str] = None,
logdir: str = None,
suffix: Union[str, None, bool] = "default",
mode: Literal["disabled", "cloud", "local"] = None,
load: str = None,
public: bool = None,
**kwargs,
) -> SwanLabRun:
"""
Start a new run to track and log. Once you have called this function, you can use 'swanlab.log' to log data to
Expand Down
File renamed without changes.
35 changes: 29 additions & 6 deletions test/unit/api/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
开发环境下存储凭证过期时间为3s
"""
import os
import time
import nanoid
from swanlab.api.http import create_http, HTTP, CosClient
from swanlab.api.auth.login import login_by_key
from swanlab.data.modules import MediaBuffer
from tutils import API_KEY, TEMP_PATH, is_skip_cloud_test
from tutils.setup import UseMocker, UseSetupHttp
import pytest
import responses
from responses import registries

alphabet = "abcdefghijklmnopqrstuvwxyz"

Expand All @@ -32,6 +33,26 @@ def test_decode_response():
assert data == "test"


@responses.activate(registry=registries.OrderedRegistry)
def test_retry():
"""
测试重试机制
"""
from swanlab.package import get_host_api
url = get_host_api() + "/retry"
rsp1 = responses.get(url, body="Error", status=500)
rsp2 = responses.get(url, body="Error", status=500)
rsp3 = responses.get(url, body="Error", status=500)
rsp4 = responses.get(url, body="OK", status=200)
with UseSetupHttp() as http:
data = http.get("/retry")
assert data == "OK"
assert rsp1.call_count == 1
assert rsp2.call_count == 1
assert rsp3.call_count == 1
assert rsp4.call_count == 1


@pytest.mark.skipif(is_skip_cloud_test, reason="skip cloud test")
class TestCosSuite:
http: HTTP = None
Expand Down Expand Up @@ -67,8 +88,10 @@ def test_cos_upload(self):
buffer.write(b"test")
buffer.file_name = "test"
self.http.upload(buffer)
# 开发版本设置的过期时间为3s,等待过期
time.sleep(3)
# 重新上传,测试刷新
assert self.http.cos.should_refresh is True
self.http.upload(buffer)
# 为了开发方便,测试刷新功能关闭

# # 开发版本设置的过期时间为3s,等待过期
# time.sleep(3)
# # 重新上传,测试刷新
# assert self.http.cos.should_refresh is True
# self.http.upload(buffer)
Loading

0 comments on commit 890f1dc

Please sign in to comment.