Skip to content

Commit

Permalink
Implement timeout when canceling benchmarking (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
MsRandom authored Nov 4, 2024
1 parent 6022e40 commit 791ac6c
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 13 deletions.
4 changes: 1 addition & 3 deletions neuron/neuron/submission_tester/testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from concurrent.futures import ThreadPoolExecutor, CancelledError
from concurrent.futures import CancelledError
from pathlib import Path
from statistics import mean
from threading import Event
Expand All @@ -23,8 +23,6 @@
DEFAULT_LOAD_TIMEOUT = 500
MIN_LOAD_TIMEOUT = 240

EXECUTOR = ThreadPoolExecutor(max_workers=2)

debug = int(os.getenv("VALIDATOR_DEBUG") or 0) > 0
logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion validator/base_validator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .metrics import *

API_VERSION = "4.4.3"
API_VERSION = "4.4.4"
2 changes: 1 addition & 1 deletion validator/submission_tester/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def start_benchmarking(
with benchmarker.lock:
timestamp = time.time_ns()

if timestamp - benchmarker.start_timestamp < 10_000_000_000:
if timestamp - benchmarker.start_timestamp < 60_000_000_000:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Started recently",
Expand Down
19 changes: 12 additions & 7 deletions validator/submission_tester/benchmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from pipelines import TextToImageRequest

logger = logging.getLogger(__name__)
EXECUTOR = ThreadPoolExecutor(1)


class Benchmarker:
Expand All @@ -41,6 +40,7 @@ class Benchmarker:
lock: Lock
benchmark_future: Future | None
cancelled_event: Event
executor: ThreadPoolExecutor

def __init__(self):
self.submissions = {}
Expand All @@ -54,6 +54,7 @@ def __init__(self):
self.submission_times = []
self.benchmark_future = None
self.cancelled_event = Event()
self.executor = ThreadPoolExecutor(max_workers=1)

def _benchmark_key(self, hotkey: Key):
submission = self.submissions[hotkey]
Expand Down Expand Up @@ -91,7 +92,7 @@ def _start_benchmarking(self, submissions: dict[Key, ModelRepositoryInfo]):
logger.info("Generating baseline samples to compare")
self.baseline = generate_baseline(self.inputs, cancelled_event=self.cancelled_event)

while len(self.benchmarks) != len(self.submissions):
while len(self.benchmarks) != len(self.submissions) and not self.cancelled_event.is_set():
hotkey = choice(list(self.submissions.keys() - self.benchmarks.keys()))

self._benchmark_key(hotkey)
Expand All @@ -116,13 +117,17 @@ def _start_benchmarking(self, submissions: dict[Key, ModelRepositoryInfo]):
def start_benchmarking(self, submissions: dict[Key, ModelRepositoryInfo]):
benchmark_future = self.benchmark_future

if benchmark_future:
if benchmark_future and not benchmark_future.done():
benchmark_future.cancel()
self.cancelled_event.set()
if not benchmark_future.cancelled():
benchmark_future.result()

self.benchmark_future = EXECUTOR.submit(self._start_benchmarking, submissions)
try:
benchmark_future.result(timeout=60)
except (CancelledError, TimeoutError):
logger.warning("Benchmarking was not stopped gracefully. Forcing shutdown.")
self.executor.shutdown(wait=False)
self.executor = ThreadPoolExecutor(max_workers=1)

self.benchmark_future = self.executor.submit(self._start_benchmarking, submissions)

def get_baseline_metrics(self) -> MetricData | None:
return self.baseline.metric_data if self.baseline else None
2 changes: 1 addition & 1 deletion validator/weight_setting/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from .wandb_args import add_wandb_args
from .winner_selection import get_scores, get_contestant_scores, get_tiers, get_contestant_tier

VALIDATOR_VERSION: tuple[int, int, int] = (4, 5, 7)
VALIDATOR_VERSION: tuple[int, int, int] = (4, 5, 8)
VALIDATOR_VERSION_STRING = ".".join(map(str, VALIDATOR_VERSION))

WEIGHTS_VERSION = (
Expand Down

0 comments on commit 791ac6c

Please sign in to comment.