From e731c1d4ed554a5db1339efc88f1f2413d2a971d Mon Sep 17 00:00:00 2001 From: futrime <35801754+futrime@users.noreply.github.com> Date: Thu, 6 Feb 2025 22:59:46 +0800 Subject: [PATCH] feat: add tests for agent code fetcher --- agent_code_fetcher.py | 1 - base_agent_code_fetcher.py | 3 -- tests/test_agent_code_fetcher.py | 89 ++++++++++++++++++++++++++++++++ tests/test_path_manager.py | 6 +++ 4 files changed, 95 insertions(+), 4 deletions(-) create mode 100644 tests/test_agent_code_fetcher.py diff --git a/agent_code_fetcher.py b/agent_code_fetcher.py index 157c3fe..6ffd6a0 100644 --- a/agent_code_fetcher.py +++ b/agent_code_fetcher.py @@ -24,7 +24,6 @@ def __init__(self, session: aiohttp.ClientSession): Args: session: The aiohttp client session initialized with the base URL of the API """ - self._session = session async def clean(self) -> None: diff --git a/base_agent_code_fetcher.py b/base_agent_code_fetcher.py index 163062e..680945f 100644 --- a/base_agent_code_fetcher.py +++ b/base_agent_code_fetcher.py @@ -11,7 +11,6 @@ class BaseAgentCodeFetcher(ABC): @abstractmethod async def clean(self) -> None: """Cleans up fetched resources.""" - raise NotImplementedError @abstractmethod async def fetch(self, code_id: str) -> Path: @@ -26,7 +25,6 @@ async def fetch(self, code_id: str) -> Path: Returns: The path to the tarball file where the code should be saved """ - raise NotImplementedError @abstractmethod async def list(self) -> Dict[str, Path]: @@ -35,4 +33,3 @@ async def list(self) -> Dict[str, Path]: Returns: A dictionary mapping code IDs to the paths of their corresponding tarball files """ - raise NotImplementedError diff --git a/tests/test_agent_code_fetcher.py b/tests/test_agent_code_fetcher.py new file mode 100644 index 0000000..94c9419 --- /dev/null +++ b/tests/test_agent_code_fetcher.py @@ -0,0 +1,89 @@ +import shutil +from pathlib import Path +from unittest import IsolatedAsyncioTestCase + +import aiohttp + +import agent_code_fetcher + +CODE_ID = "a09f660a-e0e6-41ac-b721-f8ece8e71f33" +HTTP_BASE_URL = "https://api.dev.saiblo.net" + + +class TestAgentCodeFetcher(IsolatedAsyncioTestCase): + _session: aiohttp.ClientSession + + async def asyncSetUp(self) -> None: + shutil.rmtree( + Path("data"), + ignore_errors=True, + ) + + self._session = aiohttp.ClientSession(HTTP_BASE_URL) + + def tearDown(self) -> None: + shutil.rmtree( + Path("data"), + ignore_errors=True, + ) + + async def test_clean_no_dir(self): + # Arrange. + fetcher = agent_code_fetcher.AgentCodeFetcher(self._session) + + # Act. + await fetcher.clean() + + # Assert. + self.assertTrue(Path("data/agent_code").is_dir()) + + async def test_clean_dir_exists(self): + # Arrange. + Path("data/agent_code").mkdir(parents=True, exist_ok=True) + fetcher = agent_code_fetcher.AgentCodeFetcher(self._session) + + # Act. + await fetcher.clean() + + # Assert. + self.assertTrue( + Path("data/agent_code").is_dir(), + ) + + async def test_fetch_cached(self): + # Arrange. + Path("data/agent_code").mkdir(parents=True, exist_ok=True) + path = Path("data/agent_code") / f"{CODE_ID}.tar" + path.touch() + fetcher = agent_code_fetcher.AgentCodeFetcher(self._session) + + # Act. + result = await fetcher.fetch(CODE_ID) + + # Assert. + self.assertEqual(path.absolute(), result) + + async def test_fetch_not_cached(self): + # Arrange. + path = Path("data/agent_code") / f"{CODE_ID}.tar" + fetcher = agent_code_fetcher.AgentCodeFetcher(self._session) + + # Act. + result = await fetcher.fetch(CODE_ID) + + # Assert. + self.assertEqual(path.absolute(), result) + self.assertTrue(path.is_file()) + + async def test_list(self): + # Arrange. + Path("data/agent_code").mkdir(parents=True, exist_ok=True) + path = Path("data/agent_code") / f"{CODE_ID}.tar" + path.touch() + fetcher = agent_code_fetcher.AgentCodeFetcher(self._session) + + # Act. + result = await fetcher.list() + + # Assert. + self.assertEqual({CODE_ID: path.absolute()}, result) diff --git a/tests/test_path_manager.py b/tests/test_path_manager.py index 0c7fa23..f10345a 100644 --- a/tests/test_path_manager.py +++ b/tests/test_path_manager.py @@ -11,6 +11,12 @@ def test_get_agent_code_base_dir_path(self): path_manager.get_agent_code_base_dir_path(), ) + def test_get_judge_replay_base_dir_path(self): + self.assertEqual( + Path.cwd() / "data/judge_replays", + path_manager.get_judge_replay_base_dir_path(), + ) + def test_get_judge_result_base_dir_path(self): self.assertEqual( Path.cwd() / "data/judge_results",