diff --git a/dmlcloud/metrics.py b/dmlcloud/metrics.py index 09ff494..3a2644a 100644 --- a/dmlcloud/metrics.py +++ b/dmlcloud/metrics.py @@ -105,6 +105,9 @@ def reduce_and_append(self, value): self.values.append(value) def reduce_locally(self): + if len(self.values) == 0: + return None + if isinstance(self.dim, list): dim = [0] + [d + 1 for d in self.dim] elif isinstance(self.dim, int): @@ -115,14 +118,26 @@ def reduce_locally(self): tensor = reduce_tensor(tensor, reduction=self.reduction, dim=dim) return tensor - def reduce_globally(self, group=None, async_op=False): + def reduce_globally(self, group=None): + # if the list of values is empty, the result is None + if self.globally: + empty_workers = [None] * dist.get_world_size(group) + dist.all_gather_object(empty_workers, len(self.values) == 0, group=group) + if any(empty_workers): + if len(empty_workers) > 1 and not all(empty_workers): + raise ValueError('Some workers tracked values this epoch and some did not. This is likely a bug.') + else: + return None + elif len(self.values) == 0: + return None + tensor = self.reduce_locally() if self.globally: if self.reduction == Reduction.MEAN: - dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_op) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) tensor /= dist.get_world_size(group) else: - dist.all_reduce(tensor, op=self.reduction.as_torch(), group=group, async_op=async_op) + dist.all_reduce(tensor, op=self.reduction.as_torch(), group=group) return tensor def state_dict(self): diff --git a/dmlcloud/pipeline.py b/dmlcloud/pipeline.py index 7a7ac82..e4a5609 100644 --- a/dmlcloud/pipeline.py +++ b/dmlcloud/pipeline.py @@ -94,8 +94,8 @@ def register_dataset(self, name: str, dataset: Union[DataLoader, Dataset, Sequen msg += f' - Batches (Total): ~{length * dist.get_world_size()}\n' msg += f' - Batches (/Worker): {length}\n' except TypeError: # __len__ not implemented - msg += f' - Batches (Total): N/A\n' - msg += f' - Batches (/Worker): N/A\n' + msg += ' - Batches (Total): N/A\n' + msg += ' - Batches (/Worker): N/A\n' self.logger.info(msg) def append_stage(self, stage: Stage, max_epochs: Optional[int] = None, name: Optional[str] = None): diff --git a/test/test_metrics.py b/test/test_metrics.py index 1deca76..8787517 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -79,6 +79,14 @@ def test_serialization(self): assert new_reducer.dim == [1, 2, 3] assert new_reducer.values == reducer.values + def test_empty_reduction(self, torch_distributed): + reducer = MetricReducer(reduction=Reduction.MIN, globally=True) + result = reducer.reduce_locally() + assert result is None + + result = reducer.reduce_globally() + assert result is None + class TestMetricTracker: def test_dictionary(self):