Skip to content

Commit

Permalink
fix: metric reduction with empty values
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Mar 18, 2024
1 parent c9b027f commit 3c8cf34
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
21 changes: 18 additions & 3 deletions dmlcloud/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def reduce_and_append(self, value):
self.values.append(value)

def reduce_locally(self):
if len(self.values) == 0:
return None

if isinstance(self.dim, list):
dim = [0] + [d + 1 for d in self.dim]
elif isinstance(self.dim, int):
Expand All @@ -115,14 +118,26 @@ def reduce_locally(self):
tensor = reduce_tensor(tensor, reduction=self.reduction, dim=dim)
return tensor

def reduce_globally(self, group=None, async_op=False):
def reduce_globally(self, group=None):
# if the list of values is empty, the result is None
if self.globally:
empty_workers = [None] * dist.get_world_size(group)
dist.all_gather_object(empty_workers, len(self.values) == 0, group=group)
if any(empty_workers):
if len(empty_workers) > 1 and not all(empty_workers):
raise ValueError('Some workers tracked values this epoch and some did not. This is likely a bug.')
else:
return None
elif len(self.values) == 0:
return None

tensor = self.reduce_locally()
if self.globally:
if self.reduction == Reduction.MEAN:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_op)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
tensor /= dist.get_world_size(group)
else:
dist.all_reduce(tensor, op=self.reduction.as_torch(), group=group, async_op=async_op)
dist.all_reduce(tensor, op=self.reduction.as_torch(), group=group)
return tensor

def state_dict(self):
Expand Down
4 changes: 2 additions & 2 deletions dmlcloud/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def register_dataset(self, name: str, dataset: Union[DataLoader, Dataset, Sequen
msg += f' - Batches (Total): ~{length * dist.get_world_size()}\n'
msg += f' - Batches (/Worker): {length}\n'
except TypeError: # __len__ not implemented
msg += f' - Batches (Total): N/A\n'
msg += f' - Batches (/Worker): N/A\n'
msg += ' - Batches (Total): N/A\n'
msg += ' - Batches (/Worker): N/A\n'
self.logger.info(msg)

def append_stage(self, stage: Stage, max_epochs: Optional[int] = None, name: Optional[str] = None):
Expand Down
8 changes: 8 additions & 0 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ def test_serialization(self):
assert new_reducer.dim == [1, 2, 3]
assert new_reducer.values == reducer.values

def test_empty_reduction(self, torch_distributed):
reducer = MetricReducer(reduction=Reduction.MIN, globally=True)
result = reducer.reduce_locally()
assert result is None

result = reducer.reduce_globally()
assert result is None


class TestMetricTracker:
def test_dictionary(self):
Expand Down

0 comments on commit 3c8cf34

Please sign in to comment.