Skip to content

Commit

Permalink
Merge pull request #3 from Spartan859/main
Browse files Browse the repository at this point in the history
Multiple fixes and features.
  • Loading branch information
Spartan859 authored Feb 4, 2025
2 parents bed4bac + 43cb441 commit df1094e
Show file tree
Hide file tree
Showing 16 changed files with 371 additions and 207 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,5 @@ cython_debug/
record/

fetched_codes/
backend2judger.md
.data/
23 changes: 23 additions & 0 deletions aiohttp_session_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import aiohttp
import asyncio


class AiohttpSessionManager:
_instance = None
_session: dict[str, aiohttp.ClientSession] = {}

def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(AiohttpSessionManager, cls).__new__(
cls, *args, **kwargs
)
return cls._instance

async def __aexit__(self, exc_type, exc, tb):
for session in self._session.values():
await session.close()

def get_session(self, url: str) -> aiohttp.ClientSession:
if url not in self._session:
self._session[url] = aiohttp.ClientSession(base_url=url)
return self._session[url]
5 changes: 5 additions & 0 deletions base_match_judger.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ async def list(self) -> Dict[str, MatchResult]:
A dictionary mapping match IDs to their corresponding MatchResult objects
"""
raise NotImplementedError

@abstractmethod
def stop(self) -> None:
"""Stops the judger."""
raise NotImplementedError
15 changes: 14 additions & 1 deletion build_task.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Contains the task for building agents."""

import logging
from typing import Optional

from base_agent_code_fetcher import BaseAgentCodeFetcher
from base_compile_result_sender import BaseCompileResultSender
from base_docker_image_builder import BaseDockerImageBuilder
from base_task import BaseTask

Expand All @@ -14,17 +16,20 @@ class BuildTask(BaseTask):
_code_id: str
_fetcher: BaseAgentCodeFetcher
_result: Optional[str] = None
_sender: Optional[BaseCompileResultSender]

def __init__(
self,
code_id: str,
fetcher: BaseAgentCodeFetcher,
builder: BaseDockerImageBuilder,
sender: BaseCompileResultSender = None,
):
self._code_id = code_id

self._fetcher = fetcher
self._builder = builder
self._sender = sender

async def execute(self) -> str:
"""Runs the task.
Expand All @@ -33,7 +38,15 @@ async def execute(self) -> str:
The tag of the built image
"""
tar_file_path = await self._fetcher.fetch(self._code_id)
self._result = await self._builder.build(tar_file_path, self._code_id)
try:
self._result = await self._builder.build(tar_file_path, self._code_id)
if self._sender:
await self._sender.send(self._code_id, True, "")
except Exception as e:
logging.error(f"Failed to build agent {self._code_id}: {e}")
if self._sender:
await self._sender.send(self._code_id, False, str(e))
# raise e
return self._result

@property
Expand Down
55 changes: 0 additions & 55 deletions compile_task.py

This file was deleted.

51 changes: 31 additions & 20 deletions judge_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self._reporter = reporter

self._build_tasks = [
BuildTask(code_id, fetcher, builder) for code_id in player_code_ids
BuildTask(code_id, fetcher, builder, None) for code_id in player_code_ids
]

async def execute(self) -> MatchResult:
Expand All @@ -50,25 +50,36 @@ async def execute(self) -> MatchResult:
Returns:
The match judge result
"""

agent_image_tags = await asyncio.gather(
*[t.execute() for t in self._build_tasks]
)

for tag in agent_image_tags:
if tag.split(":")[0] == "E":
match_result = MatchResult(
self._match_id, scores=[0, 0], record_file_path=""
)
await self._reporter.report(match_result)
return match_result

match_result = await self._judger.judge(
self._match_id, self._game_host_image_tag, agent_image_tags
)
await self._reporter.report(match_result)
self._result = match_result
return match_result
try:
agent_image_tags = await asyncio.gather(
*[t.execute() for t in self._build_tasks]
)
match_result = await self._judger.judge(
self._match_id, self._game_host_image_tag, agent_image_tags
)
await self._reporter.report(match_result)
self._result = match_result
return match_result
except asyncio.CancelledError:
print("Task cancelled.")
self._judger.stop()
raise
except Exception as e:
# If any build task failed, the match is judged as failed.
match_result = MatchResult(
match_id=self._match_id,
success=False,
err_msg=str(e),
scores=[0] * len(self._build_tasks),
record_file_path=None,
states=[
{"position": i, "status": "OK", "code": 0, "stderr": ""}
for i in range(len(self._build_tasks))
],
)
await self._reporter.report(match_result)
self._result = match_result
return match_result

@property
def result(self) -> Optional[MatchResult]:
Expand Down
91 changes: 48 additions & 43 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import base64
import time
from aiohttp_session_manager import AiohttpSessionManager
from base_task_scheduler import BaseTaskScheduler
from build_task import BuildTask
from compile_task import CompileTask
from match_result import MatchResult
from thuai_builder import ThuaiBuilder
from thuai_cr_sender import ThuaiCRSender
from thuai_fetcher import ThuaiFetcher
Expand All @@ -12,56 +14,59 @@
from thuai_task_scheduler import ThuaiTaskScheduler
from ws_client import WsClient

BASE_URL = "https://api.dev.saiblo.net/"

async def fetch():
await ThuaiFetcher().fetch("cbd96c3c5a934e0cabac0a3f006a823b")


async def clean():
await ThuaiFetcher().clean()


async def buildTask():
return await BuildTask(
"cbd96c3c5a934e0cabac0a3f006a823b", ThuaiFetcher(), ThuaiBuilder()
).execute()


async def compileTask():
return await CompileTask(
"cbd96c3c5a934e0cabac0a3f006a823b",
ThuaiFetcher(),
ThuaiBuilder(),
ThuaiCRSender(),
).execute()
async def testWsClient():
async with AiohttpSessionManager().get_session(BASE_URL) as http_session:
ws_client = WsClient(
"wss://api.dev.saiblo.net/ws/",
"thuai8judger",
ThuaiTaskScheduler(),
ThuaiFetcher(session=http_session),
ThuaiBuilder(),
ThuaiCRSender(session=http_session),
ThuaiJudger(),
ThuaiReporter(session=http_session),
"thuai7judger:latest",
)
await ws_client.start()
# print("WsClient started")
# # print('qwdhkdjwqieuo')
# time.sleep(10)
# # print("qhjdqkjwhdjk")
# ws_client.stop()
# time.sleep(2)
# ws_client.start()
# time.sleep(20)
# ws_client.stop()


async def testWsClient():
ws_client = WsClient(
"wss://api.dev.saiblo.net/ws/",
"thuai8judger",
ThuaiTaskScheduler(),
ThuaiFetcher(),
ThuaiBuilder(),
ThuaiCRSender(),
ThuaiJudger(),
ThuaiReporter(),
"thuai7judger:latest",
)
await ws_client.start()
print("WsClient started")
# # print('qwdhkdjwqieuo')
# time.sleep(10)
# # print("qhjdqkjwhdjk")
# ws_client.stop()
# time.sleep(2)
# ws_client.start()
# time.sleep(20)
# ws_client.stop()
async def testReporter():
async with AiohttpSessionManager().get_session(BASE_URL) as http_session:
reporter = ThuaiReporter(session=http_session)
match_result = MatchResult(
match_id="7716",
success=False,
scores=[0, 0],
err_msg="Test error message",
record_file_path="test.dat",
states=[
{
"position": i,
"status": "OK",
"code": 0,
"stderr": base64.b64encode(b"test").decode("utf-8"),
}
for i in range(2)
],
)
await reporter.report(match_result)


async def main():
await testWsClient()
# await testReporter()


if __name__ == "__main__":
Expand Down
15 changes: 13 additions & 2 deletions match_result.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Contains the match result."""

from dataclasses import dataclass
from typing import List
from typing import List, Optional


@dataclass
Expand All @@ -10,12 +10,23 @@ class MatchResult:
Attributes:
match_id: The match ID
success: Whether the match was successfully judged
err_msg: An error message if the match was not successfully judged
scores: A list of scores
achieved by each agent. The index corresponds to the agent's position in the
original agent_paths list. Higher scores typically indicate better performance.
record_file_path: The path to the record file.
states: A list of states for each agent. Each state is a dictionary.
state:
position: the rank of the agent
status: the status of the agent, ["OK", "RE", "TLE", "MLE", "OLE", "STLE", "EXIT", "UE", "CANCEL", "IA"]
code: the exit code of the agent
stderr: the stderr of the agent, base64 encoded
"""

match_id: str
success: bool
err_msg: str
scores: List[float]
record_file_path: str
record_file_path: Optional[str]
states: List[dict]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
coverage==7.6.10
docker==7.1.0
websocket-client==1.8.0
aiohttp==3.10.5
13 changes: 11 additions & 2 deletions thuai_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Contains docker image build for THUAI."""

import asyncio
from io import BytesIO
import tarfile
from typing import Dict
from pathlib import Path
import string
Expand Down Expand Up @@ -29,7 +31,14 @@ def __init__(self):

def _build_image(self, file_path: Path, code_id: str):
"""Block in a separate thread to build Docker image."""
self.client.images.build(path=str(file_path), tag=code_id, rm=True)
with open(file_path, "rb") as tar_file:
self.client.images.build(
fileobj=tar_file,
custom_context=True,
tag=code_id,
rm=True,
forcerm=True,
)

async def build(self, file_path: Path, code_id: str) -> str:
# get all image tags
Expand All @@ -54,7 +63,7 @@ async def build(self, file_path: Path, code_id: str) -> str:
# error_msg += log_line
error_msg += log_line
# print(error_msg)
return f"E:{error_msg}"
raise Exception(error_msg)

self.built_images[file_path] = code_id

Expand Down
Loading

0 comments on commit df1094e

Please sign in to comment.