From bc39591032f26054dbeb8baa303f648fcabef828 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 4 Feb 2025 18:19:41 +0000 Subject: [PATCH] Several changes to enable entropy model training/eval Summary: - Make arrow iterator able to read from jsonl files, the entropies are omitted in this case - Make the data/checkpoint code fsspec compatible - Fix issues with all reduce with non-bf16 in dist_sum and norm computation. - Minimal fixes to get eval to run, it is slow currently - Add bpb numbers during training Test Plan: Run ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/entropy_model.yaml eval=null max_steps=10100 ``` ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null ``` --- bytelatent/args.py | 48 ++++++++- bytelatent/checkpoint.py | 97 +++++++++-------- bytelatent/data/iterators/arrow_iterator.py | 97 +++++++++-------- bytelatent/distributed.py | 18 +++- bytelatent/eval.py | 66 ++++++++---- bytelatent/generate.py | 6 +- bytelatent/norms.py | 100 ++++++++++++++++++ bytelatent/preprocess/preprocess_entropies.py | 26 +++-- bytelatent/train.py | 59 ++++++++++- 9 files changed, 382 insertions(+), 135 deletions(-) create mode 100644 bytelatent/norms.py diff --git a/bytelatent/args.py b/bytelatent/args.py index d1bac46..33d12d6 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,8 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import logging import os from typing import Any +import fsspec import numpy as np import yaml from omegaconf import OmegaConf @@ -10,11 +12,9 @@ from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator -from bytelatent.data.iterators.arrow_iterator import ( - ArrowFileIterator, - find_and_sanitize_chunks, -) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator @@ -53,6 +53,43 @@ def parse_args(args_cls): return pydantic_args +def read_args_file(fs: fsspec.AbstractFileSystem, path: str) -> Any: + with fs.open(path, "rt") as f: + if path.endswith(".json"): + return json.load(f) + elif path.endswith(".yaml"): + return yaml.load(f) + else: + raise ValueError("Invalid args file format") + + +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, + world_size: int, + file_pattern: str, + s3_profile: str | None = None, +): + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks + + def distribute_data_to_rank( *, dataset_path: str, @@ -62,9 +99,10 @@ def distribute_data_to_rank( rank: int, world_size: int, s3_profile: str | None = None, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( - dataset_path, world_size, s3_profile=s3_profile + dataset_path, world_size, file_pattern, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..6631673 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,8 +4,6 @@ import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec import torch @@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -121,13 +122,13 @@ def __init__(self, args: CheckpointArgs): self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: + def get_existing_saves(self) -> list[str]: folders = [ p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) ] - folders.sort(key=lambda p: _get_key_step(p.name)) + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +137,9 @@ def clean_up(self): eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +163,39 @@ def clean_up(self): if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +223,14 @@ def save( model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +243,19 @@ def save( if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +274,7 @@ def load( optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +287,12 @@ def load( # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 4e7b99e..1c68d3a 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -16,6 +16,7 @@ from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text logger = getLogger(__name__) @@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): arrow_batch_size: int = 100 s3_profile: str | None filesystem_type: str | None = None + file_format: str def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -44,6 +46,7 @@ def build(self) -> "ArrowFileIterator": dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) @@ -70,6 +73,7 @@ def __init__( dataset_files: list[str] | None = None, s3_profile: str | None = None, filesystem_type: str | None = None, + file_format: str = "arrow", ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -87,12 +91,16 @@ def __init__( self.arrow_batch_size = arrow_batch_size self.s3_profile = s3_profile self.filesystem_type = filesystem_type + self.file_format = file_format self.fs = None if self.filesystem_type is not None: if self.filesystem_type == "file": self.fs = fsspec.filesystem("file") elif self.filesystem_type == "s3": self.fs = fsspec.filesystem("s3", profile=s3_profile) + else: + raise ValueError("Unknown filesystem") + logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: # Prepare arrow shards @@ -153,6 +161,7 @@ def get_state(self) -> ArrowFileIteratorState: dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) def create_iter( @@ -164,7 +173,7 @@ def create_iter( else: filesystem = None self.dataset = pa.dataset.dataset( - self.dataset_files, format="arrow", filesystem=filesystem + self.dataset_files, format=self.file_format, filesystem=filesystem ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size @@ -173,13 +182,22 @@ def create_iter( if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume self.batch_to_consume = None - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -191,13 +209,22 @@ def create_iter( for batch in self.batch_iterator: batch_columns = batch.to_pydict() - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -231,13 +258,24 @@ def _set_row_num(self, target_row_num: int): for batch in self.batch_iterator: if len(batch) > curr_remaining: batch_columns: dict[str, list] = batch.to_pydict() - batch_columns["sample_id"] = batch_columns["sample_id"][ - curr_remaining: - ] - batch_columns["entropies"] = batch_columns["entropies"][ - curr_remaining: - ] - batch_columns["text"] = batch_columns["text"][curr_remaining:] + if self.file_format == "arrow": + leftover_sample_ids = batch_columns["sample_id"][ + curr_remaining: + ] + leftover_entropies = batch_columns["entropies"][curr_remaining:] + leftover_texts = batch_columns["text"][curr_remaining:] + elif self.file_format == "json": + leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ + curr_remaining: + ] + leftover_entropies = None + leftover_texts = get_text(batch_columns)[curr_remaining:] + else: + raise ValueError(f"Unknown file format: {self.file_format}") + + batch_columns["sample_id"] = leftover_sample_ids + batch_columns["entropies"] = leftover_entropies + batch_columns["text"] = leftover_texts self.batch_to_consume = batch_columns break elif len(batch) == curr_remaining: @@ -250,30 +288,3 @@ def _set_row_num(self, target_row_num: int): logger.info( f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" ) - - -TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" - - -def find_and_sanitize_chunks( - dataset_path: str, - world_size: int, - file_pattern: str = TRAIN_DATA_FILE_PATTERN, - s3_profile: str | None = None, -): - fs = get_fs(dataset_path, s3_profile=s3_profile) - path_with_glob = os.path.join(dataset_path, file_pattern) - dataset_chunks = fs.glob(path_with_glob) - n_chunks = len(dataset_chunks) - - if n_chunks > world_size: - n_discard = n_chunks - world_size - dataset_chunks = dataset_chunks[:world_size] - else: - assert ( - world_size % n_chunks == 0 - ), "World size should be a multiple of number of chunks" - - assert n_chunks > 0, f"No valid chunks in {dataset_path}" - - return dataset_chunks diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 168cb7c..f4b57e2 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None): return tensor +def dist_sum( + x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None +): + tensor = torch.tensor(x).cuda() + if reduce_dtype is not None: + tensor = tensor.to(reduce_dtype) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None) + return tensor + + def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): tensor = torch.tensor(x).cuda() dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs): logger.warning(f"WARNING: Setting {name} to {value}") -def setup_torch_distributed(dist_args): +def setup_torch_distributed(dist_args: DistributedArgs): """ Handle single and multi-GPU / multi-node / SLURM jobs. Initialize the following variables: @@ -388,14 +398,14 @@ def clean_env(): def parallelize_model( - model, + model: torch.nn.Module, device_mesh, model_args, distributed_args: DistributedArgs, fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, tp_parallelize=None, no_recompute_ops=None, -): +) -> torch.nn.Module: if distributed_args.tp_size > 1: assert ( distributed_args.fsdp_type == "full_shard" @@ -429,6 +439,8 @@ def parallelize_model( device_mesh["dp_shard"].size() == 1 ), "dp_shard must be 1 for no_shard fsdp_type" + # TODO: Remove with something better + # model = model.to(param_dtype) fsdp_config = dict( mp_policy=( MixedPrecisionPolicy( diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..1943a33 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -15,9 +15,16 @@ from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import ( + EvalArgs, + TrainArgs, + ValidationArgs, + find_and_sanitize_chunks, + parse_args, +) from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -117,19 +124,40 @@ def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]: return results -def eval_on_val(generator, val_args: ValidationArgs, train_cfg): - srcs = {} +def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs): + srcs = [] for src in val_args.sources: path = os.path.join(val_args.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) + for src in train_cfg.data.sources: path = os.path.join(train_cfg.data.root_dir, src) - srcs[path] = 1.0 - - multi_state = init_choice_state( - "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" - ) - path_to_iter = setup_sources(multi_state) + srcs.append(path) + + path_to_iter = {} + for path in srcs: + chunks = find_and_sanitize_chunks( + path, + world_size=1, + file_pattern="*.val.jsonl", + s3_profile=train_cfg.data.s3_profile, + ) + assert ( + len(chunks) == 1 + ), f"There should be only 1 chunk per validation file, but found: {chunks}" + chunk = chunks[0] + iterator = ArrowFileIterator( + dataset_files=[chunk], + file_path=None, + preprocess_dir=None, + entropy_model_name=None, + worker_id=0, + num_workers=1, + arrow_batch_size=train_cfg.data.arrow_batch_size, + s3_profile=train_cfg.data.s3_profile, + file_format="json", + ) + path_to_iter[path] = iterator max_gen_len = generator.max_gen_len # We temporarily lower max gen len @@ -137,16 +165,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): all_val_metrics = {} for src in path_to_iter: - jsonl_iterator = path_to_iter[src] + example_iterator = path_to_iter[src].create_iter() texts = [] logger.info(f"Running validation on {src}...") - for step, (content, state) in enumerate(jsonl_iterator): - if state["current_iter"] > 0 or ( - val_args.max_steps is not None and step >= val_args.max_steps - ): - break - content_key = "text" if ("text" in content) else "content" - texts.append(content[content_key]) + for step, example in enumerate(example_iterator): + texts.append(example.text) _, loglikelihood, _ = generator.generate(texts) @@ -191,7 +214,7 @@ def launch_eval(eval_args: EvalArgs): else: consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) if not fs.exists(consolidate_path) and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) + consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) fs.mkdirs(eval_args.dump_dir, exist_ok=True) with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: @@ -210,10 +233,12 @@ def launch_eval(eval_args: EvalArgs): wrap = EvalHarnessLM(generator) # Redo - results = simple_evaluate(wrap, eval_args.harness.model_dump()) + results = simple_evaluate(wrap, **eval_args.harness.model_dump()) + val_results = None if eval_args.validation: val_results = eval_on_val(generator, eval_args.validation, train_cfg) + if get_global_rank() == 0: with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) @@ -222,6 +247,7 @@ def launch_eval(eval_args: EvalArgs): with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") + if eval_args.metric_log_dir and get_global_rank() == 0: metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") diff --git a/bytelatent/generate.py b/bytelatent/generate.py index eb79d81..9d44f30 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer( ): train_args_path = os.path.join(consolidated_path, "params.json") fs = get_fs(train_args_path) - with fs.open(train_args_path) as f: - train_args = TrainArgs.model_validate_json(f.read()) + train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) if train_args.train_entropy_model: model_args = train_args.entropy_model @@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer( train_args.distributed.model_dtype ] tokenizer = train_args.data.tokenizer_args.build() - st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) + with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: + st_dict = torch.load(f, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): diff --git a/bytelatent/norms.py b/bytelatent/norms.py new file mode 100644 index 0000000..81d1652 --- /dev/null +++ b/bytelatent/norms.py @@ -0,0 +1,100 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from torch import Tensor +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@torch.no_grad() +def fixed_clip_grad_norm_( + parameters: torch.Tensor | list[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 519da94..31a4802 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -15,29 +15,37 @@ from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -def get_id_from_doc(doc: dict) -> int: +def get_id_key(doc: dict) -> int: """ We need a reliable way to ensure that samples from jsonl and arrow are the same, but there is no unique id field, so derive the best possible """ if "sample_id" in doc: - sample_id = doc["sample_id"] + return "sample_id" elif "title" in doc: - sample_id = doc["title"] + return "title" elif "qid" in doc: - sample_id = doc["qid"] + return "qid" elif "paper_id" in doc: - sample_id = doc["paper_id"] + return "paper_id" elif "path" in doc: - sample_id = doc["path"] + return "path" elif "url" in doc: - sample_id = doc["url"] + return "url" elif "id" in doc: - sample_id = doc["id"] + return "id" else: raise ValueError(f"Could not find a id key from: {doc.keys()}") - return str(sample_id) + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + return str(doc[get_id_key(doc)]) def get_text(doc: dict): diff --git a/bytelatent/train.py b/bytelatent/train.py index 6b20ecd..16c6865 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -3,6 +3,7 @@ import gc import logging +import math import os import sys from contextlib import ExitStack @@ -12,6 +13,7 @@ from timeit import default_timer as timer from typing import Any, TypeVar +import numpy as np import torch import torch.distributed import torch.nn.functional @@ -25,6 +27,7 @@ from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -33,7 +36,9 @@ from bytelatent.distributed import ( check_model_value_range, clean_env, + dist_mean, dist_mean_dict, + dist_sum, get_device_mesh, get_is_master, get_world_size, @@ -47,6 +52,7 @@ from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.norms import fixed_clip_grad_norm_ from bytelatent.optim import build_optimizer from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler @@ -295,8 +301,11 @@ def train(args: TrainArgs): if args.checkpoint.init_ckpt_path: logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + ckpt_fs = get_fs( + args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile + ) load_from_checkpoint( - args.checkpoint.init_ckpt_path, model, model_key="model" + ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model" ) # Put model_key="" if its directly the model checkpoint model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: @@ -364,6 +373,9 @@ def train(args: TrainArgs): time_last_log = timer() gc.collect() saved = False + step_losses: list[float] = [] + step_tok_losses: list[float] = [] + n_bytes: int = 0 while train_state.step < args.steps and ( args.max_steps is None or train_state.step < args.max_steps ): @@ -385,6 +397,24 @@ def train(args: TrainArgs): batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + if args.data.tokenizer_args.name in ["bytes", "blt"]: + if mask is None: + n_bytes += batch_y.numel() + else: + n_bytes += mask.sum() + elif args.data.tokenizer_args.name in ["sp", "tiktoken"]: + for example in batch.y: + target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False) + n_bytes += ( + len(bytes(target_tokens, encoding="utf-8", errors="ignore")) + + sum(example == tokenizer.eos_id) + + sum(example == tokenizer.bos_id) + ) + else: + raise ValueError( + f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}" + ) + if ( not args.train_entropy_model and args.model.encoder_enable_byte_ngrams @@ -459,7 +489,7 @@ def train(args: TrainArgs): batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids ) - loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps @@ -470,8 +500,14 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), max_norm=args.optim.clip, foreach=True + # Undo loss scaling so downstream down't need to worry about it + step_losses.append((loss / train_state.scale).item()) + step_tok_losses.append(tok_loss / train_state.scale) + + # grad_norm = torch.nn.utils.clip_grad_norm_( + grad_norm = fixed_clip_grad_norm_( + model.parameters(), + max_norm=args.optim.clip, # foreach=True ) grad_norm = ( @@ -559,20 +595,33 @@ def train(args: TrainArgs): gpu_memory_monitor.reset_peak_stats() nwords_since_last_log = 0 time_last_log = timer() + stacked_tok_loss = torch.cat(step_tok_losses, dim=0) + total_tok_loss = dist_sum( + stacked_tok_loss.sum().item(), reduce_dtype=torch.bfloat16 + ) + total_n_bytes = dist_sum(n_bytes, reduce_dtype=torch.bfloat16) + avg_bpb = total_tok_loss / math.log(2) / total_n_bytes + avg_loss = dist_mean(np.mean(step_losses).item()) logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss: {round(loss.item(),4):>7}" + f" loss: step={round(loss.item(),4):>7} avg={avg_loss}" + f" bpb: {avg_bpb:3f}" f" grad: {grad_norm:.2e}" f" flops: {FLOPS:.2e}" f" wps: {wps:.2e}" f" iter: {curr_iter_time:>7}" f" data: {data_load_time:>5}" f" lr: {curr_lr:.2e}" + f" n_bytes={total_n_bytes}" f" mem: {gpu_mem_stats.max_active_pct:.0f}%" f" pow: {gpu_mem_stats.power_draw/1000} W" ) + n_bytes = 0 + step_losses = [] + step_tok_losses = [] + if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):