Skip to content

Commit

Permalink
Fix multiprocessing dataloader checkpointing and use it in the train …
Browse files Browse the repository at this point in the history
…script

Summary:

Test Plan:
  • Loading branch information
EntilZha committed Feb 11, 2025
1 parent fe45f69 commit 38cc67a
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 10 deletions.
10 changes: 10 additions & 0 deletions bytelatent/data/iterators/abstract_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@ class IteratorState(Generic[C]):
@abc.abstractmethod
def build(self) -> StatefulIterator[T, C]:
pass


def get_state_and_refresh(iterator: StatefulIterator):
# Re-init dataloader and iterator is necessary since get_state()
# on mp iterator shuts down MP to correctly persist state and it needs
# to be restarted.
state = iterator.get_state()
data_loader = state.build()
py_iterator = data_loader.create_iter()
return state, data_loader, py_iterator
4 changes: 2 additions & 2 deletions bytelatent/data/iterators/arrow_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def create_iter(

def _set_row_num(self, target_row_num: int):
logger.info(
f"Setting arrow position to {target_row_num} for {self.dataset_files}"
f"Setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}"
)
if target_row_num is None or target_row_num == 0:
self.row_num = 0
Expand Down Expand Up @@ -286,5 +286,5 @@ def _set_row_num(self, target_row_num: int):
curr_remaining -= len(batch)
self.row_num = target_row_num
logger.info(
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
f"Finished setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}"
)
27 changes: 19 additions & 8 deletions bytelatent/data/iterators/multiprocess_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ def start_work_from_state(
if stop_event.is_set():
# Signal the end of output, this ensures that even if the queue takes a while to
# buffer, that the main thread receives everything (and tosses this fake batch)
logging.info(
logging.debug(
"Worker thread: Stop event detected, outputting is_final=True batch"
)
logging.debug("Worker thread: batch_queue full=%s", batch_queue.full())
batch_queue.put(
Batch(
x=np.zeros((1, 1)),
Expand All @@ -67,14 +68,17 @@ def start_work_from_state(
ngram_ids=None,
)
)
logging.debug(
"Worker thread: is_final=True batch put in queue, breaking from loop."
)
break

try:
logging.info("Worker thread: outputting state")
state_queue.put(iterator.get_state(), timeout=1)
logging.info("Worker thread: state dump complete")
logging.debug("Worker thread: outputting state")
state_queue.put(stateful_iterator.get_state(), timeout=1)
logging.debug("Worker thread: state dump complete")
state_dumped_event.set()
logging.info("Worker thread: set state_dump_event")
logging.debug("Worker thread: set state_dump_event")
except Full:
raise ValueError(
"Attempted to dump state into the state queue, but it was full"
Expand Down Expand Up @@ -156,23 +160,30 @@ def get_state(self) -> MultiprocessIteratorState:
serialized_prefetch_buffer=serialized_prefetch_buffer,
)
else:
logging.info("Main thread: Sending stop iteration event")
logging.debug("Main thread: Sending stop iteration event")
self.stop_iterating_event.set()
logging.info("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()
logging.debug(
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
)
self.prefetch_buffer = []
final_batch_received = False
while True:
try:
batch = self.batch_queue.get(timeout=1)
if batch.is_final:
logging.debug(
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
)
final_batch_received = True
break
self.prefetch_buffer.append(batch)
except Empty:
logging.warning("Main thread: batch_queue is abnormally empty")
assert final_batch_received

logging.debug("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()

try:
base_iterator_state = self.state_queue.get(timeout=1)
assert isinstance(base_iterator_state, IteratorState)
Expand Down
10 changes: 10 additions & 0 deletions bytelatent/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from bytelatent.args import TrainArgs, parse_args
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
from bytelatent.data.iterators.multiprocess_iterator import (
MultiprocessIterator,
MultiprocessIteratorState,
Expand Down Expand Up @@ -699,6 +700,9 @@ def train(args: TrainArgs):
if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
saved = checkpoint.save(
model,
optimizer,
Expand Down Expand Up @@ -740,6 +744,9 @@ def train(args: TrainArgs):

if preemption_flag["flag"]:
if not saved:
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
checkpoint.save(
model,
optimizer,
Expand All @@ -751,6 +758,9 @@ def train(args: TrainArgs):
sys.exit(0)

if not saved:
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
checkpoint.save(
model,
optimizer,
Expand Down

0 comments on commit 38cc67a

Please sign in to comment.