Skip to content

Commit

Permalink
Add changes to support FSDP (#598)
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekgoe authored Jan 23, 2024
1 parent 1cca12a commit e238bca
Show file tree
Hide file tree
Showing 13 changed files with 526 additions and 29 deletions.
12 changes: 12 additions & 0 deletions examples/language-modeling/fsdp_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_backward_prefetch": "BACKWARD_PRE",
"fsdp_forward_prefetch": false,
"fsdp_offload_params": false,
"fsdp_sharding_strategy": 1,
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_sync_module_states": true,
"fsdp_use_orig_params": true,
"transformer_layer_cls_to_wrap": "GaudiLlamaDecoderLayer",
"fsdp_activation_checkpointing": false
}
13 changes: 8 additions & 5 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,8 @@
import torch
import transformers
from datasets import load_dataset
from peft import (
LoraConfig,
TaskType,
get_peft_model,
)
from peft import LoraConfig, TaskType, get_peft_model, tuners
from peft.utils.other import fsdp_auto_wrap_policy
from transformers import (
AutoConfig,
AutoModelForCausalLM,
Expand All @@ -45,6 +42,7 @@
from transformers.trainer_utils import is_main_process

from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments
from optimum.habana.peft.layer import GaudiLoraLayerLinearForward
from optimum.habana.utils import set_seed


Expand Down Expand Up @@ -674,6 +672,7 @@ def compute_metrics(eval_preds):
)
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
tuners.lora.layer.Linear.forward = GaudiLoraLayerLinearForward
lora_model = get_peft_model(model, peft_config)
if training_args.bf16:
lora_model = lora_model.to(torch.bfloat16)
Expand All @@ -695,6 +694,10 @@ def compute_metrics(eval_preds):
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
)

# Solution for https://github.com/huggingface/peft/blob/v0.6.2/README.md#caveats (1)
if training_args.fsdp and training_args.fsdp_config["auto_wrap_policy"] == "TRANSFORMER_BASED_WRAP":
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(lora_model)

if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
Expand Down
12 changes: 12 additions & 0 deletions examples/question-answering/fsdp_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_backward_prefetch": "BACKWARD_PRE",
"fsdp_forward_prefetch": false,
"fsdp_offload_params": false,
"fsdp_sharding_strategy": 1,
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_sync_module_states": true,
"fsdp_use_orig_params": true,
"transformer_layer_cls_to_wrap": "BertLayer",
"fsdp_activation_checkpointing": false
}
96 changes: 92 additions & 4 deletions optimum/habana/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import contextlib
import functools
import math
import os
import sys
Expand All @@ -37,7 +38,6 @@
DistributedDataParallelKwargs,
DistributedType,
FP8RecipeKwargs,
FullyShardedDataParallelPlugin,
GradientAccumulationPlugin,
GradScalerKwargs,
InitProcessGroupKwargs,
Expand All @@ -50,8 +50,10 @@
check_os_kernel,
convert_outputs_to_fp32,
is_deepspeed_available,
is_torch_version,
parse_choice_from_env,
)
from accelerate.utils.constants import FSDP_PYTORCH_VERSION
from accelerate.utils.operations import _gpu_gather
from accelerate.utils.other import is_compiled_module
from torch.optim.lr_scheduler import LRScheduler
Expand All @@ -68,7 +70,12 @@

from .data_loader import gaudi_prepare_data_loader
from .state import GaudiAcceleratorState, GaudiPartialState
from .utils import GaudiDistributedType, GaudiDynamoBackend, GaudiTorchDynamoPlugin
from .utils import (
GaudiDistributedType,
GaudiDynamoBackend,
GaudiFullyShardedDataParallelPlugin,
GaudiTorchDynamoPlugin,
)


logger = get_logger(__name__)
Expand All @@ -87,7 +94,7 @@ def __init__(
gradient_accumulation_steps: int = 1,
cpu: bool = False,
deepspeed_plugin: DeepSpeedPlugin | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
fsdp_plugin: GaudiFullyShardedDataParallelPlugin | None = None,
megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str | RNGType] | None = None,
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
Expand Down Expand Up @@ -142,6 +149,27 @@ def __init__(
deepspeed_plugin.set_mixed_precision(mixed_precision)
deepspeed_plugin.set_deepspeed_weakref()

if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance(
fsdp_plugin, GaudiFullyShardedDataParallelPlugin
):
import importlib.metadata

torch_version = importlib.metadata.version("torch")
torch_version = torch_version[5:]
if is_torch_version("<", FSDP_PYTORCH_VERSION + torch_version):
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")

if fsdp_plugin is None: # init from env variables
fsdp_plugin = (
GaudiFullyShardedDataParallelPlugin()
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
else None
)
else:
if not isinstance(fsdp_plugin, GaudiFullyShardedDataParallelPlugin):
raise TypeError("`fsdp_plugin` must be a GaudiFullyShardedDataParallelPlugin object.")
os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided

# Kwargs handlers
self.ddp_handler = None
self.scaler_handler = None
Expand Down Expand Up @@ -370,6 +398,54 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
if any(p.requires_grad for p in model.parameters()):
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
elif self.distributed_type == GaudiDistributedType.FSDP:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
# In case the model is already compiled using PyTorch 2.0 and the wrapped model in it
# is a FSDP model, don't wrap it again
is_type_fsdp = isinstance(model, FSDP) or (
is_compiled_module(model) and isinstance(model._orig_mod, FSDP)
)

if not is_type_fsdp:
self.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = self.state.fsdp_plugin
kwargs = {
"sharding_strategy": fsdp_plugin.sharding_strategy,
"cpu_offload": fsdp_plugin.cpu_offload,
"auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
"mixed_precision": fsdp_plugin.mixed_precision_policy,
"sync_module_states": fsdp_plugin.sync_module_states,
"backward_prefetch": fsdp_plugin.backward_prefetch,
"forward_prefetch": fsdp_plugin.forward_prefetch,
"use_orig_params": fsdp_plugin.use_orig_params,
"param_init_fn": fsdp_plugin.param_init_fn,
"ignored_modules": fsdp_plugin.ignored_modules,
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
"device_id": torch.device("hpu"),
}
model = FSDP(model, **kwargs)
if fsdp_plugin.activation_checkpointing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)

apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
),
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
)
# if the previous and current models are same, delete the previous one
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
del self._models[-2]
self._models[-1] = model
# torch.compile should be called last and only if the model isn't already compiled.
if self.state.dynamo_plugin.backend != GaudiDynamoBackend.NO and not is_compiled_module(model):
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
Expand Down Expand Up @@ -672,7 +748,11 @@ def gather(self, tensor):
tensor([0, 1, 2, 3])
```
"""
if GaudiPartialState().distributed_type in [GaudiDistributedType.MULTI_HPU, GaudiDistributedType.DEEPSPEED]:
if GaudiPartialState().distributed_type in [
GaudiDistributedType.MULTI_HPU,
GaudiDistributedType.DEEPSPEED,
GaudiDistributedType.FSDP,
]:
return _gpu_gather(tensor)
else:
return tensor
Expand Down Expand Up @@ -719,6 +799,14 @@ def get_state_dict(self, model, unwrap=True):
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save

state_dict = clone_tensors_for_torch_save(self.unwrap_model(model).state_dict())
# copied from https://github.com/huggingface/accelerate/blob/6f05bbd41a179cc9a86238c7c6f3f4eded70fbd8/src/accelerate/accelerator.py#L3057
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
state_dict = model.state_dict()
else:
if unwrap:
model = self.unwrap_model(model)
Expand Down
10 changes: 10 additions & 0 deletions optimum/habana/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def __init__(self, cpu: bool = False, **kwargs):
deepspeed.init_distributed(dist_backend=self.backend, **kwargs)
logger.info("DeepSpeed is enabled.")
self._mixed_precision = "no" # deepspeed handles mixed_precision using deepspeed_config
elif os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
self.distributed_type = GaudiDistributedType.FSDP
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend=self.backend, rank=rank, world_size=world_size)
logger.info("Enabled distributed run.")
else:
self.distributed_type = GaudiDistributedType.MULTI_HPU
if not torch.distributed.is_initialized():
Expand Down Expand Up @@ -115,6 +120,7 @@ def wait_for_everyone(self):
GaudiDistributedType.MULTI_CPU,
GaudiDistributedType.DEEPSPEED,
GaudiDistributedType.MULTI_HPU,
GaudiDistributedType.FSDP,
):
torch.distributed.barrier()

Expand Down Expand Up @@ -171,6 +177,10 @@ def __init__(
)
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu:
self.deepspeed_plugin = deepspeed_plugin
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" and not cpu:
if self._mixed_precision != "no":
fsdp_plugin.set_mixed_precision(self._mixed_precision)
self.fsdp_plugin = fsdp_plugin
GaudiPartialState._shared_state["distributed_type"] = self.distributed_type
self.use_ipex = False

Expand Down
7 changes: 6 additions & 1 deletion optimum/habana/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from .dataclasses import GaudiDistributedType, GaudiDynamoBackend, GaudiTorchDynamoPlugin
from .dataclasses import (
GaudiDistributedType,
GaudiDynamoBackend,
GaudiFullyShardedDataParallelPlugin,
GaudiTorchDynamoPlugin,
)
38 changes: 38 additions & 0 deletions optimum/habana/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from dataclasses import dataclass
from enum import Enum

import torch
from accelerate.utils import FullyShardedDataParallelPlugin
from accelerate.utils.constants import FSDP_BACKWARD_PREFETCH
from accelerate.utils.dataclasses import BaseEnum, TorchDynamoPlugin
from accelerate.utils.environment import str_to_bool

Expand All @@ -31,12 +34,14 @@ class GaudiDistributedType(str, Enum):
- **NO** -- Not a distributed environment, just a single process.
- **MULTI_HPU** -- Distributed on multiple HPUs.
- **DEEPSPEED** -- Using DeepSpeed.
- **FSDP** -- Using FSDP.
"""

# Subclassing str as well as Enum allows the `GaudiDistributedType` to be JSON-serializable out of the box.
NO = "NO"
MULTI_HPU = "MULTI_HPU"
DEEPSPEED = "DEEPSPEED"
FSDP = "FSDP"


class GaudiDynamoBackend(str, BaseEnum):
Expand Down Expand Up @@ -106,3 +111,36 @@ def __post_init__(self):
self.fullgraph = str_to_bool(os.environ.get(prefix + "USE_FULLGRAPH", "False")) == 1
if self.dynamic is None:
self.dynamic = str_to_bool(os.environ.get(prefix + "USE_DYNAMIC", "False")) == 1


@dataclass
class GaudiFullyShardedDataParallelPlugin(FullyShardedDataParallelPlugin):
def __post_init__(self):
from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, CPUOffload, ShardingStrategy

prefix = "FSDP_"
if self.sharding_strategy is None:
self.sharding_strategy = ShardingStrategy(int(os.environ.get(prefix + "SHARDING_STRATEGY", 1)))

if self.cpu_offload is None:
if str_to_bool(os.environ.get(prefix + "OFFLOAD_PARAMS", "False")) == 1:
self.cpu_offload = CPUOffload(offload_params=True)
else:
self.cpu_offload = CPUOffload(offload_params=False)

if self.backward_prefetch is None:
prefetch_policy = os.environ.get(prefix + "BACKWARD_PREFETCH", "NO_PREFETCH")
if prefetch_policy != FSDP_BACKWARD_PREFETCH[-1]:
self.backward_prefetch = BackwardPrefetch(FSDP_BACKWARD_PREFETCH.index(prefetch_policy) + 1)

if self.state_dict_type is None:
state_dict_type_policy = os.environ.get(prefix + "STATE_DICT_TYPE", "FULL_STATE_DICT")
self.set_state_dict_type(state_dict_type_policy)
self.use_orig_params = str_to_bool(os.environ.get(prefix + "USE_ORIG_PARAMS", "False")) == 1
self.sync_module_states = str_to_bool(os.environ.get(prefix + "SYNC_MODULE_STATES", "True")) == 1
self.forward_prefetch = str_to_bool(os.environ.get(prefix + "FORWARD_PREFETCH", "False")) == 1
self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1

if self.sync_module_states:
device = torch.device("hpu")
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)
1 change: 1 addition & 0 deletions optimum/habana/peft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .layer import GaudiLoraLayerLinearForward
31 changes: 31 additions & 0 deletions optimum/habana/peft/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Any

import torch


def GaudiLoraLayerLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
# https://github.com/huggingface/peft/blob/4b02148af252c17e36b0a4b995f9e8519806fbb5/src/peft/tuners/lora/layer.py#L354C1-L376C22
# only differences are avoiding inplace update of "result" to prevent error from torch Dynamo in torch.compile mode of execution
# and replacing self.base_layer by self._linear
previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
result = self._linear(x, *args, **kwargs)
elif self.merged:
result = self._linear(x, *args, **kwargs)
else:
result = self._linear(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
result = result.clone() + lora_B(lora_A(dropout(x))) * scaling

result = result.to(previous_dtype)
return result
10 changes: 8 additions & 2 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,14 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states):
- override RMSNorm with Habana fused RMSNorm
"""
if hidden_states.device.type == "hpu" and FusedRMSNorm:
hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon)
return hidden_states
# mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype
if hidden_states.dtype != self.weight.dtype:
orig_dtype = hidden_states.dtype
hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon)
return hidden_states.to(orig_dtype)
else:
hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon)
return hidden_states
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
Expand Down
Loading

0 comments on commit e238bca

Please sign in to comment.