Skip to content

Commit

Permalink
Fix type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
MsRandom committed Oct 16, 2024
1 parent 495e4c6 commit 1e2ee36
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 56 deletions.
6 changes: 3 additions & 3 deletions miner/miner/benchmarker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from multiprocessing.connection import Client
from multiprocessing.connection import Client, Connection
from os.path import abspath
from pathlib import Path
from subprocess import Popen
Expand Down Expand Up @@ -40,7 +40,7 @@ def wait_for_socket(socket_path: str, process: Popen):
raise RuntimeError(f"Socket file '{socket_path}' not found after {safe_timeout} seconds. Your pipeline is taking too long to load. Please optimize and avoid precompiling if possible.")


def test(contest: Contest, client: Client):
def test(contest: Contest, client: Connection):
outputs: list[GenerationOutput] = []
inputs = random_inputs()

Expand Down Expand Up @@ -70,7 +70,7 @@ def test(contest: Contest, client: Client):
)


def benchmark(contest: Contest, client: Client, request: TextToImageRequest):
def benchmark(contest: Contest, client: Connection, request: TextToImageRequest):
start_joules = contest.get_joules()
vram_monitor = VRamMonitor(contest)
start = perf_counter()
Expand Down
2 changes: 1 addition & 1 deletion miner/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion miner/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ edge-maxxing-neuron = { path = "../neuron", develop = true }
submit_model = 'miner.submit:main'

[tool.poetry.dev-dependencies]
pytype = "2024.4.11"
pytype = "2024.10.11"

[build-system]
requires = ["poetry-core"]
Expand Down
5 changes: 1 addition & 4 deletions neuron/neuron/contest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from enum import Enum
from functools import partial
from io import BytesIO
from typing import TypeVar, Callable

from pydantic import BaseModel
from transformers import CLIPProcessor, CLIPVisionModelWithProjection

RequestT = TypeVar("RequestT", bound=BaseModel)
ResponseT = TypeVar("ResponseT")


class ModelRepositoryInfo(BaseModel):
url: str
Expand Down
2 changes: 1 addition & 1 deletion neuron/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ opencv-python = "4.10.0.84"
edge-maxxing-pipelines = { path = "../pipelines", develop = true }

[tool.poetry.dev-dependencies]
pytype = "2024.4.11"
pytype = "2024.10.11"

[build-system]
requires = ["poetry-core"]
Expand Down
2 changes: 1 addition & 1 deletion pipelines/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ python = ">=3.10,<3.13"
pydantic = ">=2"

[tool.poetry.dev-dependencies]
pytype = "2024.4.11"
pytype = "2024.10.11"

[build-system]
requires = ["poetry-core"]
Expand Down
52 changes: 25 additions & 27 deletions validator/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions validator/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ edge-maxxing-neuron = { path = "../neuron", develop = true }
start_validator = 'weight_setting.validator:main'

[tool.poetry.dev-dependencies]
pytype = "2024.4.11"
pytype = "2024.10.11"

[build-system]
requires = ["poetry-core"]
Expand All @@ -54,4 +54,4 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry-monorepo.deps]

[tool.pytype]
inputs = ["validator"]
inputs = ["base_validator", "submission_tester", "weight_setting"]
8 changes: 4 additions & 4 deletions validator/submission_tester/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from contextlib import asynccontextmanager
from io import TextIOWrapper
from queue import Queue
from typing import Annotated
from typing import Annotated, TextIO

from fastapi import FastAPI, WebSocket, Request, Header, HTTPException
from starlette import status
Expand All @@ -29,10 +29,10 @@


class LogsIO(TextIOWrapper):
old_stdout: TextIOWrapper
old_stdout: TextIO
log_type: str

def __init__(self, old_stdout, log_type: str):
def __init__(self, old_stdout: TextIO, log_type: str):
super().__init__(old_stdout.buffer, encoding=old_stdout.encoding, errors=old_stdout.errors, newline=old_stdout.newlines)
self.old_stdout = old_stdout
self.log_type = log_type
Expand Down Expand Up @@ -115,7 +115,7 @@ async def start_benchmarking(

benchmarker.start_timestamp = timestamp

await benchmarker.start_benchmarking(submissions)
await benchmarker.start_benchmarking(submissions)


@app.get("/state")
Expand Down
8 changes: 4 additions & 4 deletions validator/submission_tester/benchmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(self):
self.benchmarks = {}
self.baseline = None
self.inputs = []
self.started = False
self.done = True
self.start_timestamp = 0
self.lock = Lock()
Expand Down Expand Up @@ -67,7 +66,6 @@ async def _start_benchmarking(self, submissions: dict[Key, ModelRepositoryInfo])
self.benchmarks = {}
self.submission_times = []
self.inputs = random_inputs()
self.started = True
self.done = False

if not self.baseline or self.baseline.inputs != self.inputs:
Expand Down Expand Up @@ -98,8 +96,10 @@ async def _start_benchmarking(self, submissions: dict[Key, ModelRepositoryInfo])
self.done = True

async def start_benchmarking(self, submissions: dict[Key, ModelRepositoryInfo]):
if not self.done and self.started:
self.benchmark_task.cancel()
benchmark_task = self.benchmark_task

if not self.done and benchmark_task:
benchmark_task.cancel()

self.submissions = submissions
self.benchmarks = {}
Expand Down
11 changes: 7 additions & 4 deletions validator/submission_tester/inference_sandbox.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import logging
import os
import socket
import sys
import time
from io import TextIOWrapper
from multiprocessing.connection import Client, Connection
from os.path import abspath
from pathlib import Path
from subprocess import Popen, run, TimeoutExpired, PIPE
from typing import Generic
from threading import Thread
from typing import Generic, TypeVar

from pydantic import BaseModel

from neuron import (
RequestT,
INFERENCE_SOCKET_TIMEOUT,
ModelRepositoryInfo,
setup_sandbox,
InvalidSubmissionError, SPEC_VERSION,
InvalidSubmissionError,
)

SANDBOX_DIRECTORY = Path("/sandbox")
Expand All @@ -25,6 +25,9 @@
logger = logging.getLogger(__name__)


RequestT = TypeVar("RequestT", bound=BaseModel)


def sandbox_args(user: str):
return [
"/bin/sudo",
Expand Down
9 changes: 5 additions & 4 deletions validator/submission_tester/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ModelRepositoryInfo,
CURRENT_CONTEST,
Key,
OutputComparator,
)
from pipelines import TextToImageRequest
from .inference_sandbox import InferenceSandbox, InvalidSubmissionError
Expand Down Expand Up @@ -149,9 +150,9 @@ def compare_checkpoints(
f"Max Power Usage: {watts_used}W"
)

comparator = CURRENT_CONTEST.output_comparator()
output_comparator = CURRENT_CONTEST.output_comparator()

def calculate_similarity(baseline_output: GenerationOutput, optimized_output: GenerationOutput):
def calculate_similarity(comparator: OutputComparator, baseline_output: GenerationOutput, optimized_output: GenerationOutput):
try:
return comparator(baseline_output.output, optimized_output.output)
except:
Expand All @@ -163,11 +164,11 @@ def calculate_similarity(baseline_output: GenerationOutput, optimized_output: Ge
return 0.0

average_similarity = mean(
calculate_similarity(baseline_output, output)
calculate_similarity(output_comparator, baseline_output, output)
for baseline_output, output in zip(baseline.outputs, outputs)
)

del comparator
del output_comparator
CURRENT_CONTEST.clear_cache()

benchmark = CheckpointBenchmark(
Expand Down

0 comments on commit 1e2ee36

Please sign in to comment.