diff --git a/stable_diffusion/mlperf_logging_utils.py b/stable_diffusion/mlperf_logging_utils.py index b49c8d959..4f4c84d6c 100644 --- a/stable_diffusion/mlperf_logging_utils.py +++ b/stable_diffusion/mlperf_logging_utils.py @@ -105,15 +105,15 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu def on_train_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int) -> None: if trainer.global_step % self.train_log_interval == 0: - self.logger.start(key=mllog_constants.BLOCK_START, value="training_step", metadata={mllog_constants.STEP_NUM: trainer.global_step}) + self.logger.start(key=mllog_constants.BLOCK_START, value="training_step", metadata={mllog_constants.EPOCH_NUM: self._samples_count(trainer)}) def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: if trainer.global_step % self.train_log_interval == 0: logs = trainer.callback_metrics - self.logger.event(key="loss", value=logs["train/loss"].item(), metadata={mllog_constants.STEP_NUM: trainer.global_step}) - self.logger.event(key="lr_abs", value=logs["lr_abs"].item(), metadata={mllog_constants.STEP_NUM: trainer.global_step}) - self.logger.end(key=mllog_constants.BLOCK_STOP, value="training_step", metadata={mllog_constants.STEP_NUM: trainer.global_step}) + self.logger.event(key="loss", value=logs["train/loss"].item(), metadata={mllog_constants.EPOCH_NUM: self._samples_count(trainer)}) + self.logger.event(key="lr_abs", value=logs["lr_abs"].item(), metadata={mllog_constants.EPOCH_NUM: self._samples_count(trainer)}) + self.logger.end(key=mllog_constants.BLOCK_STOP, value="training_step", metadata={mllog_constants.EPOCH_NUM: self._samples_count(trainer)}) def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.logger.start(key=mllog_constants.EVAL_START, value=trainer.global_step) @@ -123,11 +123,11 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul if "validation/fid" in logs: self.logger.event(key=mllog_constants.EVAL_ACCURACY, value=logs["validation/fid"].item(), - metadata={mllog_constants.STEP_NUM: trainer.global_step, "metric": "FID"}) + metadata={mllog_constants.EPOCH_NUM: self._samples_count(trainer), "metric": "FID"}) if "validation/clip" in logs: self.logger.event(key=mllog_constants.EVAL_ACCURACY, value=logs["validation/clip"].item(), - metadata={mllog_constants.STEP_NUM: trainer.global_step, "metric": "CLIP"}) + metadata={mllog_constants.EPOCH_NUM: self._samples_count(trainer), "metric": "CLIP"}) self.logger.end(key=mllog_constants.EVAL_STOP, value=trainer.global_step) def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -146,5 +146,11 @@ def on_validation_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin if batch_idx % self.validation_log_interval == 0: self.logger.end(key=mllog_constants.BLOCK_STOP, value="validation_step", metadata={mllog_constants.STEP_NUM: batch_idx}) + def _samples_count(self, trainer: "pl.Trainer") -> int: + batch_size_per_gpu = trainer.train_dataloader.batch_size + num_gpus = trainer.num_gpus if trainer.num_gpus else 1 + global_batch_size = batch_size_per_gpu * num_gpus + + return global_batch_size * trainer.global_step mllogger = SDLogger()