Skip to content

Commit

Permalink
Run formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekgoe committed Jan 5, 2024
1 parent 8c63a3c commit cd026a1
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 39 deletions.
10 changes: 3 additions & 7 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,7 @@
import torch
import transformers
from datasets import load_dataset
from peft import (
LoraConfig,
TaskType,
get_peft_model,
tuners
)
from peft import LoraConfig, TaskType, get_peft_model, tuners
from transformers import (
AutoConfig,
AutoModelForCausalLM,
Expand All @@ -46,8 +41,9 @@
from transformers.trainer_utils import is_main_process

from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments
from optimum.habana.utils import set_seed
from optimum.habana.peft.layer import GaudiLoraLayerLinearForward
from optimum.habana.utils import set_seed


try:
from optimum.habana.utils import check_optimum_habana_min_version
Expand Down
22 changes: 17 additions & 5 deletions optimum/habana/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from __future__ import annotations

import contextlib
import functools
import math
import os
import sys
import warnings
import functools
from collections import OrderedDict
from contextlib import contextmanager
from dataclasses import make_dataclass
Expand Down Expand Up @@ -68,7 +68,12 @@

from .data_loader import gaudi_prepare_data_loader
from .state import GaudiAcceleratorState, GaudiPartialState
from .utils import GaudiDistributedType, GaudiDynamoBackend, GaudiTorchDynamoPlugin, GaudiFullyShardedDataParallelPlugin
from .utils import (
GaudiDistributedType,
GaudiDynamoBackend,
GaudiFullyShardedDataParallelPlugin,
GaudiTorchDynamoPlugin,
)


logger = get_logger(__name__)
Expand Down Expand Up @@ -145,14 +150,17 @@ def __init__(
fsdp_plugin, GaudiFullyShardedDataParallelPlugin
):
import importlib.metadata

torch_version = importlib.metadata.version("torch")
torch_version = torch_version[5:]
if is_torch_version("<", FSDP_PYTORCH_VERSION + torch_version):
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")

if fsdp_plugin is None: # init from env variables
fsdp_plugin = (
GaudiFullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None
GaudiFullyShardedDataParallelPlugin()
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
else None
)
else:
if not isinstance(fsdp_plugin, GaudiFullyShardedDataParallelPlugin):
Expand Down Expand Up @@ -413,7 +421,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
"param_init_fn": fsdp_plugin.param_init_fn,
"ignored_modules": fsdp_plugin.ignored_modules,
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
"device_id": torch.device("hpu")
"device_id": torch.device("hpu"),
}
model = FSDP(model, **kwargs)
if fsdp_plugin.activation_checkpointing:
Expand Down Expand Up @@ -737,7 +745,11 @@ def gather(self, tensor):
tensor([0, 1, 2, 3])
```
"""
if GaudiPartialState().distributed_type in [GaudiDistributedType.MULTI_HPU, GaudiDistributedType.DEEPSPEED, GaudiDistributedType.FSDP]:
if GaudiPartialState().distributed_type in [
GaudiDistributedType.MULTI_HPU,
GaudiDistributedType.DEEPSPEED,
GaudiDistributedType.FSDP,
]:
return _gpu_gather(tensor)
else:
return tensor
Expand Down
7 changes: 6 additions & 1 deletion optimum/habana/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from .dataclasses import GaudiDistributedType, GaudiDynamoBackend, GaudiTorchDynamoPlugin, GaudiFullyShardedDataParallelPlugin
from .dataclasses import (
GaudiDistributedType,
GaudiDynamoBackend,
GaudiFullyShardedDataParallelPlugin,
GaudiTorchDynamoPlugin,
)
11 changes: 7 additions & 4 deletions optimum/habana/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
# limitations under the License.

import os
import torch
from dataclasses import dataclass
from enum import Enum

import torch
from accelerate.utils import FullyShardedDataParallelPlugin
from accelerate.utils.constants import FSDP_BACKWARD_PREFETCH
from accelerate.utils.dataclasses import BaseEnum, TorchDynamoPlugin
from accelerate.utils.environment import str_to_bool
from accelerate.utils.constants import FSDP_BACKWARD_PREFETCH
from accelerate.utils import FullyShardedDataParallelPlugin


class GaudiDistributedType(str, Enum):
"""
Expand All @@ -42,6 +43,7 @@ class GaudiDistributedType(str, Enum):
DEEPSPEED = "DEEPSPEED"
FSDP = "FSDP"


class GaudiDynamoBackend(str, BaseEnum):
"""
Represents a dynamo backend (see https://github.com/pytorch/torchdynamo).
Expand Down Expand Up @@ -110,6 +112,7 @@ def __post_init__(self):
if self.dynamic is None:
self.dynamic = str_to_bool(os.environ.get(prefix + "USE_DYNAMIC", "False")) == 1


@dataclass
class GaudiFullyShardedDataParallelPlugin(FullyShardedDataParallelPlugin):
def __post_init__(self):
Expand Down Expand Up @@ -140,4 +143,4 @@ def __post_init__(self):

if self.sync_module_states:
device = "hpu:" + str(torch.hpu.current_device())
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)
13 changes: 4 additions & 9 deletions optimum/habana/peft/layer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import math
import warnings
from typing import Any, List, Optional, Union
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.pytorch_utils import Conv1D


def GaudiLoraLayerLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
#https://github.com/huggingface/peft/blob/4b02148af252c17e36b0a4b995f9e8519806fbb5/src/peft/tuners/lora/layer.py#L354C1-L376C22
#only differences are avoiding inplace update of "result" to prevent error from torch Dynamo in torch.compile mode of execution
#and replacing self.base_layer by self._linear
# https://github.com/huggingface/peft/blob/4b02148af252c17e36b0a4b995f9e8519806fbb5/src/peft/tuners/lora/layer.py#L354C1-L376C22
# only differences are avoiding inplace update of "result" to prevent error from torch Dynamo in torch.compile mode of execution
# and replacing self.base_layer by self._linear
previous_dtype = x.dtype

if self.disable_adapters:
Expand Down
6 changes: 2 additions & 4 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import sys
import time
import warnings
from packaging import version
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -117,9 +116,7 @@
from accelerate.utils import DeepSpeedSchedulerWrapper

if is_accelerate_available():
from accelerate import __version__ as accelerate_version
from accelerate.utils import (
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
save_fsdp_optimizer,
Expand All @@ -140,6 +137,7 @@
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"


class GaudiTrainer(Trainer):
"""
GaudiTrainer is built on top of the tranformers' Trainer to enable
Expand Down Expand Up @@ -2049,7 +2047,7 @@ def create_accelerator_and_postprocess(self):
dispatch_batches=self.args.dispatch_batches,
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
even_batches= not self.args.dataloader_drop_last,
even_batches=not self.args.dataloader_drop_last,
distribution_strategy=self.args.distribution_strategy,
)

Expand Down
12 changes: 7 additions & 5 deletions optimum/habana/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import io
import json
import os
import warnings
from dataclasses import asdict, dataclass, field
from datetime import timedelta
Expand All @@ -42,10 +42,11 @@
from optimum.utils import logging

from ..accelerate.state import GaudiAcceleratorState, GaudiPartialState
from ..accelerate.utils import GaudiDistributedType, GaudiFullyShardedDataParallelPlugin
from ..accelerate.utils import GaudiDistributedType
from ..utils import get_habana_frameworks_version
from .gaudi_configuration import GaudiConfig


if is_torch_available():
import torch

Expand Down Expand Up @@ -507,7 +508,7 @@ def __post_init__(self):
" during training"
)

# Copy of https://github.com/huggingface/transformers/blob/b71f20a7c9f3716d30f6738501559acf863e2c5c/src/transformers/training_args.py#L1563
# Copy of https://github.com/huggingface/transformers/blob/b71f20a7c9f3716d30f6738501559acf863e2c5c/src/transformers/training_args.py#L1563
# except following changes, (1) Remove XLA specific code & (2) change fsdp_backward_prefetch to backward_prefetch
if isinstance(self.fsdp, bool):
self.fsdp = "full_shard" if self.fsdp else ""
Expand Down Expand Up @@ -597,8 +598,9 @@ def __post_init__(self):
os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefect", "false"))
os.environ[f"{prefix}SYNC_MODULE_STATES"] = str(self.fsdp_config.get("sync_module_states", "true"))
os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "false"))
os.environ[f"{prefix}ACTIVATION_CHECKPOINTING"] = str(self.fsdp_config.get("activation_checkpointing", "false"))

os.environ[f"{prefix}ACTIVATION_CHECKPOINTING"] = str(
self.fsdp_config.get("activation_checkpointing", "false")
)

if isinstance(self.debug, str):
self.debug = [DebugOption(s) for s in self.debug.split()]
Expand Down
17 changes: 13 additions & 4 deletions tests/test_fsdp_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@
# Gaudi2 CI baselines
MODELS_TO_TEST = {
"bf16": [
("bert-base-uncased", "Habana/bert-base-uncased", 2807, 85.4688, "question-answering", 24, 8, "run_qa.py", "full_shard"),
(
"bert-base-uncased",
"Habana/bert-base-uncased",
2807,
85.4688,
"question-answering",
24,
8,
"run_qa.py",
"full_shard",
),
],
}

Expand All @@ -31,8 +41,8 @@ def _test_fsdp(
world_size: int = 8,
):
os.environ["PT_HPU_LAZY_MODE"] = "0"
os.environ["PT_HPU_EAGER_4_STAGE_PIPELINE_ENABLE"] = "0" #To be removed later
os.environ["PT_HPU_EAGER_PIPELINE_ENABLE"] = "0" #To be removed later
os.environ["PT_HPU_EAGER_4_STAGE_PIPELINE_ENABLE"] = "0" # To be removed later
os.environ["PT_HPU_EAGER_PIPELINE_ENABLE"] = "0" # To be removed later
path_to_example_dir = Path(__file__).resolve().parent.parent / "examples"

# Install question-answering example requirements
Expand All @@ -43,7 +53,6 @@ def _test_fsdp(

command = ["python3"]


command += [
f"{path_to_example_dir / 'gaudi_spawn.py'}",
"--use_mpi",
Expand Down

0 comments on commit cd026a1

Please sign in to comment.