diff --git a/MaxText/input_pipeline/input_pipeline_interface.py b/MaxText/input_pipeline/input_pipeline_interface.py index 0ea25720d..d75eee3ac 100644 --- a/MaxText/input_pipeline/input_pipeline_interface.py +++ b/MaxText/input_pipeline/input_pipeline_interface.py @@ -116,7 +116,7 @@ def get_process_loading_real_data( batch_cutoff = global_batch_size_to_train_on process_loading_real_data = set() for p, indices in devices_indices_map.items(): - if indices[0].stop <= batch_cutoff: + if not indices[0].stop or indices[0].stop <= batch_cutoff: process_loading_real_data.add(p.process_index) return list(process_loading_real_data)