diff --git a/dmlcloud/core/metrics.py b/dmlcloud/core/metrics.py
index b054465..bd965b2 100644
--- a/dmlcloud/core/metrics.py
+++ b/dmlcloud/core/metrics.py
@@ -170,28 +170,28 @@ def __init__(self):
         self.external_metrics = torch.nn.ModuleDict()
 
     def add_metric(self, name: str, metric: torchmetrics.Metric):
-        if name in self.external_metrics or name in self.metrics:
+        if name in self.metrics:
             raise ValueError(f'Metric {name} already exists')
 
-        self.external_metrics[name] = metric
+        self.metrics[name] = metric
 
-    def log(self, name: str, value: Any, reduction: str = 'mean'):
-        if reduction not in ['mean', 'sum', 'min', 'max']:
-            raise ValueError(f'Invalid reduction {reduction}. Must be one of mean, sum, min, max')
-
-        if name in self.external_metrics:
-            raise ValueError(f'Metric {name} is a external metric. Please use the .update() method yourself.')
+    def log(self, name: str, value: Any, reduction: str = 'mean', **kwargs):
+        if reduction not in ['mean', 'sum', 'min', 'max', 'cat']:
+            raise ValueError(f'Invalid reduction {reduction}. Must be one of mean, sum, min, max, cat')
 
         if name not in self.metrics:
             if reduction == 'mean':
-                metric = torchmetrics.MeanMetric()
+                metric = torchmetrics.MeanMetric(**kwargs)
             elif reduction == 'sum':
-                metric = torchmetrics.SumMetric()
+                metric = torchmetrics.SumMetric(**kwargs)
             elif reduction == 'min':
-                metric = torchmetrics.MinMetric()
+                metric = torchmetrics.MinMetric(**kwargs)
             elif reduction == 'max':
-                metric = torchmetrics.MaxMetric()
-            self.metrics[name] = metric.to(value.device)
+                metric = torchmetrics.MaxMetric(**kwargs)
+            elif reduction == 'cat':
+                metric = torchmetrics.CatMetric(**kwargs)
+            device = value.device if torch.is_tensor(value) else torch.device('cpu')
+            self.add_metric(name, metric.to(device))
 
         self.metrics[name].update(value)
 
@@ -200,15 +200,9 @@ def reduce(self):
         for name, metric in self.metrics.items():
             values[name] = metric.compute()
             metric.reset()
-        for name, metric in self.external_metrics.items():
-            values[name] = metric.compute()
-            metric.reset()
         return values
 
     def clear(self):
         for metric in self.metrics.values():
             metric.reset()
-        for metric in self.external_metrics.values():
-            metric.reset()
-        self.metrics.clear()
-        self.external_metrics.clear()
+        self.metrics.clear()
\ No newline at end of file