Skip to content

Commit

Permalink
Several changes to enable entropy model training/eval
Browse files Browse the repository at this point in the history
Summary:

- Make arrow iterator able to read from jsonl files, the entropies are omitted in this case
- Make the data/checkpoint code fsspec compatible
- Fix issues with all reduce with non-bf16 in dist_sum and norm computation.
- Minimal fixes to get eval to run, it is slow currently
- Add bpb numbers during training


Test Plan:

Run

```
torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/entropy_model.yaml eval=null max_steps=10100
```
  • Loading branch information
EntilZha committed Feb 4, 2025
1 parent 7044771 commit ab399e9
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 135 deletions.
48 changes: 43 additions & 5 deletions bytelatent/args.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import os
from typing import Any

import fsspec
import numpy as np
import yaml
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict

from bytelatent.checkpoint import CheckpointArgs
from bytelatent.data.data_types import Batch
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
find_and_sanitize_chunks,
)
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
Expand Down Expand Up @@ -53,6 +53,43 @@ def parse_args(args_cls):
return pydantic_args


def read_args_file(fs: fsspec.AbstractFileSystem, path: str) -> Any:
with fs.open(path, "rt") as f:
if path.endswith(".json"):
return json.load(f)
elif path.endswith(".yaml"):
return yaml.load(f)
else:
raise ValueError("Invalid args file format")


TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"


def find_and_sanitize_chunks(
dataset_path: str,
world_size: int,
file_pattern: str,
s3_profile: str | None = None,
):
fs = get_fs(dataset_path, s3_profile=s3_profile)
path_with_glob = os.path.join(dataset_path, file_pattern)
dataset_chunks = fs.glob(path_with_glob)
n_chunks = len(dataset_chunks)

if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"

assert n_chunks > 0, f"No valid chunks in {dataset_path}"

return dataset_chunks


def distribute_data_to_rank(
*,
dataset_path: str,
Expand All @@ -62,9 +99,10 @@ def distribute_data_to_rank(
rank: int,
world_size: int,
s3_profile: str | None = None,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
) -> ArrowFileIterator:
dataset_chunks = find_and_sanitize_chunks(
dataset_path, world_size, s3_profile=s3_profile
dataset_path, world_size, file_pattern, s3_profile=s3_profile
)
n_workers_per_chunk = world_size // len(dataset_chunks)
rank_to_arrow_iterator_params = []
Expand Down
97 changes: 50 additions & 47 deletions bytelatent/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import logging
import os
import re
from pathlib import Path
from typing import List, Optional, Tuple

import fsspec
import torch
Expand Down Expand Up @@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
Returns the path to the consolidated checkpoint
"""
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
if not (consolidate_path / CONSOLIDATE_NAME).exists():
consolidate_path.mkdir(exist_ok=True)
logger.info(f"Consolidating to: {str(consolidate_path)}")
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
(consolidate_path / CONFIG_NAME).write_text(
(Path(ckpt_dir) / CONFIG_NAME).read_text()
consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER)
consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME)
if not fs.exists(consolidate_name):
fs.mkdirs(consolidate_path, exist_ok=True)
logger.info(f"Consolidating to: {consolidate_path}")
dcp_to_torch_save(ckpt_dir, consolidate_name)
fs.write_text(
os.path.join(consolidate_path, CONFIG_NAME),
fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)),
)
logger.info("Consolidated !")
return consolidate_path


def load_from_checkpoint(
fs: fsspec.AbstractFileSystem,
ckpt_dir: str,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
optimizer: torch.optim.Optimizer | None = None,
model_key: str = "model",
optim_key: str = "optim",
):
if not (Path(ckpt_dir) / ".metadata").exists():
if not fs.exists(os.path.join(ckpt_dir, ".metadata")):
raise ValueError(
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
)
Expand Down Expand Up @@ -121,13 +122,13 @@ def __init__(self, args: CheckpointArgs):

self.existing_saves = self.get_existing_saves()

def get_existing_saves(self) -> List[Path]:
def get_existing_saves(self) -> list[str]:
folders = [
p
for p in Path(self.path).iterdir()
if p.is_dir() and re.match(RE_FOLDER, p.name)
for p in self.fs.ls(self.path)
if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p))
]
folders.sort(key=lambda p: _get_key_step(p.name))
folders.sort(key=lambda p: _get_key_step(os.path.basename(p)))
return folders

def clean_up(self):
Expand All @@ -136,8 +137,9 @@ def clean_up(self):
eval_folders = []
other_folders = []
for p in self.existing_saves:
is_dump = _get_key_step(p.name) % self.dump_every.every == 0
is_eval = _get_key_step(p.name) % self.eval_every.every == 0
assert isinstance(p, str), f"Base path type: {p}"
is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0
is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0
if is_dump:
dump_folders.append(p)
if is_eval:
Expand All @@ -161,40 +163,39 @@ def clean_up(self):

if dist.get_rank() == 0:
for folder in folder_to_remove:
for file in folder.iterdir():
if file.is_file():
file.unlink()
elif file.is_dir():
assert file.name in [CONSOLIDATE_FOLDER]
for f in file.iterdir():
f.unlink()
file.rmdir()
folder.rmdir()
for file in self.fs.ls(folder):
if self.fs.isfile(file):
self.fs.rm_file(file)
elif self.fs.isdir(file):
assert os.path.name(file) in [CONSOLIDATE_FOLDER]
for f in self.fs.ls(file):
self.fs.rm(f)
self.fs.rmdir(file)
self.fs.rmdir(folder)

dist.barrier()

self.existing_saves = list(folder_to_keep)
self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p)))

def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
def get_last_step_path(self, dp_rank: int = 0) -> str | None:
path = None
for p in reversed(self.existing_saves):
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():

if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))):
path = p
break
return path

def _create_folder(self, base_path: Path, folder_name: str) -> Path:
folder = base_path / folder_name
def _create_folder(self, base_path: str, folder_name: str) -> str:
folder = os.path.join(base_path, folder_name)
if get_is_master():
folder.mkdir(parents=False, exist_ok=True)
self.fs.mkdirs(folder, exist_ok=True)
if dist.is_initialized():
dist.barrier()
return folder

def _get_dp_tp_mesh(
self, device_mesh: Optional[DeviceMesh] = None
) -> Tuple[int, int]:
def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]:
dp_rank = 0
tp_rank = 0
if device_mesh is not None:
Expand Down Expand Up @@ -222,14 +223,14 @@ def save(
model,
optimizer,
train_state,
config,
device_mesh: Optional[DeviceMesh] = None,
config: BaseModel,
device_mesh: DeviceMesh | None = None,
) -> bool:

# When creating directory check if only rank0 or is there other solution
path = Path(self.path)
path = self.path
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
logger.info(f"Saving to: {str(curr_save_dir)}")
logger.info(f"Saving to: {curr_save_dir}")

if dist.is_initialized():
dist.barrier()
Expand All @@ -242,17 +243,19 @@ def save(
if dist.is_initialized():
dist.barrier()

print("config type", type(config))
if get_is_master():
config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
self.fs.write_text(
os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json()
)

# Add json dump here
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
if tp_rank == 0:
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info(
f"Saving train state to: {str(curr_save_dir / train_state_name)}"
)
with open(curr_save_dir / train_state_name, "w") as f:
train_state_full_path = os.path.join(curr_save_dir, train_state_name)
logger.info(f"Saving train state to: {train_state_full_path}")
with self.fs.open(train_state_full_path, "w") as f:
json.dump(train_state.state_dict(), f)
logger.info("Train state saved !")

Expand All @@ -271,7 +274,7 @@ def load(
optimizer,
train_state,
device_mesh: DeviceMesh,
path: Optional[Path] = None,
path: str | None = None,
):
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
Expand All @@ -284,12 +287,12 @@ def load(
# Only load train state if it's provided, the files exist and we're not loading from init path
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info("Reloading train state")
with open(path / train_state_name, "r") as f:
with self.fs.open(os.path.join(path, train_state_name), "r") as f:
train_state_dict = json.load(f)
train_state.load_state_dict(train_state_dict)
logger.info("Train state reloaded")

logger.info(f"Loading from: {str(path)}")
logger.info(f"Loading from: {path}")
state_dict = self.get_state_dict(
model=model,
optimizer=optimizer,
Expand Down
Loading

0 comments on commit ab399e9

Please sign in to comment.