Skip to content

Commit

Permalink
Dynamic loading timeout (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
MsRandom authored Oct 31, 2024
1 parent 59287a0 commit 6fa06dd
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 20 deletions.
3 changes: 1 addition & 2 deletions neuron/neuron/random_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from pipelines import TextToImageRequest

INFERENCE_SOCKET_TIMEOUT = 240
TIMEZONE = ZoneInfo("US/Pacific")
TIMEZONE = ZoneInfo("America/Los_Angeles")
INPUTS_ENDPOINT = os.getenv("INPUTS_ENDPOINT", "https://edge-inputs.api.wombo.ai")


Expand Down
29 changes: 18 additions & 11 deletions neuron/neuron/submission_tester/inference_sandbox.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import logging
import os
import sys
import time
from io import TextIOWrapper
from multiprocessing.connection import Client, Connection
from os.path import abspath
from pathlib import Path
from subprocess import Popen, run, TimeoutExpired, PIPE
from threading import Thread
from time import perf_counter, sleep
from typing import Generic, TypeVar

from pydantic import BaseModel

from .setup_inference_sandbox import setup_sandbox, InvalidSubmissionError, NETWORK_JAIL
from ..contest import ModelRepositoryInfo
from ..random_inputs import INFERENCE_SOCKET_TIMEOUT

logger = logging.getLogger(__name__)

Expand All @@ -26,8 +25,16 @@ class InferenceSandbox(Generic[RequestT]):

_client: Connection
_process: Popen

def __init__(self, repository_info: ModelRepositoryInfo, baseline: bool, sandbox_directory: Path, switch_user: bool):
load_time: float

def __init__(
self,
repository_info: ModelRepositoryInfo,
baseline: bool,
sandbox_directory: Path,
switch_user: bool,
load_timeout: int,
):
self._repository = repository_info
self._baseline = baseline
self._sandbox_directory = sandbox_directory
Expand Down Expand Up @@ -67,21 +74,21 @@ def __init__(self, repository_info: ModelRepositoryInfo, baseline: bool, sandbox

logger.info("Inference process starting")

for _ in range(INFERENCE_SOCKET_TIMEOUT):
if os.path.exists(socket_path):
break

time.sleep(1)

start = perf_counter()
for _ in range(load_timeout):
if os.path.exists(socket_path): break
sleep(1)
self._check_exit()
else:
self.fail(f"Timed out after {INFERENCE_SOCKET_TIMEOUT} seconds")
self.fail(f"Timed out after {load_timeout} seconds")

logger.info("Connecting to socket")
try:
self._client = Client(socket_path)
except ConnectionRefusedError:
self.fail("Failed to connect to socket")
self.load_time = perf_counter() - start
logger.info(f"Connected to socket in {self.load_time:.2f} seconds")

@property
def _user(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions neuron/neuron/submission_tester/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class MetricData(BaseModel):
size: int
vram_used: float
watts_used: float
load_time: float


class BaselineBenchmark(BaseModel):
Expand Down
24 changes: 22 additions & 2 deletions neuron/neuron/submission_tester/testing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from concurrent.futures import ThreadPoolExecutor, CancelledError
from pathlib import Path
from statistics import mean
Expand All @@ -19,9 +20,12 @@

SANDBOX_DIRECTORY = Path("/sandbox")
BASELINE_SANDBOX_DIRECTORY = Path("/baseline-sandbox")
DEFAULT_LOAD_TIMEOUT = 500
MIN_LOAD_TIMEOUT = 120

EXECUTOR = ThreadPoolExecutor(max_workers=2)

debug = int(os.getenv("VALIDATOR_DEBUG", 0)) > 0
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -56,7 +60,13 @@ def generate_baseline(
) -> BaselineBenchmark:
outputs: list[GenerationOutput] = []

with InferenceSandbox(CURRENT_CONTEST.baseline_repository, True, sandbox_directory, switch_user) as sandbox:
with InferenceSandbox(
repository_info=CURRENT_CONTEST.baseline_repository,
baseline=True,
sandbox_directory=sandbox_directory,
switch_user=switch_user,
load_timeout=DEFAULT_LOAD_TIMEOUT,
) as sandbox:
size = sandbox.model_size

for index, request in enumerate(inputs):
Expand Down Expand Up @@ -86,6 +96,7 @@ def generate_baseline(
size=size,
vram_used=vram_used,
watts_used=watts_used,
load_time=sandbox.load_time,
),
)

Expand All @@ -96,13 +107,20 @@ def compare_checkpoints(
baseline: BaselineBenchmark,
sandbox_directory: Path = SANDBOX_DIRECTORY,
switch_user: bool = True,
load_timeout: int = DEFAULT_LOAD_TIMEOUT,
cancelled_event: Event | None = None,
) -> CheckpointBenchmark | None:
logger.info("Generating model samples")

outputs: list[GenerationOutput] = []

with InferenceSandbox(submission, False, sandbox_directory, switch_user) as sandbox:
with InferenceSandbox(
repository_info=submission,
baseline=False,
sandbox_directory=sandbox_directory,
switch_user=switch_user,
load_timeout=max(load_timeout, MIN_LOAD_TIMEOUT if not debug else DEFAULT_LOAD_TIMEOUT),
) as sandbox:
size = sandbox.model_size

try:
Expand Down Expand Up @@ -164,6 +182,7 @@ def calculate_similarity(comparator: OutputComparator, baseline_output: Generati
size=size,
vram_used=vram_used,
watts_used=watts_used,
load_time=sandbox.load_time,
),
average_similarity=average_similarity,
min_similarity=min_similarity,
Expand All @@ -176,6 +195,7 @@ def calculate_similarity(comparator: OutputComparator, baseline_output: Generati
f"Min Similarity: {min_similarity}\n"
f"Average Generation Time: {average_time}s\n"
f"Model Size: {size}b\n"
f"Model Load Time: {sandbox.load_time}s\n"
f"Max VRAM Usage: {vram_used}b\n"
f"Max Power Usage: {watts_used}W"
)
Expand Down
11 changes: 6 additions & 5 deletions validator/submission_tester/benchmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from random import choice
from threading import Lock, Event
from time import perf_counter
from typing import cast

from neuron.submission_tester import (
CheckpointBenchmark,
Expand Down Expand Up @@ -61,10 +62,10 @@ def _benchmark_key(self, hotkey: Key):

try:
self.benchmarks[hotkey] = compare_checkpoints(
submission,
self.inputs,
self.baseline,
cancelled_event=self.cancelled_event,
submission=submission,
inputs=self.inputs,
baseline=self.baseline,
load_timeout=int(cast(MetricData, self.get_baseline_metrics()).load_time * 2),
)
except InvalidSubmissionError as e:
logger.error(f"Skipping invalid submission '{submission}': '{e}'")
Expand Down Expand Up @@ -129,4 +130,4 @@ def start_benchmarking(self, submissions: dict[Key, ModelRepositoryInfo]):
self.benchmark_future = EXECUTOR.submit(self._start_benchmarking, submissions)

def get_baseline_metrics(self) -> MetricData | None:
return self.baseline.metric_data if self.baseline else None
return self.baseline.metric_data if self.baseline else None

0 comments on commit 6fa06dd

Please sign in to comment.