From 04a22a0086c4f45063879ef7795d39a498a84fe5 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 27 Jan 2025 11:47:43 -0800 Subject: [PATCH 1/5] Update [ghstack-poisoned] --- run_llama_train.sh | 4 + torchtitan/checkpoint.py | 205 +++++++++++++---------- torchtitan/config_manager.py | 13 ++ torchtitan/ft.py | 59 +++++++ torchtitan/optimizer.py | 27 ++- torchtitan/parallelisms/parallel_dims.py | 21 ++- torchtitan/utils.py | 8 + train.py | 31 ++-- train_configs/debug_model.toml | 2 +- 9 files changed, 264 insertions(+), 106 deletions(-) create mode 100644 torchtitan/ft.py diff --git a/run_llama_train.sh b/run_llama_train.sh index a69c967a7..cbd98a9f8 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then overrides="$*" fi +TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"} + PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +TORCHFT_LIGHTHOUSE=http://localhost:29510 \ +TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ train.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index db54ccd9b..0971de2c4 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -13,12 +13,13 @@ from dataclasses import dataclass, field from io import BytesIO from multiprocessing import get_context -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn +from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, set_model_state_dict, @@ -143,49 +144,28 @@ def __init__( lr_schedulers: SchedulersContainer, states: Dict[str, Any], job_config: JobConfig, + ft_manager: Optional[Any] = None, ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint - self.keep_latest_k = ckpt_config.keep_latest_k + self.ft_manager = ft_manager - if not self.enable_checkpoint: + if not self.enable_checkpoint and self.ft_manager is None: return - """ - Note: Pipeline Parallelism and Virtual Stages - - 1. even for simple PP schedules, there is a separate optimizer each PP rank. - rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model. - rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1. - When saving, these collide and one of them is lost. Then when reloading, only one stage can - restore its optimizer states, others will error. - - The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan - by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer. - - 2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also - requiring us to reason about multiple 'optim' objects locally. - - We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object - into one state dict before saving/loading. We rely on the individual state_dicts to not collide, - which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening - support described in (1). - - 3. LR schedulers also index model states like optimizers and would need to be flattened properly to support - resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like - optimizers do, so it's hard to write a generic 'flattener' utility. - - TODO: This is currently unsolved and needs a fix. - """ - self.states = states - self.states.update( - { - "model": ModelWrapper(model_parts), - "optimizer": optimizers, - "dataloader": dataloader, - } + self._initialize_states( + states, dataloader, model_parts, optimizers, lr_schedulers ) - self.states.update(lr_schedulers.get_lr_scheduler_state()) + + async_mode = ckpt_config.async_mode.lower() + self.enable_staging = ( + self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM + ) or self.ft_manager + self.staging = False + self.sending_to_checkpoint_mp = False + self.staging_id = None + self.cpu_offload_state_dict = None + self.staging_stream = torch.cuda.Stream() if self.enable_staging else None self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) self.interval_type = ( @@ -199,11 +179,11 @@ def __init__( self.time_sync_result = None self.pg = dist.new_group(backend="gloo") + self.keep_latest_k = ckpt_config.keep_latest_k self.model_weights_only = ckpt_config.model_weights_only self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] self.mp = None - async_mode = ckpt_config.async_mode.lower() if async_mode == AsyncMode.DISABLED: self.async_mode = AsyncMode.DISABLED elif async_mode == AsyncMode.ASYNC: @@ -223,10 +203,6 @@ def __init__( daemon=True, ) self.mp.start() - self.cpu_offload_state_dict = None - self.staging = False - self.staging_id = None - self.staging_stream = torch.cuda.Stream() else: raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}") @@ -240,8 +216,61 @@ def __del__(self): self.mp.join() def reset(self) -> None: + # We need to stage the local state if another replicate joins during the + # first step. + if self.ft_manager: + self.cpu_staging(None) self.begin_time = time.monotonic() + def _initialize_states( + self, + states: Dict[str, Any], + dataloader: DataLoader, + model_parts: List[nn.Module], + optimizers: OptimizersContainer, + lr_schedulers: SchedulersContainer, + ) -> None: + """ + Note: Pipeline Parallelism and Virtual Stages + + 1. Even for simple PP schedules, there is a separate optimizer each PP rank. + rank0's optimizer would have a param_group[0] which refers to layers.0 in the + original model. rank1's would _also_ have a param_group[0], since it's index based, + but referring to layers.1. + When saving, these collide and one of them is lost. Then when reloading, only one + stage can restore its optimizer states, others will error. + + The solution to this problem is optimizer flattening: it landed in #127071 + and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict' + kwarg to DCP functions called in the OptimizerContainer. + + 2. With complex PP schedules, we have multiple model chunks per pp rank. This + compounds challenge (1) by also requiring us to reason about multiple 'optim' + objects locally. + + We solve this in the Model and Optimizer wrapper classes by flattening the + state dicts from each object into one state dict before saving/loading. + We rely on the individual state_dicts to not collide, which is gauranteed for + the model by correct pipeline splitting and for the optimizer by the flattening + support described in (1). + + 3. LR schedulers also index model states like optimizers and would need to be + flattened properly to support resharding. Unfortunately, the implementations of + different lr_schedulers do not follow a clear pattern like optimizers do, so it's + hard to write a generic 'flattener' utility. + + TODO: This is currently unsolved and needs a fix. + """ + self.states = states + self.states.update( + { + "model": ModelWrapper(model_parts), + "optimizer": optimizers, + "dataloader": dataloader, + } + ) + self.states.update(lr_schedulers.get_lr_scheduler_state()) + def _create_checkpoint_id(self, step: int) -> str: return os.path.join(self.folder, f"step-{step}") @@ -324,31 +353,8 @@ def _async_wait(self) -> None: self.async_future.result() def _async_with_pinned_memory(self, checkpoint_id: str) -> None: - try: - from torch.distributed._state_dict_utils import ( - _copy_state_dict, - _create_cpu_state_dict, - ) - except ImportError as e: - raise ImportError( - "Please install the latest PyTorch nightly to use async checkpointing with pinned memory." - ) from e - state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states) - if self.cpu_offload_state_dict is None: - logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f") - self.cpu_offload_state_dict = _create_cpu_state_dict( - state_dict, pin_memory=True, share_memory=True - ) - - logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f") - with torch.cuda.stream(self.staging_stream): - self.cpu_offload_state_dict = _copy_state_dict( - state_dict, - self.cpu_offload_state_dict, - non_blocking=True, - ) - self.staging = True - self.staging_id = checkpoint_id + self.cpu_staging(checkpoint_id) + self.sending_to_checkpoint_mp = True def save(self, curr_step: int, force: bool = False) -> None: """ @@ -358,6 +364,8 @@ def save(self, curr_step: int, force: bool = False) -> None: for initial seed checkpoint. """ if not self._should_save(curr_step, force): + if self.ft_manager: + self.cpu_staging(None) return begin = time.monotonic() @@ -381,26 +389,51 @@ def save(self, curr_step: int, force: bool = False) -> None: f"in {time.monotonic() - begin:.2f} seconds." ) + def cpu_staging(self, checkpoint_id: Optional[str]) -> None: + """Offload state_dict to CPU memory""" + state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states) + if self.cpu_offload_state_dict is None: + logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f") + self.cpu_offload_state_dict = _create_cpu_state_dict( + state_dict, pin_memory=True, share_memory=True + ) + + logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f") + with torch.cuda.stream(self.staging_stream): + self.cpu_offload_state_dict = _copy_state_dict( + state_dict, + self.cpu_offload_state_dict, + non_blocking=True, + ) + self.staging = True + self.staging_id = checkpoint_id + + def wait_for_staging(self) -> None: + if not self.staging_stream.query(): + self.staging_stream.synchronize() + self.staging = False + + def staging_results(self) -> Dict[str, Any]: + self.maybe_wait_for_staging() + return self.cpu_offload_state_dict + def maybe_wait_for_staging(self) -> None: - if ( - self.enable_checkpoint - and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM - and self.staging - ): - if not self.staging_stream.query(): - self.staging_stream.synchronize() - - def sync_func(): - self.mp_queue_send.put_nowait( - (self.cpu_offload_state_dict, self.staging_id) - ) - - # This may be a faster way to do zero-overhead checkpointing staging - # checkpointing but we need more thorough investigation before - # swithing to this method. - # self.my_thread = threading.Thread(target=func).start() - sync_func() - self.staging = False + if self.enable_staging and self.staging: + self.wait_for_staging() + + if self.sending_to_checkpoint_mp: + # Copy the sync staging result to another process. + def sync_func(): + self.mp_queue_send.put_nowait( + (self.cpu_offload_state_dict, self.staging_id) + ) + + # This may be a faster way to do zero-overhead checkpointing staging + # checkpointing but we need more thorough investigation before + # swithing to this method. + # self.my_thread = threading.Thread(target=func).start() + sync_func() + self.sending_to_checkpoint_mp = False def load(self, step: int = -1) -> bool: if not self.enable_checkpoint: diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d59e34bc6..ce295b996 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -604,6 +604,19 @@ def __init__(self): action="store_true", ) + self.parser.add_argument( + "--experimental.enable_torchft", + action="store_true", + help="Enable TorchFT integration.", + ) + + self.parser.add_argument( + "--experimental.ft_replica_group_id", + type=int, + default=-1, + help="The FT replicate group of this run.", + ) + def to_dict(self): return self.args_dict diff --git a/torchtitan/ft.py b/torchtitan/ft.py new file mode 100644 index 000000000..b620fb9a6 --- /dev/null +++ b/torchtitan/ft.py @@ -0,0 +1,59 @@ +from typing import Any, Callable, Optional +import importlib + +from torchtitan.config_manager import JobConfig +from torch.distributed._state_dict_utils import ( + _copy_state_dict, + _create_cpu_state_dict, +) + +if importlib.util.find_spec("torchft") is not None: + import torchft as ft + has_torchft = True +else: + has_torchft = False + + +def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]: + """ + Initialize the FT manager for the given job. + """ + if not job.experimental.enable_torchft: + return None + + if not has_torchft: + raise ImportError("torchft is not installed. Please install it.") + + pg = ft.ProcessGroupBabyNCCL() + manager = ft.Manager( + pg=pg, + min_replica_size=1, + load_state_dict=None, + state_dict=None, + use_async_quorum=True, + replica_id=f"torchtitan_ft_{job.experimental.ft_replica_group_id}", + ) + + return manager + + +def set_ft_state_dict_fns(manager: Optional["ft.Manager"], ckpt_manager) -> None: + """ + Set the state dict for the given manager. + """ + if manager is None: + return + + def state_dict(): + ret = {} + for k, v in ckpt_manager.staging_results().items(): + if k in {"model", "optimizer", "lr_schedulers"}: + ret[k] = v + return ret + + def load_state_dict(state_dict): + assert state_dict is not None + for k, v in state_dict.items(): + ckpt_manager.states[k].load_state_dict(v) + + manager.set_state_dict_fns(load_state_dict, state_dict) diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index 8927125fd..fe99cf3d8 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import functools -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torch import torch.nn as nn @@ -25,7 +25,11 @@ class OptimizersContainer(Stateful): """ def __init__( - self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str + self, + model_parts: List[nn.Module], + optimizer_kwargs: Dict[str, Any], + name: str, + ft_manager: Optional[Any] = None, ) -> None: self.optimizers = [] self.model_parts = model_parts @@ -38,6 +42,19 @@ def __init__( else: raise NotImplementedError(f"Optimizer {name} not added.") self.optimizers.append(optimizer) + if ft_manager: + import torchft as ft + + # Force to initialize the optimizer state so that `optim.step()` + # won't be called by ft.Optimizer.step(). + _ = { + k: v + for sd in map(get_optimizer_state_dict, model_parts, self.optimizers) + for k, v in sd.items() + } + self.optimizers = [ + ft.Optimizer(ft_manager, optim) for optim in self.optimizers + ] self._validate_length(len(self.model_parts)) def _validate_length(self, expected_length) -> None: @@ -128,7 +145,7 @@ def zero_grad(self) -> None: # consider split between PP and non-PP def build_optimizers( - model_parts: List[nn.Module], job_config: JobConfig + model_parts: List[nn.Module], job_config: JobConfig, ft_manager: Optional[Any] ) -> OptimizersContainer: """Wrap one optimizer per model part in an OptimizersContainer which provides a single step() and zero_grad() method for all the child optimizers. @@ -148,9 +165,11 @@ def build_optimizers( "fused": fused, "foreach": not fused, } + if optim_in_bwd and ft_manager: + raise NotImplementedError("TorchFT currently doesn't support optim_in_bwd.") return ( - OptimizersContainer(model_parts, optimizer_kwargs, name) + OptimizersContainer(model_parts, optimizer_kwargs, name, ft_manager) if not optim_in_bwd else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name) ) diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 13d066a84..e1a1a3437 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from functools import cached_property +from typing import Any, Optional from torch.distributed.device_mesh import init_device_mesh from torchtitan.logging import logger @@ -20,6 +21,7 @@ class ParallelDims: pp: int world_size: int enable_loss_parallel: bool + ft_manager: Optional[Any] def __post_init__(self): self._validate() @@ -52,13 +54,24 @@ def build_mesh(self, device_type): [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], ["pp", "dp_replicate", "dp_shard", "cp", "tp"], ): - if d > 1: + if d > 1 or (name == "dp_replicate" and self.ft_manager is not None): dims.append(d) names.append(name) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") names = tuple(names) - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + if self.ft_manager is None: + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + else: + from torchft.process_group import ft_init_device_mesh + + mesh = ft_init_device_mesh( + device_type=device_type, + mesh_shape=dims, + mesh_dim_names=names, + replicate_dim=names.index("dp_replicate"), + manager=self.ft_manager, + ) # Create all the submesh here to ensure all required process groups are # initialized: @@ -69,7 +82,7 @@ def build_mesh(self, device_type): # Mesh for loss all-reduce dp_cp_mesh_dim_names = [] - if self.dp_replicate_enabled: + if self.dp_replicate_enabled or ft_manager is not None: dp_mesh_dim_names.append("dp_replicate") dp_cp_mesh_dim_names.append("dp_replicate") if self.dp_shard_enabled: @@ -97,7 +110,7 @@ def dp_enabled(self): @property def dp_replicate_enabled(self): - return self.dp_replicate > 1 + return self.dp_replicate > 1 or self.ft_manager is not None @property def dp_shard_enabled(self): diff --git a/torchtitan/utils.py b/torchtitan/utils.py index c9dcf2fac..da389dca3 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -35,6 +35,13 @@ def get_device_info(): def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float: + import torchft as ft + + if isinstance(mesh, ft.process_group._FlattenDeviceMesh): + torch.distributed.all_reduce(x, group=mesh.managed_mesh.replicate_pg) + # x = funcol.all_reduce(x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg) + mesh = mesh.managed_mesh.mesh + if isinstance(x, DTensor): # functional collectives do not support DTensor inputs x = x.full_tensor() @@ -399,6 +406,7 @@ def clip_grad_norm_( if isinstance(total_norm, DTensor): # Will reach here if any non-PP parallelism is used. # If only using PP, total_norm will be a local tensor. + assert False, total_norm.placements total_norm = total_norm.full_tensor() if pp_mesh is not None: diff --git a/train.py b/train.py index bac227722..be3dbdcba 100644 --- a/train.py +++ b/train.py @@ -17,6 +17,7 @@ from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_hf_data_loader, build_tokenizer from torchtitan.float8 import Float8Handler +from torchtitan.ft import init_ft_manager, set_ft_state_dict_fns from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_device_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config @@ -45,6 +46,8 @@ def main(job_config: JobConfig): # take control of garbage collection to avoid stragglers gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) + ft_manager = init_ft_manager(job_config) + # init distributed world_size = int(os.environ["WORLD_SIZE"]) parallel_dims = ParallelDims( @@ -55,6 +58,7 @@ def main(job_config: JobConfig): pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=not job_config.training.disable_loss_parallel, + ft_manager=ft_manager, ) device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") device_module.set_device(device) @@ -154,7 +158,8 @@ def loss_fn(pred, labels): pp_schedule, model_parts = models_pipelining_fns[model_name]( model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn ) - # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead + # when PP is enabled, `model` obj is no longer used after this point, + # model_parts is used instead del model # For PP with looped schedules, each item in model_parts is one stage-model-chunk. @@ -185,7 +190,7 @@ def loss_fn(pred, labels): ) # build optimizer after applying parallelisms to the model - optimizers = build_optimizers(model_parts, job_config) + optimizers = build_optimizers(model_parts, job_config, ft_manager) lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) train_state = TrainState() @@ -198,7 +203,9 @@ def loss_fn(pred, labels): lr_schedulers=lr_schedulers, states={"train_state": train_state}, job_config=job_config, + ft_manager=ft_manager, ) + set_ft_state_dict_fns(ft_manager, checkpoint) if job_config.checkpoint.create_seed_checkpoint: assert ( @@ -313,16 +320,18 @@ def loss_fn(pred, labels): del pred loss.backward() + # TODO(torchft): fix this # clip gradients - utils.clip_grad_norm_( - [p for m in model_parts for p in m.parameters()], - job_config.training.max_norm, - foreach=True, - pp_mesh=pp_mesh if parallel_dims.pp_enabled else None, - ) - + # utils.clip_grad_norm_( + # [p for m in model_parts for p in m.parameters()], + # job_config.training.max_norm, + # foreach=True, + # pp_mesh=pp_mesh if parallel_dims.pp_enabled else None, + # ) + + # TODO(torchft): fix this # sync float8 amaxes and scales - float8_handler.sync_float8_amax_and_scale_history(model_parts) + # float8_handler.sync_float8_amax_and_scale_history(model_parts) # optimizer step checkpoint.maybe_wait_for_staging() @@ -331,7 +340,7 @@ def loss_fn(pred, labels): # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance - float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts) + # float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts) # log metrics if ( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 733bc0ae4..732efb951 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -36,7 +36,7 @@ batch_size = 8 seq_len = 2048 warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 100 data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 tensor_parallel_degree = 1 From f024dff41d967ca5e93f59e0e3e0ec82bbdda9c1 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 27 Jan 2025 13:08:02 -0800 Subject: [PATCH 2/5] Update [ghstack-poisoned] --- torchtitan/checkpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 0971de2c4..1f5f47660 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -149,6 +149,9 @@ def __init__( ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint self.ft_manager = ft_manager + self.enable_staging = ( + self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM + ) or self.ft_manager if not self.enable_checkpoint and self.ft_manager is None: return @@ -158,9 +161,6 @@ def __init__( ) async_mode = ckpt_config.async_mode.lower() - self.enable_staging = ( - self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM - ) or self.ft_manager self.staging = False self.sending_to_checkpoint_mp = False self.staging_id = None From 4d4d554bf9b1b8f393f8ded6e15e65b829421ab3 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 27 Jan 2025 17:23:02 -0800 Subject: [PATCH 3/5] Update [ghstack-poisoned] --- torchtitan/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/utils.py b/torchtitan/utils.py index da389dca3..b9545d79c 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -36,6 +36,7 @@ def get_device_info(): def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float: import torchft as ft + return 0.1 if isinstance(mesh, ft.process_group._FlattenDeviceMesh): torch.distributed.all_reduce(x, group=mesh.managed_mesh.replicate_pg) From 4b2edcb0a81a8ead586c021ff664c24d7fe70e70 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 28 Jan 2025 17:05:40 -0800 Subject: [PATCH 4/5] Update [ghstack-poisoned] --- torchtitan/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtitan/utils.py b/torchtitan/utils.py index b9545d79c..da389dca3 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -36,7 +36,6 @@ def get_device_info(): def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float: import torchft as ft - return 0.1 if isinstance(mesh, ft.process_group._FlattenDeviceMesh): torch.distributed.all_reduce(x, group=mesh.managed_mesh.replicate_pg) From 955e6569060e183e47bc5dd2160d1b7ed7b14900 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 31 Jan 2025 14:31:02 -0800 Subject: [PATCH 5/5] Update [ghstack-poisoned] --- torchtitan/ft.py | 9 ++++----- torchtitan/utils.py | 15 ++++++++++++--- train.py | 13 ++++++------- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/torchtitan/ft.py b/torchtitan/ft.py index b620fb9a6..f342a41a1 100644 --- a/torchtitan/ft.py +++ b/torchtitan/ft.py @@ -1,14 +1,13 @@ -from typing import Any, Callable, Optional import importlib +from typing import Any, Callable, Optional + +from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict from torchtitan.config_manager import JobConfig -from torch.distributed._state_dict_utils import ( - _copy_state_dict, - _create_cpu_state_dict, -) if importlib.util.find_spec("torchft") is not None: import torchft as ft + has_torchft = True else: has_torchft = False diff --git a/torchtitan/utils.py b/torchtitan/utils.py index da389dca3..5e4078326 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +import copy import gc import math import os @@ -16,6 +17,7 @@ import torch import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d +import torchft as ft from torch import distributed as dist from torch._utils import _get_available_device_type, _get_device_module from torch.distributed.device_mesh import DeviceMesh @@ -35,8 +37,6 @@ def get_device_info(): def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float: - import torchft as ft - if isinstance(mesh, ft.process_group._FlattenDeviceMesh): torch.distributed.all_reduce(x, group=mesh.managed_mesh.replicate_pg) # x = funcol.all_reduce(x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg) @@ -406,7 +406,16 @@ def clip_grad_norm_( if isinstance(total_norm, DTensor): # Will reach here if any non-PP parallelism is used. # If only using PP, total_norm will be a local tensor. - assert False, total_norm.placements + mesh = total_norm._spec.mesh + if isinstance(mesh, ft.process_group.ManagedDeviceMesh): + local_tensor = total_norm.to_local() + dist.all_reduce(local_tensor, op=dist.ReduceOp.AVG, group=mesh.replicate_pg) + + placements = list(copy.copy(total_norm._spec.placements)) + placements.pop(mesh.replicate_dim) + mesh = mesh.mesh + total_norm = DTensor.from_local(local_tensor, mesh, placements) + total_norm = total_norm.full_tensor() if pp_mesh is not None: diff --git a/train.py b/train.py index 80b16742f..981425d61 100644 --- a/train.py +++ b/train.py @@ -320,14 +320,13 @@ def loss_fn(pred, labels): del pred loss.backward() - # TODO(torchft): fix this # clip gradients - # utils.clip_grad_norm_( - # [p for m in model_parts for p in m.parameters()], - # job_config.training.max_norm, - # foreach=True, - # pp_mesh=pp_mesh if parallel_dims.pp_enabled else None, - # ) + utils.clip_grad_norm_( + [p for m in model_parts for p in m.parameters()], + job_config.training.max_norm, + foreach=True, + pp_mesh=pp_mesh if parallel_dims.pp_enabled else None, + ) # TODO(torchft): fix this # sync float8 amaxes and scales