Skip to content

Commit

Permalink
fix: Tracker.log() when called with non-tensor arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Dec 26, 2024
1 parent 6f58811 commit 57e10a5
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions dmlcloud/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,28 +170,28 @@ def __init__(self):
self.external_metrics = torch.nn.ModuleDict()

def add_metric(self, name: str, metric: torchmetrics.Metric):
if name in self.external_metrics or name in self.metrics:
if name in self.metrics:
raise ValueError(f'Metric {name} already exists')

self.external_metrics[name] = metric
self.metrics[name] = metric

def log(self, name: str, value: Any, reduction: str = 'mean'):
if reduction not in ['mean', 'sum', 'min', 'max']:
raise ValueError(f'Invalid reduction {reduction}. Must be one of mean, sum, min, max')

if name in self.external_metrics:
raise ValueError(f'Metric {name} is a external metric. Please use the .update() method yourself.')
def log(self, name: str, value: Any, reduction: str = 'mean', **kwargs):
if reduction not in ['mean', 'sum', 'min', 'max', 'cat']:
raise ValueError(f'Invalid reduction {reduction}. Must be one of mean, sum, min, max, cat')

if name not in self.metrics:
if reduction == 'mean':
metric = torchmetrics.MeanMetric()
metric = torchmetrics.MeanMetric(**kwargs)
elif reduction == 'sum':
metric = torchmetrics.SumMetric()
metric = torchmetrics.SumMetric(**kwargs)
elif reduction == 'min':
metric = torchmetrics.MinMetric()
metric = torchmetrics.MinMetric(**kwargs)
elif reduction == 'max':
metric = torchmetrics.MaxMetric()
self.metrics[name] = metric.to(value.device)
metric = torchmetrics.MaxMetric(**kwargs)
elif reduction == 'cat':
metric = torchmetrics.CatMetric(**kwargs)
device = value.device if torch.is_tensor(value) else torch.device('cpu')
self.add_metric(name, metric.to(device))

self.metrics[name].update(value)

Expand All @@ -200,15 +200,9 @@ def reduce(self):
for name, metric in self.metrics.items():
values[name] = metric.compute()
metric.reset()
for name, metric in self.external_metrics.items():
values[name] = metric.compute()
metric.reset()
return values

def clear(self):
for metric in self.metrics.values():
metric.reset()
for metric in self.external_metrics.values():
metric.reset()
self.metrics.clear()
self.external_metrics.clear()
self.metrics.clear()

0 comments on commit 57e10a5

Please sign in to comment.