Skip to content

Commit

Permalink
⚡ fixing issue with loading from checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
rbturnbull committed Nov 18, 2024
1 parent 60a6473 commit 7505b1f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
69 changes: 66 additions & 3 deletions hierarchicalsoftmax/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from sklearn.metrics import f1_score, precision_score, recall_score
import torch
from torchmetrics import Metric
from typing import Callable
from collections.abc import Sequence
from torch.nn import Module
from torch import Tensor
from torchmetrics.metric import Metric, apply_to_collection

from . import inference, nodes
from .inference import ShapeError

Expand Down Expand Up @@ -225,22 +230,80 @@ def compute(self):


class RankAccuracyTorchMetric(Metric):
def __init__(self, root:nodes.SoftmaxNode, ranks:dict[int,str], name:str="rank_accuracy"):
def __init__(self, root, ranks: dict[int, str], name: str = "rank_accuracy"):
super().__init__()
self.root = root
self.ranks = ranks
self.name = name

# Use `add_state` for metrics to handle distributed reduction and device placement
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

for rank_name in ranks.values():
self.add_state(rank_name, default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, predictions, targets):
# Ensure tensors match the device
predictions = predictions.to(self.device)
targets = targets.to(self.device)

self.total += targets.size(0)
depth_accurate_tensor = depth_accurate(predictions, targets, self.root)

for depth, rank_name in self.ranks.items():
accurate_at_depth = (depth_accurate_tensor >= depth).sum()
setattr(self, rank_name, getattr(self, rank_name) + accurate_at_depth)

def compute(self):
return {rank_name: getattr(self, rank_name) / self.total for rank_name in self.ranks.values()}
# Compute final metric values
return {
rank_name: getattr(self, rank_name) / self.total
for rank_name in self.ranks.values()
}

def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module:
"""Overwrite `_apply` function such that we can also move metric states to the correct device.
This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods
are called. Dtype conversion is guarded and will only happen through the special `set_dtype` method.
Overriding because there is an issue device in the parent class.
Args:
fn: the function to apply
exclude_state: list of state variables to exclude from applying the function, that then needs to be handled
by the metric class itself.
"""
this = super(Metric, self)._apply(fn)
fs = str(fn)
cond = any(f in fs for f in ["Module.type", "Module.half", "Module.float", "Module.double", "Module.bfloat16"])
if not self._dtype_convert and cond:
return this

# Also apply fn to metric states and defaults
for key, value in this._defaults.items():
if key in exclude_state:
continue

if isinstance(value, Tensor):
this._defaults[key] = fn(value)
elif isinstance(value, Sequence):
this._defaults[key] = [fn(v) for v in value]

current_val = getattr(this, key)
if isinstance(current_val, Tensor):
setattr(this, key, fn(current_val))
elif isinstance(current_val, Sequence):
setattr(this, key, [fn(cur_v) for cur_v in current_val])
else:
raise TypeError(
f"Expected metric state to be either a Tensor or a list of Tensor, but encountered {current_val}"
)

# Additional apply to forward cache and computed attributes (may be nested)
if this._computed is not None:
this._computed = apply_to_collection(this._computed, Tensor, fn)
if this._forward_cache is not None:
this._forward_cache = apply_to_collection(this._forward_cache, Tensor, fn)

return this
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "hierarchicalsoftmax"
version = "1.2.0"
version = "1.2.1"
description = "A Hierarchical Softmax Framework for PyTorch."
authors = ["Robert Turnbull <[email protected]>"]
license = "Apache-2.0"
Expand Down

0 comments on commit 7505b1f

Please sign in to comment.