Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the socket to obtain available ports to resolve port number conflicts #673

Merged
merged 4 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion swanlab/cli/commands/dashboard/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,36 @@
import click
import os
import sys
import socket


def get_free_port(address='0.0.0.0', default_port=5092) -> int:
"""
获取一个可用端口
NOTE: 默认情况下,返回5092端口,如果端口被占用,返回一个随机可用端口
WARNING: 不能保证独占,极稀有情况下两个程序占用到此端口
---
Args:
address: 主机(host)地址
default_port: 默认端口号

Return:
port: 一个可用端口
"""
# 判断是否占用
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind((address, default_port))
except OSError:
pass
else:
return default_port
# 如果占用就返回一个随机可用端口
sock = socket.socket()
sock.bind((address, 0))
ip, port = sock.getsockname()
sock.close()
return port


@click.command()
Expand Down Expand Up @@ -40,7 +70,7 @@
@click.option(
"--port",
"-p",
default=lambda: os.environ.get(SwanLabEnv.SWANBOARD_PROT.value, 5092),
default=lambda: os.environ.get(SwanLabEnv.SWANBOARD_PROT.value, get_free_port()),
nargs=1,
type=click.IntRange(1, 65535),
help="The port of swanlab web, default by 5092",
Expand Down
2 changes: 1 addition & 1 deletion swanlab/data/run/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _watch_tip_print(self):
"""
swanlog.info(
"🌟 Run `"
+ FONT.bold("swanlab watch -l {}".format(self.fmt_windows_path(self.settings.swanlog_dir)))
+ FONT.bold("swanlab watch {}".format(self.fmt_windows_path(self.settings.swanlog_dir)))
+ "` to view SwanLab Experiment Dashboard locally"
)

Expand Down
120 changes: 17 additions & 103 deletions test/unit/cli/test_cli_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,106 +7,20 @@
@Description:
测试cli的watch命令
"""
# import os
# import time
# import pytest
# from swanlab.cli import cli
# from click.testing import CliRunner
# import multiprocessing
# import requests
# import swanlab
# from tutils import TEMP_PATH
# from swanlab.env import get_swanlog_dir, SwanLabEnv
#
#
# def mock_swanlog(path=None):
# if path is None:
# path = get_swanlog_dir()
# os.makedirs(path, exist_ok=True)
# del os.environ[SwanLabEnv.SWANLOG_FOLDER.value]
# swanlab.init(logdir=path, mode="local")
# swanlab.log({"test": 1})
# swanlab.finish()
#
#
# # 运行任务
# # noinspection PyTypeChecker
# def runner_watch(*args):
# runner = CliRunner()
# return runner.invoke(cli, ["watch", *args])
#
#
# def runner_watch_wrapper(*args):
# from tutils import reset_some_env
# reset_some_env()
# r = runner_watch(*args)
# if r.exit_code != 0:
# raise Exception(r.output)
#
#
# # 测试能否ping通
# def ping(host="127.0.0.1", port=5092):
# url = f"http://{host}:{port}" # noqa
# time.sleep(3)
# try:
# response = requests.get(url, timeout=10)
# assert response.status_code == 200
# except Exception as e:
# raise Exception("ping failed", str(e))
#
#
# @pytest.mark.parametrize("logdir, ping_args, args", [
# # 无参数
# [None, [], []],
# # 指定logdir
# [os.path.join(TEMP_PATH, "watch", "swanlog"), [], ["--logdir", os.path.join(TEMP_PATH, "watch", "swanlog")]],
# # 指定host和port
# [None, ["0.0.0.0", "5093"], ["--host", "0.0.0.0", "--port", "5093"]],
# # 直接watch
# [os.path.join(TEMP_PATH, "watch", "swanlog"), [], [os.path.join(TEMP_PATH, "watch", "swanlog")]],
# ])
# def test_watch_ok(logdir, ping_args, args):
# """
# 测试watch命令,正常情况
# """
# mock_swanlog(logdir)
# p1 = multiprocessing.Process(target=runner_watch_wrapper, args=args)
# p2 = multiprocessing.Process(target=ping, args=ping_args)
# p1.start()
# p2.start()
# p2.join()
# p1.kill()
# assert p2.exitcode == 0
#
#
# def test_watch_wrong_logdir():
# """
# 测试watch命令,logdir不存在
# """
# result = runner_watch("wrong_logdir")
# assert result.exit_code == 2
# result = runner_watch("--logdir", "wrong_logdir")
# assert result.exit_code == 2
# os.environ[SwanLabEnv.SWANLOG_FOLDER.value] = "wrong_logdir"
# result = runner_watch()
# assert result.exit_code == 2
#
#
# def test_watch_wrong_host_port():
# """
# 测试watch命令,host和port错误
# """
# mock_swanlog()
# result = runner_watch("--host", "")
# assert result.exit_code == 6
# result = runner_watch("--port", "0")
# assert result.exit_code == 2
# result = runner_watch("--port", "65536")
# assert result.exit_code == 2
# # 如果ip被占用,会报错
# p1 = multiprocessing.Process(target=runner_watch_wrapper, args=[get_swanlog_dir(), "--port", "5092"])
# p1.start()
# time.sleep(3)
# result = runner_watch("--port", "5092")
# assert result.exit_code == 7
# p1.kill()
from swanlab.cli.commands.dashboard.watch import get_free_port
import socket


def test_get_free_port():
"""
测试get_free_port函数是否能正确获取一个可用端口
"""
port = get_free_port()
assert isinstance(port, int)
assert port == 5092
# 占用端口
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('0.0.0.0', 5092))
port = get_free_port()
assert isinstance(port, int)
assert port != 5092