diff --git a/MaxText/train.py b/MaxText/train.py index 8c3ed1331..711a5e82d 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -963,12 +963,13 @@ def train_loop(config, state=None): # pytype: disable=attribute-error compiled = p_train_step.lower(state, example_batch, nextrng).compile() compiled_stats = compiled.memory_analysis() - max_logging.log( - f"Output size: {compiled_stats.output_size_in_bytes}, " - f"temp size: {compiled_stats.temp_size_in_bytes}, " - f"argument size: {compiled_stats.argument_size_in_bytes}, " - f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes." - ) + if compiled_stats is not None: + max_logging.log( + f"Output size: {compiled_stats.output_size_in_bytes}, " + f"temp size: {compiled_stats.temp_size_in_bytes}, " + f"argument size: {compiled_stats.argument_size_in_bytes}, " + f"host temp size: {compiled_stats.host_temp_size_in_bytes}, in bytes." + ) return state