Skip to content

Commit

Permalink
Automatically change port when conflict occurs (#673)
Browse files Browse the repository at this point in the history
* add test
* automatically change port

---------

Co-authored-by: KAAANG <[email protected]>
  • Loading branch information
Puiching-Memory and SAKURA-CAT authored Aug 15, 2024
1 parent a31da69 commit cbbb21e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 105 deletions.
32 changes: 31 additions & 1 deletion swanlab/cli/commands/dashboard/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,36 @@
import click
import os
import sys
import socket


def get_free_port(address='0.0.0.0', default_port=5092) -> int:
"""
获取一个可用端口
NOTE: 默认情况下,返回5092端口,如果端口被占用,返回一个随机可用端口
WARNING: 不能保证独占,极稀有情况下两个程序占用到此端口
---
Args:
address: 主机(host)地址
default_port: 默认端口号
Return:
port: 一个可用端口
"""
# 判断是否占用
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind((address, default_port))
except OSError:
pass
else:
return default_port
# 如果占用就返回一个随机可用端口
sock = socket.socket()
sock.bind((address, 0))
ip, port = sock.getsockname()
sock.close()
return port


@click.command()
Expand Down Expand Up @@ -40,7 +70,7 @@
@click.option(
"--port",
"-p",
default=lambda: os.environ.get(SwanLabEnv.SWANBOARD_PROT.value, 5092),
default=lambda: os.environ.get(SwanLabEnv.SWANBOARD_PROT.value, get_free_port()),
nargs=1,
type=click.IntRange(1, 65535),
help="The port of swanlab web, default by 5092",
Expand Down
2 changes: 1 addition & 1 deletion swanlab/data/run/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _watch_tip_print(self):
"""
swanlog.info(
"🌟 Run `"
+ FONT.bold("swanlab watch -l {}".format(self.fmt_windows_path(self.settings.swanlog_dir)))
+ FONT.bold("swanlab watch {}".format(self.fmt_windows_path(self.settings.swanlog_dir)))
+ "` to view SwanLab Experiment Dashboard locally"
)

Expand Down
120 changes: 17 additions & 103 deletions test/unit/cli/test_cli_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,106 +7,20 @@
@Description:
测试cli的watch命令
"""
# import os
# import time
# import pytest
# from swanlab.cli import cli
# from click.testing import CliRunner
# import multiprocessing
# import requests
# import swanlab
# from tutils import TEMP_PATH
# from swanlab.env import get_swanlog_dir, SwanLabEnv
#
#
# def mock_swanlog(path=None):
# if path is None:
# path = get_swanlog_dir()
# os.makedirs(path, exist_ok=True)
# del os.environ[SwanLabEnv.SWANLOG_FOLDER.value]
# swanlab.init(logdir=path, mode="local")
# swanlab.log({"test": 1})
# swanlab.finish()
#
#
# # 运行任务
# # noinspection PyTypeChecker
# def runner_watch(*args):
# runner = CliRunner()
# return runner.invoke(cli, ["watch", *args])
#
#
# def runner_watch_wrapper(*args):
# from tutils import reset_some_env
# reset_some_env()
# r = runner_watch(*args)
# if r.exit_code != 0:
# raise Exception(r.output)
#
#
# # 测试能否ping通
# def ping(host="127.0.0.1", port=5092):
# url = f"http://{host}:{port}" # noqa
# time.sleep(3)
# try:
# response = requests.get(url, timeout=10)
# assert response.status_code == 200
# except Exception as e:
# raise Exception("ping failed", str(e))
#
#
# @pytest.mark.parametrize("logdir, ping_args, args", [
# # 无参数
# [None, [], []],
# # 指定logdir
# [os.path.join(TEMP_PATH, "watch", "swanlog"), [], ["--logdir", os.path.join(TEMP_PATH, "watch", "swanlog")]],
# # 指定host和port
# [None, ["0.0.0.0", "5093"], ["--host", "0.0.0.0", "--port", "5093"]],
# # 直接watch
# [os.path.join(TEMP_PATH, "watch", "swanlog"), [], [os.path.join(TEMP_PATH, "watch", "swanlog")]],
# ])
# def test_watch_ok(logdir, ping_args, args):
# """
# 测试watch命令,正常情况
# """
# mock_swanlog(logdir)
# p1 = multiprocessing.Process(target=runner_watch_wrapper, args=args)
# p2 = multiprocessing.Process(target=ping, args=ping_args)
# p1.start()
# p2.start()
# p2.join()
# p1.kill()
# assert p2.exitcode == 0
#
#
# def test_watch_wrong_logdir():
# """
# 测试watch命令,logdir不存在
# """
# result = runner_watch("wrong_logdir")
# assert result.exit_code == 2
# result = runner_watch("--logdir", "wrong_logdir")
# assert result.exit_code == 2
# os.environ[SwanLabEnv.SWANLOG_FOLDER.value] = "wrong_logdir"
# result = runner_watch()
# assert result.exit_code == 2
#
#
# def test_watch_wrong_host_port():
# """
# 测试watch命令,host和port错误
# """
# mock_swanlog()
# result = runner_watch("--host", "")
# assert result.exit_code == 6
# result = runner_watch("--port", "0")
# assert result.exit_code == 2
# result = runner_watch("--port", "65536")
# assert result.exit_code == 2
# # 如果ip被占用,会报错
# p1 = multiprocessing.Process(target=runner_watch_wrapper, args=[get_swanlog_dir(), "--port", "5092"])
# p1.start()
# time.sleep(3)
# result = runner_watch("--port", "5092")
# assert result.exit_code == 7
# p1.kill()
from swanlab.cli.commands.dashboard.watch import get_free_port
import socket


def test_get_free_port():
"""
测试get_free_port函数是否能正确获取一个可用端口
"""
port = get_free_port()
assert isinstance(port, int)
assert port == 5092
# 占用端口
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('0.0.0.0', 5092))
port = get_free_port()
assert isinstance(port, int)
assert port != 5092

0 comments on commit cbbb21e

Please sign in to comment.