From cbbb21e5488f86845b8cb2c85b6793fc8e7620ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=A6=E5=BD=92=E4=BA=91=E5=B8=86?= <83509039+Puiching-Memory@users.noreply.github.com> Date: Thu, 15 Aug 2024 14:44:07 +0800 Subject: [PATCH] Automatically change port when conflict occurs (#673) * add test * automatically change port --------- Co-authored-by: KAAANG <79990647+SAKURA-CAT@users.noreply.github.com> --- swanlab/cli/commands/dashboard/watch.py | 32 ++++++- swanlab/data/run/callback.py | 2 +- test/unit/cli/test_cli_watch.py | 120 ++++-------------------- 3 files changed, 49 insertions(+), 105 deletions(-) diff --git a/swanlab/cli/commands/dashboard/watch.py b/swanlab/cli/commands/dashboard/watch.py index c80ccd435..641dbf4e0 100644 --- a/swanlab/cli/commands/dashboard/watch.py +++ b/swanlab/cli/commands/dashboard/watch.py @@ -13,6 +13,36 @@ import click import os import sys +import socket + + +def get_free_port(address='0.0.0.0', default_port=5092) -> int: + """ + 获取一个可用端口 + NOTE: 默认情况下,返回5092端口,如果端口被占用,返回一个随机可用端口 + WARNING: 不能保证独占,极稀有情况下两个程序占用到此端口 + --- + Args: + address: 主机(host)地址 + default_port: 默认端口号 + + Return: + port: 一个可用端口 + """ + # 判断是否占用 + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind((address, default_port)) + except OSError: + pass + else: + return default_port + # 如果占用就返回一个随机可用端口 + sock = socket.socket() + sock.bind((address, 0)) + ip, port = sock.getsockname() + sock.close() + return port @click.command() @@ -40,7 +70,7 @@ @click.option( "--port", "-p", - default=lambda: os.environ.get(SwanLabEnv.SWANBOARD_PROT.value, 5092), + default=lambda: os.environ.get(SwanLabEnv.SWANBOARD_PROT.value, get_free_port()), nargs=1, type=click.IntRange(1, 65535), help="The port of swanlab web, default by 5092", diff --git a/swanlab/data/run/callback.py b/swanlab/data/run/callback.py index 3e3ee3fe1..2e0f24601 100644 --- a/swanlab/data/run/callback.py +++ b/swanlab/data/run/callback.py @@ -70,7 +70,7 @@ def _watch_tip_print(self): """ swanlog.info( "🌟 Run `" - + FONT.bold("swanlab watch -l {}".format(self.fmt_windows_path(self.settings.swanlog_dir))) + + FONT.bold("swanlab watch {}".format(self.fmt_windows_path(self.settings.swanlog_dir))) + "` to view SwanLab Experiment Dashboard locally" ) diff --git a/test/unit/cli/test_cli_watch.py b/test/unit/cli/test_cli_watch.py index 0a845aca1..ab805f7c5 100644 --- a/test/unit/cli/test_cli_watch.py +++ b/test/unit/cli/test_cli_watch.py @@ -7,106 +7,20 @@ @Description: 测试cli的watch命令 """ -# import os -# import time -# import pytest -# from swanlab.cli import cli -# from click.testing import CliRunner -# import multiprocessing -# import requests -# import swanlab -# from tutils import TEMP_PATH -# from swanlab.env import get_swanlog_dir, SwanLabEnv -# -# -# def mock_swanlog(path=None): -# if path is None: -# path = get_swanlog_dir() -# os.makedirs(path, exist_ok=True) -# del os.environ[SwanLabEnv.SWANLOG_FOLDER.value] -# swanlab.init(logdir=path, mode="local") -# swanlab.log({"test": 1}) -# swanlab.finish() -# -# -# # 运行任务 -# # noinspection PyTypeChecker -# def runner_watch(*args): -# runner = CliRunner() -# return runner.invoke(cli, ["watch", *args]) -# -# -# def runner_watch_wrapper(*args): -# from tutils import reset_some_env -# reset_some_env() -# r = runner_watch(*args) -# if r.exit_code != 0: -# raise Exception(r.output) -# -# -# # 测试能否ping通 -# def ping(host="127.0.0.1", port=5092): -# url = f"http://{host}:{port}" # noqa -# time.sleep(3) -# try: -# response = requests.get(url, timeout=10) -# assert response.status_code == 200 -# except Exception as e: -# raise Exception("ping failed", str(e)) -# -# -# @pytest.mark.parametrize("logdir, ping_args, args", [ -# # 无参数 -# [None, [], []], -# # 指定logdir -# [os.path.join(TEMP_PATH, "watch", "swanlog"), [], ["--logdir", os.path.join(TEMP_PATH, "watch", "swanlog")]], -# # 指定host和port -# [None, ["0.0.0.0", "5093"], ["--host", "0.0.0.0", "--port", "5093"]], -# # 直接watch -# [os.path.join(TEMP_PATH, "watch", "swanlog"), [], [os.path.join(TEMP_PATH, "watch", "swanlog")]], -# ]) -# def test_watch_ok(logdir, ping_args, args): -# """ -# 测试watch命令,正常情况 -# """ -# mock_swanlog(logdir) -# p1 = multiprocessing.Process(target=runner_watch_wrapper, args=args) -# p2 = multiprocessing.Process(target=ping, args=ping_args) -# p1.start() -# p2.start() -# p2.join() -# p1.kill() -# assert p2.exitcode == 0 -# -# -# def test_watch_wrong_logdir(): -# """ -# 测试watch命令,logdir不存在 -# """ -# result = runner_watch("wrong_logdir") -# assert result.exit_code == 2 -# result = runner_watch("--logdir", "wrong_logdir") -# assert result.exit_code == 2 -# os.environ[SwanLabEnv.SWANLOG_FOLDER.value] = "wrong_logdir" -# result = runner_watch() -# assert result.exit_code == 2 -# -# -# def test_watch_wrong_host_port(): -# """ -# 测试watch命令,host和port错误 -# """ -# mock_swanlog() -# result = runner_watch("--host", "") -# assert result.exit_code == 6 -# result = runner_watch("--port", "0") -# assert result.exit_code == 2 -# result = runner_watch("--port", "65536") -# assert result.exit_code == 2 -# # 如果ip被占用,会报错 -# p1 = multiprocessing.Process(target=runner_watch_wrapper, args=[get_swanlog_dir(), "--port", "5092"]) -# p1.start() -# time.sleep(3) -# result = runner_watch("--port", "5092") -# assert result.exit_code == 7 -# p1.kill() +from swanlab.cli.commands.dashboard.watch import get_free_port +import socket + + +def test_get_free_port(): + """ + 测试get_free_port函数是否能正确获取一个可用端口 + """ + port = get_free_port() + assert isinstance(port, int) + assert port == 5092 + # 占用端口 + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('0.0.0.0', 5092)) + port = get_free_port() + assert isinstance(port, int) + assert port != 5092