Skip to content

Commit

Permalink
Add bpb and n_bytes to metric logging
Browse files Browse the repository at this point in the history
Summary:

Test Plan:
  • Loading branch information
EntilZha committed Feb 5, 2025
1 parent 1450464 commit 2f42633
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
16 changes: 13 additions & 3 deletions bytelatent/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
return tensor


def dist_sum(
x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None
):
tensor = torch.tensor(x).cuda()
if reduce_dtype is not None:
tensor = tensor.to(reduce_dtype)
dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None)
return tensor


def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
tensor = torch.tensor(x).cuda()
dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
Expand Down Expand Up @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs):
logger.warning(f"WARNING: Setting {name} to {value}")


def setup_torch_distributed(dist_args):
def setup_torch_distributed(dist_args: DistributedArgs):
"""
Handle single and multi-GPU / multi-node / SLURM jobs.
Initialize the following variables:
Expand Down Expand Up @@ -388,14 +398,14 @@ def clean_env():


def parallelize_model(
model,
model: torch.nn.Module,
device_mesh,
model_args,
distributed_args: DistributedArgs,
fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
tp_parallelize=None,
no_recompute_ops=None,
):
) -> torch.nn.Module:
if distributed_args.tp_size > 1:
assert (
distributed_args.fsdp_type == "full_shard"
Expand Down
46 changes: 44 additions & 2 deletions bytelatent/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import gc
import logging
import math
import os
import sys
from contextlib import ExitStack
Expand All @@ -11,6 +12,7 @@
from timeit import default_timer as timer
from typing import Any, TypeVar

import numpy as np
import torch
import torch.distributed
import torch.nn.functional
Expand All @@ -32,7 +34,9 @@
from bytelatent.distributed import (
check_model_value_range,
clean_env,
dist_mean,
dist_mean_dict,
dist_sum,
get_device_mesh,
get_is_master,
get_world_size,
Expand Down Expand Up @@ -391,6 +395,9 @@ def train(args: TrainArgs):
time_last_log = timer()
gc.collect()
saved = False
step_losses: list[float] = []
step_tok_losses: list[float] = []
n_bytes: int = 0
while train_state.step < args.steps and (
args.max_steps is None or train_state.step < args.max_steps
):
Expand All @@ -412,6 +419,24 @@ def train(args: TrainArgs):
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()

if args.data.tokenizer_args.name in ["bytes", "blt"]:
if mask is None:
n_bytes += batch_y.numel()
else:
n_bytes += mask.sum()
elif args.data.tokenizer_args.name in ["sp", "tiktoken"]:
for example in batch.y:
target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False)
n_bytes += (
len(bytes(target_tokens, encoding="utf-8", errors="ignore"))
+ sum(example == tokenizer.eos_id)
+ sum(example == tokenizer.bos_id)
)
else:
raise ValueError(
f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}"
)

if (
not args.train_entropy_model
and args.model.encoder_enable_byte_ngrams
Expand Down Expand Up @@ -486,7 +511,7 @@ def train(args: TrainArgs):
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
)

loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale)

# We scale loss with grad_acc_steps so the gradient is the same
# regardless of grad_acc_steps
Expand All @@ -497,6 +522,10 @@ def train(args: TrainArgs):
# For logging we undo that scaling
loss = loss.detach() * args.grad_acc_steps

# Undo loss scaling so downstream down't need to worry about it
step_losses.append((loss / train_state.scale).item())
step_tok_losses.append(tok_loss / train_state.scale)

world_size = get_world_size()
if 1 < world_size <= 8:
# For some reason, there are errors in reduces due to
Expand Down Expand Up @@ -597,20 +626,33 @@ def train(args: TrainArgs):
gpu_memory_monitor.reset_peak_stats()
nwords_since_last_log = 0
time_last_log = timer()
stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
total_tok_loss = dist_sum(
stacked_tok_loss.sum().item(), reduce_dtype=torch.bfloat16
)
total_n_bytes = dist_sum(n_bytes, reduce_dtype=torch.bfloat16)
avg_bpb = total_tok_loss / math.log(2) / total_n_bytes
avg_loss = dist_mean(np.mean(step_losses).item())
logger.info(
f"step: {train_state.step}"
f" acc: {train_state.acc_step}"
f" loss: {round(loss.item(),4):>7}"
f" loss: step={round(loss.item(),4):>7} avg={avg_loss}"
f" bpb: {avg_bpb:3f}"
f" grad: {grad_norm:.2e}"
f" flops: {FLOPS:.2e}"
f" wps: {wps:.2e}"
f" iter: {curr_iter_time:>7}"
f" data: {data_load_time:>5}"
f" lr: {curr_lr:.2e}"
f" n_bytes={total_n_bytes}"
f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
f" pow: {gpu_mem_stats.power_draw/1000} W"
)

n_bytes = 0
step_losses = []
step_tok_losses = []

if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
Expand Down

0 comments on commit 2f42633

Please sign in to comment.