Skip to content

Commit

Permalink
Check popen process, log output
Browse files Browse the repository at this point in the history
  • Loading branch information
MsRandom committed Nov 28, 2024
1 parent 6e71c02 commit 9dffbf7
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions base/testing/inference_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,13 @@ def _setup_sandbox(self) -> int:
return repository_size + huggingface_models_size

@tracer.start_as_current_span("wait_for_socket")
def wait_for_socket(self) -> float:
def wait_for_socket(self, process: Popen) -> float:
start = perf_counter()
for _ in range(LOAD_TIMEOUT):
if self._socket_path.exists():
break
self._stop_flag.wait(0.1)
self._stop_flag.wait(1)
check_process(process)
if self._stop_flag.is_set():
raise CancelledError()
else:
Expand All @@ -197,8 +198,9 @@ def benchmark(self) -> BenchmarkOutput:
stdout=PIPE,
stderr=PIPE,
text=True,
):
load_time = self.wait_for_socket()
bufsize=1,
) as process:
load_time = self.wait_for_socket(process)
with Client(abspath(self._socket_path)) as client:
logger.info(f"Benchmarking {len(self._inputs)} samples")
for i, request in enumerate(self._inputs):
Expand All @@ -225,11 +227,11 @@ def benchmark(self) -> BenchmarkOutput:
load_time=load_time,
))
outputs.append(output)
check_process(process)

average_generation_time = sum(metric.generation_time for metric in metrics) / len(metrics)
vram_used = max(metric.vram_used for metric in metrics) - start_vram
watts_used = max(metric.watts_used for metric in metrics)

return BenchmarkOutput(
metrics=Metrics(
generation_time=average_generation_time,
Expand All @@ -240,3 +242,18 @@ def benchmark(self) -> BenchmarkOutput:
),
outputs=outputs,
)


def check_process(process: Popen):
if process.poll():
log_process(process)
raise InvalidSubmissionError(f"Inference crashed with exit code {process.returncode}")


def log_process(process: Popen):
stdout = process.stdout.read()
stderr = process.stderr.read()
if stdout:
logger.info(stdout)
if stderr:
logger.error(stderr)

0 comments on commit 9dffbf7

Please sign in to comment.