diff --git a/MaxText/train.py b/MaxText/train.py
index 8c3ed1331..711a5e82d 100644
--- a/MaxText/train.py
+++ b/MaxText/train.py
@@ -963,12 +963,13 @@ def train_loop(config, state=None):
     # pytype: disable=attribute-error
     compiled = p_train_step.lower(state, example_batch, nextrng).compile()
     compiled_stats = compiled.memory_analysis()
-    max_logging.log(
-        f"Output size: {compiled_stats.output_size_in_bytes}, "
-        f"temp size: {compiled_stats.temp_size_in_bytes}, "
-        f"argument size: {compiled_stats.argument_size_in_bytes}, "
-        f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes."
-    )
+    if compiled_stats is not None:
+      max_logging.log(
+          f"Output size: {compiled_stats.output_size_in_bytes}, "
+          f"temp size: {compiled_stats.temp_size_in_bytes}, "
+          f"argument size: {compiled_stats.argument_size_in_bytes}, "
+          f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes."
+      )
   return state