Skip to content

Commit

Permalink
Fix multiprocessing dataloader checkpointing and use it in the train …
Browse files Browse the repository at this point in the history
…script

Summary:

Test Plan:
  • Loading branch information
EntilZha committed Feb 13, 2025
1 parent 85c2f28 commit 0c6cb99
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 13 deletions.
2 changes: 0 additions & 2 deletions bytelatent/args.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import os
from typing import Any

import fsspec
import numpy as np
import yaml
from omegaconf import OmegaConf
Expand Down
10 changes: 10 additions & 0 deletions bytelatent/data/iterators/abstract_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@ class IteratorState(Generic[C]):
@abc.abstractmethod
def build(self) -> StatefulIterator[T, C]:
pass


def get_state_and_refresh(iterator: StatefulIterator):
# Re-init dataloader and iterator is necessary since get_state()
# on mp iterator shuts down MP to correctly persist state and it needs
# to be restarted.
state = iterator.get_state()
data_loader = state.build()
py_iterator = data_loader.create_iter()
return state, data_loader, py_iterator
4 changes: 2 additions & 2 deletions bytelatent/data/iterators/arrow_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def create_iter(

def _set_row_num(self, target_row_num: int):
logger.info(
f"Setting arrow position to {target_row_num} for {self.dataset_files}"
f"Setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}"
)
if target_row_num is None or target_row_num == 0:
self.row_num = 0
Expand Down Expand Up @@ -286,5 +286,5 @@ def _set_row_num(self, target_row_num: int):
curr_remaining -= len(batch)
self.row_num = target_row_num
logger.info(
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
f"Finished setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}"
)
27 changes: 19 additions & 8 deletions bytelatent/data/iterators/multiprocess_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ def start_work_from_state(
if stop_event.is_set():
# Signal the end of output, this ensures that even if the queue takes a while to
# buffer, that the main thread receives everything (and tosses this fake batch)
logging.info(
logging.debug(
"Worker thread: Stop event detected, outputting is_final=True batch"
)
logging.debug("Worker thread: batch_queue full=%s", batch_queue.full())
batch_queue.put(
Batch(
x=np.zeros((1, 1)),
Expand All @@ -67,14 +68,17 @@ def start_work_from_state(
ngram_ids=None,
)
)
logging.debug(
"Worker thread: is_final=True batch put in queue, breaking from loop."
)
break

try:
logging.info("Worker thread: outputting state")
state_queue.put(iterator.get_state(), timeout=1)
logging.info("Worker thread: state dump complete")
logging.debug("Worker thread: outputting state")
state_queue.put(stateful_iterator.get_state(), timeout=1)
logging.debug("Worker thread: state dump complete")
state_dumped_event.set()
logging.info("Worker thread: set state_dump_event")
logging.debug("Worker thread: set state_dump_event")
except Full:
raise ValueError(
"Attempted to dump state into the state queue, but it was full"
Expand Down Expand Up @@ -156,23 +160,30 @@ def get_state(self) -> MultiprocessIteratorState:
serialized_prefetch_buffer=serialized_prefetch_buffer,
)
else:
logging.info("Main thread: Sending stop iteration event")
logging.debug("Main thread: Sending stop iteration event")
self.stop_iterating_event.set()
logging.info("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()
logging.debug(
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
)
self.prefetch_buffer = []
final_batch_received = False
while True:
try:
batch = self.batch_queue.get(timeout=1)
if batch.is_final:
logging.debug(
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
)
final_batch_received = True
break
self.prefetch_buffer.append(batch)
except Empty:
logging.warning("Main thread: batch_queue is abnormally empty")
assert final_batch_received

logging.debug("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()

try:
base_iterator_state = self.state_queue.get(timeout=1)
assert isinstance(base_iterator_state, IteratorState)
Expand Down
11 changes: 10 additions & 1 deletion bytelatent/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from bytelatent.args import TrainArgs, parse_args
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
from bytelatent.data.iterators.multiprocess_iterator import (
MultiprocessIterator,
MultiprocessIteratorState,
Expand All @@ -35,7 +36,6 @@
check_model_value_range,
clean_env,
dist_mean,
dist_mean_dict,
dist_sum,
get_device_mesh,
get_is_master,
Expand Down Expand Up @@ -702,6 +702,9 @@ def train(args: TrainArgs):
if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
saved = checkpoint.save(
model,
optimizer,
Expand Down Expand Up @@ -743,6 +746,9 @@ def train(args: TrainArgs):

if preemption_flag["flag"]:
if not saved:
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
checkpoint.save(
model,
optimizer,
Expand All @@ -754,6 +760,9 @@ def train(args: TrainArgs):
sys.exit(0)

if not saved:
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
checkpoint.save(
model,
optimizer,
Expand Down

0 comments on commit 0c6cb99

Please sign in to comment.