From 8d7338308e46a14b45b30de499dd714b2ffe8f69 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 7 Feb 2025 00:26:00 +0000 Subject: [PATCH] Add bpb and n_bytes to metric logging Summary: Test Plan: --- .gitignore | 1 + bytelatent/distributed.py | 16 ++++- bytelatent/metrics.py | 1 - bytelatent/train.py | 134 +++++++++++++++++++++++++++++++------- 4 files changed, 123 insertions(+), 29 deletions(-) diff --git a/.gitignore b/.gitignore index d1d7c2a..2d0f075 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,4 @@ figures/ .DS_Store internal/ jobs_parallel-copy/ +wandb/ diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 168cb7c..b3f31a6 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None): return tensor +def dist_sum( + x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None +): + tensor = torch.tensor(x).cuda() + if reduce_dtype is not None: + tensor = tensor.to(reduce_dtype) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None) + return tensor + + def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): tensor = torch.tensor(x).cuda() dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs): logger.warning(f"WARNING: Setting {name} to {value}") -def setup_torch_distributed(dist_args): +def setup_torch_distributed(dist_args: DistributedArgs): """ Handle single and multi-GPU / multi-node / SLURM jobs. Initialize the following variables: @@ -388,14 +398,14 @@ def clean_env(): def parallelize_model( - model, + model: torch.nn.Module, device_mesh, model_args, distributed_args: DistributedArgs, fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, tp_parallelize=None, no_recompute_ops=None, -): +) -> torch.nn.Module: if distributed_args.tp_size > 1: assert ( distributed_args.fsdp_type == "full_shard" diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py index fb443d7..ed805e5 100644 --- a/bytelatent/metrics.py +++ b/bytelatent/metrics.py @@ -49,7 +49,6 @@ class LoggingArgs(BaseModel): model_config = ConfigDict(extra="forbid") freq: int = 10 # Log every freq optimizer steps acc_freq: int | None = None # Log every acc_freq gradient accumulation steps - wandb: WandbArgs | None = None diff --git a/bytelatent/train.py b/bytelatent/train.py index 9bfe12a..ed84233 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -3,6 +3,7 @@ import gc import logging +import math import os import sys from contextlib import ExitStack @@ -11,6 +12,7 @@ from timeit import default_timer as timer from typing import Any, TypeVar +import numpy as np import torch import torch.distributed import torch.nn.functional @@ -32,7 +34,9 @@ from bytelatent.distributed import ( check_model_value_range, clean_env, + dist_mean, dist_mean_dict, + dist_sum, get_device_mesh, get_is_master, get_world_size, @@ -392,6 +396,9 @@ def train(args: TrainArgs): time_last_log = timer() gc.collect() saved = False + step_losses: list[float] = [] + step_tok_losses: list[float] = [] + n_bytes: int = 0 while train_state.step < args.steps and ( args.max_steps is None or train_state.step < args.max_steps ): @@ -413,6 +420,21 @@ def train(args: TrainArgs): batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + if args.data.tokenizer_args.name in ["bytes", "blt"]: + n_bytes += batch_y.numel() if mask is None else mask.sum() + elif args.data.tokenizer_args.name in ["sp", "tiktoken"]: + for example in batch.y: + target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False) + n_bytes += ( + len(bytes(target_tokens, encoding="utf-8", errors="ignore")) + + sum(example == tokenizer.eos_id) + + sum(example == tokenizer.bos_id) + ) + else: + raise ValueError( + f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}" + ) + if ( not args.train_entropy_model and args.model.encoder_enable_byte_ngrams @@ -487,7 +509,7 @@ def train(args: TrainArgs): batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids ) - loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps @@ -498,6 +520,10 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps + # Undo loss scaling so downstream down't need to worry about it + step_losses.append((loss / train_state.scale).item()) + step_tok_losses.append(tok_loss / train_state.scale) + world_size = get_world_size() if 1 < world_size <= 8: # For some reason, there are errors in reduces due to @@ -568,50 +594,108 @@ def train(args: TrainArgs): * wps ) - metrics = flatten_dict( - { - "global_step": train_state.step, - "acc_step": train_state.acc_step, - "speed": { - "wps": wps, - "FLOPS": FLOPS, - "curr_iter_time": curr_iter_time, - "data_load_time": data_load_time, - }, - "optim": { - "grad_norm": grad_norm, - "lr": curr_lr, - "total_tokens": total_tokens, - }, - "memory": gpu_mem_stats._asdict(), + # Below, semantics are: + # per_gpu: Metrics on a given rank + # across_gpus: Metrics averaged/summed across all ranks + # step: Metric at a step + # interval: Metric averaged/summed across all steps since the last log interval. + # Typically, this is 10 + step_loss_per_gpu = loss.item() + step_loss_across_gpus = dist_mean(step_loss_per_gpu).item() + interval_loss_per_gpu = np.mean(step_losses).item() + interval_loss_across_gpus = dist_mean(interval_loss_per_gpu).item() + + stacked_tok_loss = torch.cat(step_tok_losses, dim=0) + interval_total_tok_loss_per_gpu = stacked_tok_loss.sum().item() + interval_total_tok_loss_across_gpus = dist_sum( + interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16 + ).item() + interval_total_n_bytes_per_gpu = n_bytes + interval_total_n_bytes_across_gpus = dist_sum( + n_bytes, reduce_dtype=torch.bfloat16 + ).item() + + interval_bpb_per_gpu = ( + interval_total_tok_loss_per_gpu + / math.log(2) + / interval_total_n_bytes_per_gpu + ) + interval_bpb_across_gpus = ( + interval_total_tok_loss_across_gpus + / math.log(2) + / interval_total_n_bytes_across_gpus + ) + + metric_dict = { + "global_step": train_state.step, + "acc_step": train_state.acc_step, + "speed": { + "wps": wps, + "FLOPS": FLOPS, + "curr_iter_time": curr_iter_time, + "data_load_time": data_load_time, + }, + "optim": { + "grad_norm": grad_norm, + "lr": curr_lr, + "total_tokens": total_tokens, + }, + "memory": gpu_mem_stats._asdict(), + "loss": { + "step_per_gpu": step_loss_per_gpu, + "step_across_gpu": step_loss_across_gpus, + "interval_per_gpu": interval_loss_per_gpu, + "interval_across_gpu": interval_loss_across_gpus, }, + "bpb": { + "interval_per_gpu": interval_bpb_per_gpu, + "interval_across_gpus": interval_bpb_across_gpus, + }, + "n_bytes": { + "interval_per_gpu": interval_total_n_bytes_per_gpu, + "interval_across_gpus": interval_total_n_bytes_across_gpus, + }, + } + + metrics = flatten_dict( + metric_dict, sep="/", ) - to_sync = {} - to_sync["loss/out"] = loss.item() - metrics.update(dist_mean_dict(to_sync)) - if get_is_master(): metric_logger.log(metrics) - gpu_memory_monitor.reset_peak_stats() - nwords_since_last_log = 0 - time_last_log = timer() + # Below semantics are: + # step=Metrics at a step + # interval=Metrics averaged across the logging interval + # local=On one rank + # global=Across all ranks logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss: {round(loss.item(),4):>7}" + f" loss_gpu: {round(interval_loss_per_gpu, 4):>7}" + f" loss_avg: {round(interval_loss_across_gpus, 4):>7}" + f" bpb_gpu: {interval_bpb_per_gpu:3f}" + f" bpb_avg: {interval_bpb_across_gpus:3f}" f" grad: {grad_norm:.2e}" f" flops: {FLOPS:.2e}" f" wps: {wps:.2e}" f" iter: {curr_iter_time:>7}" f" data: {data_load_time:>5}" f" lr: {curr_lr:.2e}" + f" n_bytes_gpu: {int(interval_total_n_bytes_per_gpu)}" + f" n_bytes_sum: {int(interval_total_n_bytes_across_gpus)}" f" mem: {gpu_mem_stats.max_active_pct:.0f}%" f" pow: {gpu_mem_stats.power_draw/1000} W" ) + n_bytes = 0 + step_losses = [] + step_tok_losses = [] + gpu_memory_monitor.reset_peak_stats() + nwords_since_last_log = 0 + time_last_log = timer() + if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):