From a16d71755880e61f5ed6850342246faa5a6a7827 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 21:46:31 +0000 Subject: [PATCH] formatting --- .../workloads/fastmri/fastmri_pytorch/workload.py | 4 +++- .../workloads/imagenet_resnet/imagenet_pytorch/workload.py | 4 +++- .../librispeech_conformer/librispeech_jax/models.py | 6 +++++- .../librispeech_conformer/librispeech_pytorch/workload.py | 4 +++- algorithmic_efficiency/workloads/mnist/workload.py | 4 +++- 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index a2f0828e3..8924a9865 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -252,7 +252,9 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 0ed944191..3729bc53b 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -309,7 +309,9 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index db92f56d4..2b8250bd8 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -616,7 +616,11 @@ def __call__(self, inputs, input_paddings, train) inputs = inputs + \ - ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, use_running_average) + ConvolutionBlock(config)(inputs, + input_paddings, + train, + update_batch_norm, + use_running_average) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( inputs, padding_mask, train) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 31d069e88..0cec4116b 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -330,7 +330,9 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index 959228755..5407e8a35 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -214,6 +214,8 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } return self._normalize_eval_metrics(num_examples, total_metrics)