diff --git a/dmlcloud/core/metrics.py b/dmlcloud/core/metrics.py index b054465..bd965b2 100644 --- a/dmlcloud/core/metrics.py +++ b/dmlcloud/core/metrics.py @@ -170,28 +170,28 @@ def __init__(self): self.external_metrics = torch.nn.ModuleDict() def add_metric(self, name: str, metric: torchmetrics.Metric): - if name in self.external_metrics or name in self.metrics: + if name in self.metrics: raise ValueError(f'Metric {name} already exists') - self.external_metrics[name] = metric + self.metrics[name] = metric - def log(self, name: str, value: Any, reduction: str = 'mean'): - if reduction not in ['mean', 'sum', 'min', 'max']: - raise ValueError(f'Invalid reduction {reduction}. Must be one of mean, sum, min, max') - - if name in self.external_metrics: - raise ValueError(f'Metric {name} is a external metric. Please use the .update() method yourself.') + def log(self, name: str, value: Any, reduction: str = 'mean', **kwargs): + if reduction not in ['mean', 'sum', 'min', 'max', 'cat']: + raise ValueError(f'Invalid reduction {reduction}. Must be one of mean, sum, min, max, cat') if name not in self.metrics: if reduction == 'mean': - metric = torchmetrics.MeanMetric() + metric = torchmetrics.MeanMetric(**kwargs) elif reduction == 'sum': - metric = torchmetrics.SumMetric() + metric = torchmetrics.SumMetric(**kwargs) elif reduction == 'min': - metric = torchmetrics.MinMetric() + metric = torchmetrics.MinMetric(**kwargs) elif reduction == 'max': - metric = torchmetrics.MaxMetric() - self.metrics[name] = metric.to(value.device) + metric = torchmetrics.MaxMetric(**kwargs) + elif reduction == 'cat': + metric = torchmetrics.CatMetric(**kwargs) + device = value.device if torch.is_tensor(value) else torch.device('cpu') + self.add_metric(name, metric.to(device)) self.metrics[name].update(value) @@ -200,15 +200,9 @@ def reduce(self): for name, metric in self.metrics.items(): values[name] = metric.compute() metric.reset() - for name, metric in self.external_metrics.items(): - values[name] = metric.compute() - metric.reset() return values def clear(self): for metric in self.metrics.values(): metric.reset() - for metric in self.external_metrics.values(): - metric.reset() - self.metrics.clear() - self.external_metrics.clear() + self.metrics.clear() \ No newline at end of file