Skip to content

Commit

Permalink
Update checkpointing to use fsspec
Browse files Browse the repository at this point in the history
Summary:

- Make arrow iterator able to read from jsonl files, the entropies are omitted in this case
- Make the data/checkpoint code fsspec compatible
- Fix issues with all reduce with non-bf16 in dist_sum and norm computation.
- Minimal fixes to get eval to run, it is slow currently
- Add bpb numbers during training


Test Plan:

Run unit tests and the commands below

```
python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100
```

```
torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100
```

```
python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/
```

```
torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/
```
  • Loading branch information
EntilZha committed Feb 5, 2025
1 parent 7044771 commit b2058fb
Showing 1 changed file with 50 additions and 47 deletions.
97 changes: 50 additions & 47 deletions bytelatent/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import logging
import os
import re
from pathlib import Path
from typing import List, Optional, Tuple

import fsspec
import torch
Expand Down Expand Up @@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
Returns the path to the consolidated checkpoint
"""
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
if not (consolidate_path / CONSOLIDATE_NAME).exists():
consolidate_path.mkdir(exist_ok=True)
logger.info(f"Consolidating to: {str(consolidate_path)}")
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
(consolidate_path / CONFIG_NAME).write_text(
(Path(ckpt_dir) / CONFIG_NAME).read_text()
consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER)
consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME)
if not fs.exists(consolidate_name):
fs.mkdirs(consolidate_path, exist_ok=True)
logger.info(f"Consolidating to: {consolidate_path}")
dcp_to_torch_save(ckpt_dir, consolidate_name)
fs.write_text(
os.path.join(consolidate_path, CONFIG_NAME),
fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)),
)
logger.info("Consolidated !")
return consolidate_path


def load_from_checkpoint(
fs: fsspec.AbstractFileSystem,
ckpt_dir: str,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
optimizer: torch.optim.Optimizer | None = None,
model_key: str = "model",
optim_key: str = "optim",
):
if not (Path(ckpt_dir) / ".metadata").exists():
if not fs.exists(os.path.join(ckpt_dir, ".metadata")):
raise ValueError(
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
)
Expand Down Expand Up @@ -121,13 +122,13 @@ def __init__(self, args: CheckpointArgs):

self.existing_saves = self.get_existing_saves()

def get_existing_saves(self) -> List[Path]:
def get_existing_saves(self) -> list[str]:
folders = [
p
for p in Path(self.path).iterdir()
if p.is_dir() and re.match(RE_FOLDER, p.name)
for p in self.fs.ls(self.path)
if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p))
]
folders.sort(key=lambda p: _get_key_step(p.name))
folders.sort(key=lambda p: _get_key_step(os.path.basename(p)))
return folders

def clean_up(self):
Expand All @@ -136,8 +137,9 @@ def clean_up(self):
eval_folders = []
other_folders = []
for p in self.existing_saves:
is_dump = _get_key_step(p.name) % self.dump_every.every == 0
is_eval = _get_key_step(p.name) % self.eval_every.every == 0
assert isinstance(p, str), f"Base path type: {p}"
is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0
is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0
if is_dump:
dump_folders.append(p)
if is_eval:
Expand All @@ -161,40 +163,39 @@ def clean_up(self):

if dist.get_rank() == 0:
for folder in folder_to_remove:
for file in folder.iterdir():
if file.is_file():
file.unlink()
elif file.is_dir():
assert file.name in [CONSOLIDATE_FOLDER]
for f in file.iterdir():
f.unlink()
file.rmdir()
folder.rmdir()
for file in self.fs.ls(folder):
if self.fs.isfile(file):
self.fs.rm_file(file)
elif self.fs.isdir(file):
assert os.path.name(file) in [CONSOLIDATE_FOLDER]
for f in self.fs.ls(file):
self.fs.rm(f)
self.fs.rmdir(file)
self.fs.rmdir(folder)

dist.barrier()

self.existing_saves = list(folder_to_keep)
self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p)))

def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
def get_last_step_path(self, dp_rank: int = 0) -> str | None:
path = None
for p in reversed(self.existing_saves):
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():

if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))):
path = p
break
return path

def _create_folder(self, base_path: Path, folder_name: str) -> Path:
folder = base_path / folder_name
def _create_folder(self, base_path: str, folder_name: str) -> str:
folder = os.path.join(base_path, folder_name)
if get_is_master():
folder.mkdir(parents=False, exist_ok=True)
self.fs.mkdirs(folder, exist_ok=True)
if dist.is_initialized():
dist.barrier()
return folder

def _get_dp_tp_mesh(
self, device_mesh: Optional[DeviceMesh] = None
) -> Tuple[int, int]:
def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]:
dp_rank = 0
tp_rank = 0
if device_mesh is not None:
Expand Down Expand Up @@ -222,14 +223,14 @@ def save(
model,
optimizer,
train_state,
config,
device_mesh: Optional[DeviceMesh] = None,
config: BaseModel,
device_mesh: DeviceMesh | None = None,
) -> bool:

# When creating directory check if only rank0 or is there other solution
path = Path(self.path)
path = self.path
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
logger.info(f"Saving to: {str(curr_save_dir)}")
logger.info(f"Saving to: {curr_save_dir}")

if dist.is_initialized():
dist.barrier()
Expand All @@ -242,17 +243,19 @@ def save(
if dist.is_initialized():
dist.barrier()

print("config type", type(config))
if get_is_master():
config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
self.fs.write_text(
os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json()
)

# Add json dump here
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
if tp_rank == 0:
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info(
f"Saving train state to: {str(curr_save_dir / train_state_name)}"
)
with open(curr_save_dir / train_state_name, "w") as f:
train_state_full_path = os.path.join(curr_save_dir, train_state_name)
logger.info(f"Saving train state to: {train_state_full_path}")
with self.fs.open(train_state_full_path, "w") as f:
json.dump(train_state.state_dict(), f)
logger.info("Train state saved !")

Expand All @@ -271,7 +274,7 @@ def load(
optimizer,
train_state,
device_mesh: DeviceMesh,
path: Optional[Path] = None,
path: str | None = None,
):
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
Expand All @@ -284,12 +287,12 @@ def load(
# Only load train state if it's provided, the files exist and we're not loading from init path
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info("Reloading train state")
with open(path / train_state_name, "r") as f:
with self.fs.open(os.path.join(path, train_state_name), "r") as f:
train_state_dict = json.load(f)
train_state.load_state_dict(train_state_dict)
logger.info("Train state reloaded")

logger.info(f"Loading from: {str(path)}")
logger.info(f"Loading from: {path}")
state_dict = self.get_state_dict(
model=model,
optimizer=optimizer,
Expand Down

0 comments on commit b2058fb

Please sign in to comment.