Skip to content

Commit

Permalink
Add bpb and n_bytes to metric logging (#41)
Browse files Browse the repository at this point in the history
Summary:

Test Plan:
  • Loading branch information
EntilZha authored Feb 7, 2025
1 parent aebdc48 commit fe45f69
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 29 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@ figures/
.DS_Store
internal/
jobs_parallel-copy/
wandb/
16 changes: 13 additions & 3 deletions bytelatent/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
return tensor


def dist_sum(
x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None
):
tensor = torch.tensor(x).cuda()
if reduce_dtype is not None:
tensor = tensor.to(reduce_dtype)
dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None)
return tensor


def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
tensor = torch.tensor(x).cuda()
dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
Expand Down Expand Up @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs):
logger.warning(f"WARNING: Setting {name} to {value}")


def setup_torch_distributed(dist_args):
def setup_torch_distributed(dist_args: DistributedArgs):
"""
Handle single and multi-GPU / multi-node / SLURM jobs.
Initialize the following variables:
Expand Down Expand Up @@ -388,14 +398,14 @@ def clean_env():


def parallelize_model(
model,
model: torch.nn.Module,
device_mesh,
model_args,
distributed_args: DistributedArgs,
fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
tp_parallelize=None,
no_recompute_ops=None,
):
) -> torch.nn.Module:
if distributed_args.tp_size > 1:
assert (
distributed_args.fsdp_type == "full_shard"
Expand Down
1 change: 0 additions & 1 deletion bytelatent/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class LoggingArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
freq: int = 10 # Log every freq optimizer steps
acc_freq: int | None = None # Log every acc_freq gradient accumulation steps

wandb: WandbArgs | None = None


Expand Down
134 changes: 109 additions & 25 deletions bytelatent/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import gc
import logging
import math
import os
import sys
from contextlib import ExitStack
Expand All @@ -11,6 +12,7 @@
from timeit import default_timer as timer
from typing import Any, TypeVar

import numpy as np
import torch
import torch.distributed
import torch.nn.functional
Expand All @@ -32,7 +34,9 @@
from bytelatent.distributed import (
check_model_value_range,
clean_env,
dist_mean,
dist_mean_dict,
dist_sum,
get_device_mesh,
get_is_master,
get_world_size,
Expand Down Expand Up @@ -392,6 +396,9 @@ def train(args: TrainArgs):
time_last_log = timer()
gc.collect()
saved = False
step_losses: list[float] = []
step_tok_losses: list[float] = []
n_bytes: int = 0
while train_state.step < args.steps and (
args.max_steps is None or train_state.step < args.max_steps
):
Expand All @@ -413,6 +420,21 @@ def train(args: TrainArgs):
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()

if args.data.tokenizer_args.name in ["bytes", "blt"]:
n_bytes += batch_y.numel() if mask is None else mask.sum()
elif args.data.tokenizer_args.name in ["sp", "tiktoken"]:
for example in batch.y:
target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False)
n_bytes += (
len(bytes(target_tokens, encoding="utf-8", errors="ignore"))
+ sum(example == tokenizer.eos_id)
+ sum(example == tokenizer.bos_id)
)
else:
raise ValueError(
f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}"
)

if (
not args.train_entropy_model
and args.model.encoder_enable_byte_ngrams
Expand Down Expand Up @@ -487,7 +509,7 @@ def train(args: TrainArgs):
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
)

loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale)

# We scale loss with grad_acc_steps so the gradient is the same
# regardless of grad_acc_steps
Expand All @@ -498,6 +520,10 @@ def train(args: TrainArgs):
# For logging we undo that scaling
loss = loss.detach() * args.grad_acc_steps

# Undo loss scaling so downstream down't need to worry about it
step_losses.append((loss / train_state.scale).item())
step_tok_losses.append(tok_loss / train_state.scale)

world_size = get_world_size()
if 1 < world_size <= 8:
# For some reason, there are errors in reduces due to
Expand Down Expand Up @@ -568,50 +594,108 @@ def train(args: TrainArgs):
* wps
)

metrics = flatten_dict(
{
"global_step": train_state.step,
"acc_step": train_state.acc_step,
"speed": {
"wps": wps,
"FLOPS": FLOPS,
"curr_iter_time": curr_iter_time,
"data_load_time": data_load_time,
},
"optim": {
"grad_norm": grad_norm,
"lr": curr_lr,
"total_tokens": total_tokens,
},
"memory": gpu_mem_stats._asdict(),
# Below, semantics are:
# per_gpu: Metrics on a given rank
# across_gpus: Metrics averaged/summed across all ranks
# step: Metric at a step
# interval: Metric averaged/summed across all steps since the last log interval.
# Typically, this is 10
step_loss_per_gpu = loss.item()
step_loss_across_gpus = dist_mean(step_loss_per_gpu).item()
interval_loss_per_gpu = np.mean(step_losses).item()
interval_loss_across_gpus = dist_mean(interval_loss_per_gpu).item()

stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
interval_total_tok_loss_per_gpu = stacked_tok_loss.sum().item()
interval_total_tok_loss_across_gpus = dist_sum(
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
).item()
interval_total_n_bytes_per_gpu = n_bytes
interval_total_n_bytes_across_gpus = dist_sum(
n_bytes, reduce_dtype=torch.bfloat16
).item()

interval_bpb_per_gpu = (
interval_total_tok_loss_per_gpu
/ math.log(2)
/ interval_total_n_bytes_per_gpu
)
interval_bpb_across_gpus = (
interval_total_tok_loss_across_gpus
/ math.log(2)
/ interval_total_n_bytes_across_gpus
)

metric_dict = {
"global_step": train_state.step,
"acc_step": train_state.acc_step,
"speed": {
"wps": wps,
"FLOPS": FLOPS,
"curr_iter_time": curr_iter_time,
"data_load_time": data_load_time,
},
"optim": {
"grad_norm": grad_norm,
"lr": curr_lr,
"total_tokens": total_tokens,
},
"memory": gpu_mem_stats._asdict(),
"loss": {
"step_per_gpu": step_loss_per_gpu,
"step_across_gpu": step_loss_across_gpus,
"interval_per_gpu": interval_loss_per_gpu,
"interval_across_gpu": interval_loss_across_gpus,
},
"bpb": {
"interval_per_gpu": interval_bpb_per_gpu,
"interval_across_gpus": interval_bpb_across_gpus,
},
"n_bytes": {
"interval_per_gpu": interval_total_n_bytes_per_gpu,
"interval_across_gpus": interval_total_n_bytes_across_gpus,
},
}

metrics = flatten_dict(
metric_dict,
sep="/",
)

to_sync = {}
to_sync["loss/out"] = loss.item()
metrics.update(dist_mean_dict(to_sync))

if get_is_master():
metric_logger.log(metrics)

gpu_memory_monitor.reset_peak_stats()
nwords_since_last_log = 0
time_last_log = timer()
# Below semantics are:
# step=Metrics at a step
# interval=Metrics averaged across the logging interval
# local=On one rank
# global=Across all ranks
logger.info(
f"step: {train_state.step}"
f" acc: {train_state.acc_step}"
f" loss: {round(loss.item(),4):>7}"
f" loss_gpu: {round(interval_loss_per_gpu, 4):>7}"
f" loss_avg: {round(interval_loss_across_gpus, 4):>7}"
f" bpb_gpu: {interval_bpb_per_gpu:3f}"
f" bpb_avg: {interval_bpb_across_gpus:3f}"
f" grad: {grad_norm:.2e}"
f" flops: {FLOPS:.2e}"
f" wps: {wps:.2e}"
f" iter: {curr_iter_time:>7}"
f" data: {data_load_time:>5}"
f" lr: {curr_lr:.2e}"
f" n_bytes_gpu: {int(interval_total_n_bytes_per_gpu)}"
f" n_bytes_sum: {int(interval_total_n_bytes_across_gpus)}"
f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
f" pow: {gpu_mem_stats.power_draw/1000} W"
)

n_bytes = 0
step_losses = []
step_tok_losses = []
gpu_memory_monitor.reset_peak_stats()
nwords_since_last_log = 0
time_last_log = timer()

if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
Expand Down

0 comments on commit fe45f69

Please sign in to comment.