From 62d15f96cba0d7cfa529b9727c57a6b438e88c37 Mon Sep 17 00:00:00 2001 From: Ze-Yi LIN <58305964+Zeyi-Lin@users.noreply.github.com> Date: Tue, 13 Aug 2024 18:02:01 +0800 Subject: [PATCH 1/2] bugfix/config_json_serializable (#671) * fix swanlab config --- swanlab/data/run/config.py | 9 +++- swanlab/data/run/main.py | 52 +++++++++++---------- swanlab/data/sdk.py | 22 ++++----- test/unit/data/run/test_config.py | 76 +++++++++++++++++++++---------- 4 files changed, 100 insertions(+), 59 deletions(-) diff --git a/swanlab/data/run/config.py b/swanlab/data/run/config.py index 772781fb1..9ca171d67 100644 --- a/swanlab/data/run/config.py +++ b/swanlab/data/run/config.py @@ -50,7 +50,11 @@ def json_serializable(obj): # 对于可变映射,递归调用此函数处理值,并将key转换为字典 elif isinstance(obj, MutableMapping): return {str(key): json_serializable(value) for key, value in obj.items()} - raise TypeError(f"Object {obj} is not JSON serializable") + + try: + return str(obj) + except Exception: + raise TypeError(f"Object: {obj} is not JSON serializable") def third_party_config_process(data) -> dict: @@ -84,16 +88,19 @@ def parse(config) -> dict: """ if config is None: return {} + # 1. 第三方配置类型判断与转换 try: return third_party_config_process(config) except TypeError: pass + # 2. 将config转换为可被json序列化的字典 try: return json_serializable(config) except TypeError: # noqa pass + # 3. 尝试序列化,序列化成功直接返回 try: return json.loads(json.dumps(config)) diff --git a/swanlab/data/run/main.py b/swanlab/data/run/main.py index 9a91e82f4..5610228a6 100644 --- a/swanlab/data/run/main.py +++ b/swanlab/data/run/main.py @@ -42,15 +42,15 @@ class SwanLabRun: """ def __init__( - self, - project_name: str = None, - experiment_name: str = None, - description: str = None, - run_config=None, - log_level: str = None, - suffix: str = None, - exp_num: int = None, - operator: SwanLabRunOperator = SwanLabRunOperator(), + self, + project_name: str = None, + experiment_name: str = None, + description: str = None, + run_config=None, + log_level: str = None, + suffix: str = None, + exp_num: int = None, + operator: SwanLabRunOperator = SwanLabRunOperator(), ): """ Initializing the SwanLabRun class involves configuring the settings and initiating other logging processes. @@ -86,7 +86,7 @@ def __init__( # ---------------------------------- 初始化类内参数 ---------------------------------- self.__project_name = project_name # 生成一个唯一的id,随机生成一个8位的16进制字符串,小写 - _id = hex(random.randint(0, 2 ** 32 - 1))[2:].zfill(8) + _id = hex(random.randint(0, 2**32 - 1))[2:].zfill(8) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.__run_id = "run-{}-{}".format(timestamp, _id) # 操作员初始化 @@ -101,6 +101,11 @@ def __init__( # ---------------------------------- 初始化日志记录器 ---------------------------------- swanlog.level = self.__check_log_level(log_level) # ---------------------------------- 初始化配置 ---------------------------------- + # 如果config是以下几个类别之一,则抛出异常 + if isinstance(run_config, (int, float, str, bool, list, tuple, set)): + raise TypeError( + f"config: {run_config} (type: {type(run_config)}) is not a json serialized dict (Surpport type is dict, MutableMapping, omegaconf.DictConfig, Argparse.Namespace), please check it" + ) global config config.update(run_config) setattr(config, "_SwanLabConfig__on_setter", self.__operator.on_runtime_info_update) @@ -123,10 +128,11 @@ def _(state: SwanLabRunState): # 执行__save,必须在on_run之后,因为on_run之前部分的信息还没完全初始化 getattr(config, "_SwanLabConfig__save")() # 系统信息采集 - self.__operator.on_runtime_info_update(RuntimeInfo( - requirements=get_requirements(), - metadata=get_system_info(get_package_version(), self.settings.log_dir) - )) + self.__operator.on_runtime_info_update( + RuntimeInfo( + requirements=get_requirements(), metadata=get_system_info(get_package_version(), self.settings.log_dir) + ) + ) @property def operator(self) -> SwanLabRunOperator: @@ -302,10 +308,10 @@ def log(self, data: dict, step: int = None): v = DataWrapper(k, [v]) # 为List[MediaType]或者List[Line]类型,且长度大于0,且所有元素类型相同 elif ( - isinstance(v, list) - and len(v) > 0 - and all([isinstance(i, (Line, MediaType)) for i in v]) - and all([i.__class__ == v[0].__class__ for i in v]) + isinstance(v, list) + and len(v) > 0 + and all([isinstance(i, (Line, MediaType)) for i in v]) + and all([i.__class__ == v[0].__class__ for i in v]) ): v = DataWrapper(k, v) else: @@ -323,11 +329,11 @@ def __str__(self) -> str: return self.__run_id def __register_exp( - self, - experiment_name: str, - description: str = None, - suffix: str = None, - num: int = None, + self, + experiment_name: str, + description: str = None, + suffix: str = None, + num: int = None, ) -> SwanLabExp: """ 注册实验,将实验配置写入数据库中,完成实验配置的初始化 diff --git a/swanlab/data/sdk.py b/swanlab/data/sdk.py index e0a9a2a24..c632e49e5 100644 --- a/swanlab/data/sdk.py +++ b/swanlab/data/sdk.py @@ -70,17 +70,17 @@ def login(api_key: str = None): def init( - project: str = None, - workspace: str = None, - experiment_name: str = None, - description: str = None, - config: Union[dict, str] = None, - logdir: str = None, - suffix: Union[str, None, bool] = "default", - mode: Literal["disabled", "cloud", "local"] = None, - load: str = None, - public: bool = None, - **kwargs, + project: str = None, + workspace: str = None, + experiment_name: str = None, + description: str = None, + config: Union[dict, str] = None, + logdir: str = None, + suffix: Union[str, None, bool] = "default", + mode: Literal["disabled", "cloud", "local"] = None, + load: str = None, + public: bool = None, + **kwargs, ) -> SwanLabRun: """ Start a new run to track and log. Once you have called this function, you can use 'swanlab.log' to log data to diff --git a/test/unit/data/run/test_config.py b/test/unit/data/run/test_config.py index 40051de90..2b2159d34 100644 --- a/test/unit/data/run/test_config.py +++ b/test/unit/data/run/test_config.py @@ -68,6 +68,7 @@ def __iter__(self): config = parse(config_data) assert config["inf"] == Line.inf assert config["nan"] == Line.nan + # ---------------------------------- dataclass support ---------------------------------- @dataclass class MyData: @@ -294,12 +295,14 @@ def test_use_dict(self): """ 正常流程,输入字典 """ - run = SwanLabRun(run_config={ - "a": 1, - "b": "mnist", - "c/d": [1, 2, 3], - "e/f/h": {"a": 1, "b": {"c": 2}}, - }) + run = SwanLabRun( + run_config={ + "a": 1, + "b": "mnist", + "c/d": [1, 2, 3], + "e/f/h": {"a": 1, "b": {"c": 2}}, + } + ) config = run.config _config = get_config() assert config["a"] == _config["a"] == 1 @@ -310,12 +313,16 @@ def test_use_omegaconf(self): """ 正常流程,输入OmegaConf """ - run = SwanLabRun(run_config=omegaconf.OmegaConf.create({ - "a": 1, - "b": "mnist", - "c/d": [1, 2, 3], - "e/f/h": {"a": 1, "b": {"c": 2}}, - })) + run = SwanLabRun( + run_config=omegaconf.OmegaConf.create( + { + "a": 1, + "b": "mnist", + "c/d": [1, 2, 3], + "e/f/h": {"a": 1, "b": {"c": 2}}, + } + ) + ) config = run.config _config = get_config() assert config["a"] == _config["a"] == 1 @@ -337,12 +344,16 @@ def test_use_config(self): """ 正常流程,输入SwanLabConfig """ - run = SwanLabRun(run_config=SwanLabConfig({ - "a": 1, - "b": "mnist", - "c": [1, 2, 3], - "e/f/h": {"a": 1, "b": {"c": 2}}, - })) + run = SwanLabRun( + run_config=SwanLabConfig( + { + "a": 1, + "b": "mnist", + "c": [1, 2, 3], + "e/f/h": {"a": 1, "b": {"c": 2}}, + } + ) + ) config = run.config _config = get_config() assert config["a"] == _config["a"] == 1 @@ -353,14 +364,31 @@ def test_after_finish(self): """ 测试在finish之后config的变化 """ - run = SwanLabRun(run_config={ - "a": 1, - "b": "mnist", - "c/d": [1, 2, 3], - "e/f/h": {"a": 1, "b": {"c": 2}}, - }) + run = SwanLabRun( + run_config={ + "a": 1, + "b": "mnist", + "c/d": [1, 2, 3], + "e/f/h": {"a": 1, "b": {"c": 2}}, + } + ) run.finish() config = run.config _config = get_config() assert len(config) == 4 assert len(_config) == 0 + + def test_error_config_input(self): + """ + 测试错误的输入 + """ + with pytest.raises(TypeError): + SwanLabRun(run_config=1) + with pytest.raises(TypeError): + SwanLabRun(run_config="1") + with pytest.raises(TypeError): + SwanLabRun(run_config=[1, 2, 3]) + with pytest.raises(TypeError): + SwanLabRun(run_config=(1, 2, 3)) + with pytest.raises(TypeError): + SwanLabRun(run_config=True) From a31da6926e065d8ccc171e20f2190499615a9dd0 Mon Sep 17 00:00:00 2001 From: KAAANG <79990647+SAKURA-CAT@users.noreply.github.com> Date: Wed, 14 Aug 2024 17:20:29 +0800 Subject: [PATCH 2/2] Create auto test script (#676) * add retry * auto test * fix some bug --------- Co-authored-by: Zirui Cai <74649535+Feudalman@users.noreply.github.com> --- .github/workflows/test-when-pr.yml | 42 +++++++++++++++++++++++++ requirements-dev.txt | 3 +- requirements.txt | 1 + swanlab/api/http.py | 12 +++++++ swanlab/api/info.py | 2 +- test/unit/_/{setup.py => test_setup.py} | 0 test/unit/api/test_http.py | 35 +++++++++++++++++---- test/unit/test_package.py | 8 +++-- tutils/setup.py | 9 +++--- 9 files changed, 98 insertions(+), 14 deletions(-) create mode 100644 .github/workflows/test-when-pr.yml rename test/unit/_/{setup.py => test_setup.py} (100%) diff --git a/.github/workflows/test-when-pr.yml b/.github/workflows/test-when-pr.yml new file mode 100644 index 000000000..4702fd81b --- /dev/null +++ b/.github/workflows/test-when-pr.yml @@ -0,0 +1,42 @@ +name: Test When PR + +on: + pull_request: + paths: + - swanlab/** + - test/** + branches: + - main + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ] + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' # 缓存 pip 依赖 + + - name: Install Dependencies + run: | + pip install -r requirements.txt + pip install -r requirements-media.txt + pip install -r requirements-dev.txt + + - name: Test + # 部分环境变量没有用到,不过后续可能会用到,所以先保留 + run: | + export SWANLAB_RUNTIME=test-no-cloud + export SWANLAB_WEB_HOST=${{ secrets.SWANLAB_WEB_HOST }} + export SWANLAB_API_HOST=${{ secrets.SWANLAB_API_HOST }} + export SWANLAB_API_KEY=${{ secrets.SWANLAB_API_KEY }} + export PYTHONPATH=$PYTHONPATH:. + pytest test/unit diff --git a/requirements-dev.txt b/requirements-dev.txt index 5a20368d9..4650ab38d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,4 +10,5 @@ torchvision python-dotenv freezegun build -requests-mock \ No newline at end of file +responses +requests-mock==1.12.1 # 不太好用,即将删除 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c4f8c500f..a4b3d4399 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ swankit==0.1.1b1 swanboard==0.1.3b5 cos-python-sdk-v5 +urllib3 requests click pyyaml diff --git a/swanlab/api/http.py b/swanlab/api/http.py index de2119b2c..030b3c852 100644 --- a/swanlab/api/http.py +++ b/swanlab/api/http.py @@ -18,6 +18,8 @@ from swankit.log import FONT from swanlab.log import swanlog import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry import json @@ -118,8 +120,18 @@ def __before_request(self): def __create_session(self): """ 创建会话,这将在HTTP类实例化时调用 + 添加了重试策略 """ session = requests.Session() + retry = Retry( + total=3, + backoff_factor=0.1, + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=frozenset(["GET", "POST", "PUT", "DELETE", "PATCH"]), + ) + adapter = HTTPAdapter(max_retries=retry) + session.mount("https://", adapter) + session.headers["swanlab-sdk"] = self.__version session.cookies.update({"sid": self.__login_info.sid}) diff --git a/swanlab/api/info.py b/swanlab/api/info.py index b90bb5f5d..5a7b7ae5e 100644 --- a/swanlab/api/info.py +++ b/swanlab/api/info.py @@ -81,7 +81,7 @@ def __str__(self) -> str: return "Error api key" if self.__resp.reason == "Forbidden": return "You need to be verified first" - return self.__resp.reason + return str(self.__resp.status_code) + " " + self.__resp.reason def save(self): """ diff --git a/test/unit/_/setup.py b/test/unit/_/test_setup.py similarity index 100% rename from test/unit/_/setup.py rename to test/unit/_/test_setup.py diff --git a/test/unit/api/test_http.py b/test/unit/api/test_http.py index 31afa0c8f..26ae895a4 100644 --- a/test/unit/api/test_http.py +++ b/test/unit/api/test_http.py @@ -9,7 +9,6 @@ 开发环境下存储凭证过期时间为3s """ import os -import time import nanoid from swanlab.api.http import create_http, HTTP, CosClient from swanlab.api.auth.login import login_by_key @@ -17,6 +16,8 @@ from tutils import API_KEY, TEMP_PATH, is_skip_cloud_test from tutils.setup import UseMocker, UseSetupHttp import pytest +import responses +from responses import registries alphabet = "abcdefghijklmnopqrstuvwxyz" @@ -32,6 +33,26 @@ def test_decode_response(): assert data == "test" +@responses.activate(registry=registries.OrderedRegistry) +def test_retry(): + """ + 测试重试机制 + """ + from swanlab.package import get_host_api + url = get_host_api() + "/retry" + rsp1 = responses.get(url, body="Error", status=500) + rsp2 = responses.get(url, body="Error", status=500) + rsp3 = responses.get(url, body="Error", status=500) + rsp4 = responses.get(url, body="OK", status=200) + with UseSetupHttp() as http: + data = http.get("/retry") + assert data == "OK" + assert rsp1.call_count == 1 + assert rsp2.call_count == 1 + assert rsp3.call_count == 1 + assert rsp4.call_count == 1 + + @pytest.mark.skipif(is_skip_cloud_test, reason="skip cloud test") class TestCosSuite: http: HTTP = None @@ -67,8 +88,10 @@ def test_cos_upload(self): buffer.write(b"test") buffer.file_name = "test" self.http.upload(buffer) - # 开发版本设置的过期时间为3s,等待过期 - time.sleep(3) - # 重新上传,测试刷新 - assert self.http.cos.should_refresh is True - self.http.upload(buffer) + # 为了开发方便,测试刷新功能关闭 + + # # 开发版本设置的过期时间为3s,等待过期 + # time.sleep(3) + # # 重新上传,测试刷新 + # assert self.http.cos.should_refresh is True + # self.http.upload(buffer) diff --git a/test/unit/test_package.py b/test/unit/test_package.py index d29cad644..8a6f0a73c 100644 --- a/test/unit/test_package.py +++ b/test/unit/test_package.py @@ -75,12 +75,16 @@ def test_get_experiment_url(): # ---------------------------------- 登录部分 ---------------------------------- class TestGetKey: + @staticmethod + def remove_env_key(): + if SwanLabEnv.API_KEY.value in os.environ: + del os.environ[SwanLabEnv.API_KEY.value] def test_ok(self): """ 获取key成功 """ - del os.environ[SwanLabEnv.API_KEY.value] + self.remove_env_key() # 首先需要登录 file = os.path.join(get_save_dir(), ".netrc") with open(file, "w"): @@ -96,7 +100,7 @@ def test_no_file(self): """ 文件不存在 """ - del os.environ[SwanLabEnv.API_KEY.value] + self.remove_env_key() from swanlab.error import KeyFileError with pytest.raises(KeyFileError) as e: P.get_key() diff --git a/tutils/setup.py b/tutils/setup.py index f5906d822..d0dd8548e 100644 --- a/tutils/setup.py +++ b/tutils/setup.py @@ -46,10 +46,10 @@ def mock_login_info( m.post(f"{get_host_api()}/login/api_key", status_code=status_code, reason=error_reason) else: expired_at = datetime.now().isoformat() - expired_at = (datetime.fromisoformat(expired_at) + timedelta(days=7)).isoformat() + 'Z' + # 过期时间为当前时间加8天,主要是时区问题,所以不能7天以内 + expired_at = (datetime.fromisoformat(expired_at) + timedelta(days=8)).isoformat() + 'Z' m.post(f"{get_host_api()}/login/api_key", json={ "sid": nanoid.generate(), - # 时间为当前时间加7天 "expiredAt": expired_at, "userInfo": { "username": username @@ -62,7 +62,7 @@ def mock_login_info( class UseSetupHttp: """ - 用于全局使用的http对象 + 用于全局使用的http对象,模拟登录,退出时重置http 使用with关键字,自动登录,退出时自动重置http 也可以使用del手动释放 """ @@ -93,8 +93,9 @@ class UseMocker(requests_mock.Mocker): 使用request_mock库进行mock测试,由于现在绝大部分请求都在get_host_api上,所以封装一层 """ - def __init__(self, base_url=get_host_api()): + def __init__(self, base_url: str = None): super().__init__() + base_url = base_url or get_host_api() self.base_url = base_url def get(self, router, *args, **kwargs):