Skip to content

Commit

Permalink
Set weights using previous contest benchmarks (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
MsRandom authored Nov 1, 2024
1 parent 657c7fc commit 1532731
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions validator/weight_setting/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from .wandb_args import add_wandb_args
from .winner_selection import get_scores, get_contestant_scores, get_tiers, get_contestant_tier

VALIDATOR_VERSION: tuple[int, int, int] = (4, 5, 2)
VALIDATOR_VERSION: tuple[int, int, int] = (4, 5, 3)
VALIDATOR_VERSION_STRING = ".".join(map(str, VALIDATOR_VERSION))

WEIGHTS_VERSION = (
Expand Down Expand Up @@ -116,6 +116,7 @@ class Validator:
attempted_set_weights: bool = False

benchmarks: list[CheckpointBenchmark | None]
last_benchmarks: list[CheckpointBenchmark | None]
baseline_metrics: MetricData | None
average_benchmarking_time: float | None
benchmarking_state: BenchmarkState
Expand Down Expand Up @@ -166,6 +167,7 @@ def __init__(self):
self.wandb_run_date = None

self.benchmarks = self.clear_benchmarks()
self.last_benchmarks = self.clear_benchmarks()
self.baseline_metrics = None
self.average_benchmarking_time = None
self.benchmarking_state = BenchmarkState.NOT_STARTED
Expand Down Expand Up @@ -359,6 +361,7 @@ def save_state(self):
"step": self.step,
"hotkeys": self.hotkeys,
"benchmarks": self.benchmarks,
"last_benchmarks": self.last_benchmarks,
"baseline_benchmarks": self.baseline_metrics,
"average_benchmarking_time": self.average_benchmarking_time,
"benchmarking_state": self.benchmarking_state,
Expand Down Expand Up @@ -387,6 +390,7 @@ def load_state(self):
self.step = state["step"]
self.hotkeys = state["hotkeys"]
self.benchmarks = state.get("benchmarks", self.benchmarks)
self.last_benchmarks = state.get("last_benchmarks", self.last_benchmarks)
self.baseline_metrics = state.get("baseline_benchmarks", self.baseline_metrics)
self.average_benchmarking_time = state.get("average_benchmarking_time", self.average_benchmarking_time)
self.benchmarking_state = state.get("benchmarking_state", self.benchmarking_state)
Expand Down Expand Up @@ -493,10 +497,25 @@ def set_weights(self):
return

reuse_weights = False
equal_weights = False

if not self.contest_state:
logger.info("Will not set new weights as the contest state has not been set, setting to all ones")
equal_weights = True

if not self.last_benchmarks:
logger.info("Will not set new weights as the previous day's benchmarks have not been set, setting to all ones")
equal_weights = True

if not self.baseline_metrics:
logger.info("Will not calculate weights as the baseline benchmarks have not been set, reusing old weights")
reuse_weights = True

if self.benchmarking:
logger.info("Not setting new weights as benchmarking is not done, reusing old weights")
reuse_weights = True

if equal_weights:
uids = list(range(len(self.metagraph.nodes)))
weights = [1.0] * len(self.metagraph.nodes)

Expand All @@ -512,14 +531,6 @@ def set_weights(self):

return

if not self.baseline_metrics:
logger.info("Will not calculate weights as the baseline benchmarks have not been set, reusing old weights")
reuse_weights = True

if self.benchmarking:
logger.info("Not setting new weights as benchmarking is not done, reusing old weights")
reuse_weights = True

if reuse_weights:
zipped_weights = get_weights_set_by_node(self.substrate, self.metagraph.netuid, self.uid, self.block)

Expand Down Expand Up @@ -549,7 +560,7 @@ def set_weights(self):
if self.is_blacklisted(blacklisted_keys, hotkey, node.coldkey):
self.reset_miner(self.hotkeys.index(hotkey))

contestants = get_contestant_scores(self.benchmarks, self.baseline_metrics)
contestants = get_contestant_scores(self.last_benchmarks, self.baseline_metrics)
tiers = get_tiers(contestants)
blocks = [info.block if info else None for info in self.contest_state.miner_info]
weights = get_scores(tiers, blocks, len(self.metagraph.nodes))
Expand Down Expand Up @@ -701,6 +712,7 @@ async def do_step(self, block: int):
self.contest = CURRENT_CONTEST

self.contest_state = ContestState(self.contest.id, miner_info)
self.last_benchmarks = self.benchmarks
else:
self.contest_state.miner_info = miner_info

Expand Down

0 comments on commit 1532731

Please sign in to comment.