From 7505b1fc91077a8eeb103702df80fdaff63ac397 Mon Sep 17 00:00:00 2001 From: Robert Turnbull Date: Mon, 18 Nov 2024 16:47:44 +1100 Subject: [PATCH] :zap: fixing issue with loading from checkpoint --- hierarchicalsoftmax/metrics.py | 69 ++++++++++++++++++++++++++++++++-- pyproject.toml | 2 +- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/hierarchicalsoftmax/metrics.py b/hierarchicalsoftmax/metrics.py index 6017c98..16934af 100644 --- a/hierarchicalsoftmax/metrics.py +++ b/hierarchicalsoftmax/metrics.py @@ -1,6 +1,11 @@ from sklearn.metrics import f1_score, precision_score, recall_score import torch -from torchmetrics import Metric +from typing import Callable +from collections.abc import Sequence +from torch.nn import Module +from torch import Tensor +from torchmetrics.metric import Metric, apply_to_collection + from . import inference, nodes from .inference import ShapeError @@ -225,22 +230,80 @@ def compute(self): class RankAccuracyTorchMetric(Metric): - def __init__(self, root:nodes.SoftmaxNode, ranks:dict[int,str], name:str="rank_accuracy"): + def __init__(self, root, ranks: dict[int, str], name: str = "rank_accuracy"): super().__init__() self.root = root self.ranks = ranks self.name = name + + # Use `add_state` for metrics to handle distributed reduction and device placement self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + for rank_name in ranks.values(): self.add_state(rank_name, default=torch.tensor(0), dist_reduce_fx="sum") def update(self, predictions, targets): + # Ensure tensors match the device + predictions = predictions.to(self.device) + targets = targets.to(self.device) + self.total += targets.size(0) depth_accurate_tensor = depth_accurate(predictions, targets, self.root) + for depth, rank_name in self.ranks.items(): accurate_at_depth = (depth_accurate_tensor >= depth).sum() setattr(self, rank_name, getattr(self, rank_name) + accurate_at_depth) def compute(self): - return {rank_name: getattr(self, rank_name) / self.total for rank_name in self.ranks.values()} + # Compute final metric values + return { + rank_name: getattr(self, rank_name) / self.total + for rank_name in self.ranks.values() + } + + def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module: + """Overwrite `_apply` function such that we can also move metric states to the correct device. + + This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods + are called. Dtype conversion is guarded and will only happen through the special `set_dtype` method. + + Overriding because there is an issue device in the parent class. + + Args: + fn: the function to apply + exclude_state: list of state variables to exclude from applying the function, that then needs to be handled + by the metric class itself. + """ + this = super(Metric, self)._apply(fn) + fs = str(fn) + cond = any(f in fs for f in ["Module.type", "Module.half", "Module.float", "Module.double", "Module.bfloat16"]) + if not self._dtype_convert and cond: + return this + + # Also apply fn to metric states and defaults + for key, value in this._defaults.items(): + if key in exclude_state: + continue + + if isinstance(value, Tensor): + this._defaults[key] = fn(value) + elif isinstance(value, Sequence): + this._defaults[key] = [fn(v) for v in value] + + current_val = getattr(this, key) + if isinstance(current_val, Tensor): + setattr(this, key, fn(current_val)) + elif isinstance(current_val, Sequence): + setattr(this, key, [fn(cur_v) for cur_v in current_val]) + else: + raise TypeError( + f"Expected metric state to be either a Tensor or a list of Tensor, but encountered {current_val}" + ) + + # Additional apply to forward cache and computed attributes (may be nested) + if this._computed is not None: + this._computed = apply_to_collection(this._computed, Tensor, fn) + if this._forward_cache is not None: + this._forward_cache = apply_to_collection(this._forward_cache, Tensor, fn) + return this diff --git a/pyproject.toml b/pyproject.toml index 8a24216..d067bf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "hierarchicalsoftmax" -version = "1.2.0" +version = "1.2.1" description = "A Hierarchical Softmax Framework for PyTorch." authors = ["Robert Turnbull "] license = "Apache-2.0"