Skip to content

Commit

Permalink
Better vram monitor (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
MsRandom authored Nov 15, 2024
1 parent 49cfd08 commit f7512fc
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 18 deletions.
6 changes: 4 additions & 2 deletions neuron/neuron/submission_tester/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def generate_baseline(
) -> BaselineBenchmark:
outputs: list[GenerationOutput] = []

start_vram = CURRENT_CONTEST.get_vram_used()
with InferenceSandbox(
repository_info=CURRENT_CONTEST.baseline_repository,
baseline=True,
Expand All @@ -82,7 +83,7 @@ def generate_baseline(
outputs.append(output)

generation_time = mean(output.generation_time for output in outputs)
vram_used = max(output.vram_used for output in outputs)
vram_used = max(output.vram_used for output in outputs) - start_vram
watts_used = max(output.watts_used for output in outputs)

return BaselineBenchmark(
Expand Down Expand Up @@ -111,6 +112,7 @@ def compare_checkpoints(

outputs: list[GenerationOutput] = []

start_vram = CURRENT_CONTEST.get_vram_used()
with InferenceSandbox(
repository_info=submission,
baseline=False,
Expand Down Expand Up @@ -144,7 +146,7 @@ def compare_checkpoints(
raise InvalidSubmissionError(f"Failed to run inference") from e

average_time = sum(output.generation_time for output in outputs) / len(outputs)
vram_used = max(output.vram_used for output in outputs)
vram_used = max(output.vram_used for output in outputs) - start_vram
watts_used = max(output.watts_used for output in outputs)

with CURRENT_CONTEST.output_comparator() as output_comparator:
Expand Down
25 changes: 10 additions & 15 deletions neuron/neuron/submission_tester/vram_monitor.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,27 @@
import threading
import time
from threading import Thread, Event
from time import sleep

from .. import Contest

POLL_RATE_SECONDS = 0.1

SAMPLE_RATE_MS = 10

class VRamMonitor:
_thread: threading.Thread
_contest: Contest
_thread: Thread
_stop_flag: Event
_vram_usage: int = 0
_stop_flag: threading.Event

def __init__(self, contest: Contest):
self._contest = contest
self._stop_flag = threading.Event()
self._stop_flag = Event()

self._thread = threading.Thread(target=self.monitor)
self._thread = Thread(target=self._monitor)
self._thread.start()

def monitor(self):
def _monitor(self):
while not self._stop_flag.is_set():
vram = self._contest.get_vram_used()

if self._vram_usage < vram:
self._vram_usage = vram

time.sleep(POLL_RATE_SECONDS)
self._vram_usage = max(self._vram_usage, self._contest.get_vram_used())
sleep(SAMPLE_RATE_MS / 1000)

def complete(self) -> int:
self._stop_flag.set()
Expand Down
2 changes: 1 addition & 1 deletion validator/submission_tester/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def flush(self):
sys.stdout = LogsIO(sys.stdout, "out")
sys.stderr = LogsIO(sys.stderr, "err")
logging.basicConfig(
level=logging.DEBUG,
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(filename)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
stream=sys.stdout
Expand Down

0 comments on commit f7512fc

Please sign in to comment.