Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
formatting
Browse files Browse the repository at this point in the history
priyakasimbeg committed Oct 17, 2024
1 parent 16484db commit a16d717
Showing 5 changed files with 17 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -252,7 +252,9 @@ def _eval_model_on_split(self,
for _ in range(num_batches):
batch = next(self._eval_iters[split])
batch_metrics = self._eval_model(params, batch, model_rng)
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
Original file line number Diff line number Diff line change
@@ -309,7 +309,9 @@ def _eval_model_on_split(self,
update_batch_norm=False)
weights = batch.get('weights')
batch_metrics = self._compute_metrics(logits, batch['targets'], weights)
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
Original file line number Diff line number Diff line change
@@ -616,7 +616,11 @@ def __call__(self,
inputs, input_paddings, train)

inputs = inputs + \
ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, use_running_average)
ConvolutionBlock(config)(inputs,
input_paddings,
train,
update_batch_norm,
use_running_average)

inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
inputs, padding_mask, train)
Original file line number Diff line number Diff line change
@@ -330,7 +330,9 @@ def _eval_model_on_split(self,
'word_errors': word_errors,
'num_words': num_words,
}
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
4 changes: 3 additions & 1 deletion algorithmic_efficiency/workloads/mnist/workload.py
Original file line number Diff line number Diff line change
@@ -214,6 +214,8 @@ def _eval_model_on_split(self,
batch,
model_state,
per_device_model_rngs)
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}

return self._normalize_eval_metrics(num_examples, total_metrics)

0 comments on commit a16d717

Please sign in to comment.