Skip to content

Commit

Permalink
Adding FSDP Support to Training Library (#213)
Browse files Browse the repository at this point in the history
Adds support for FSDP and FSDP w/ CPU Offloading.

Introduces accelerate as a distributed backend abstraction (for FSDP/DeepSpeed)
Also fixes mistral template and cleans up data processing.

---------

Signed-off-by: aldo pareja-cardona <[email protected]>
Signed-off-by: Oleg S <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Co-authored-by: aldo pareja-cardona <[email protected]>
Co-authored-by: Oleg S <[email protected]>
Co-authored-by: Mustafa Eyceoz <[email protected]>
  • Loading branch information
4 people authored Sep 26, 2024
1 parent b37c8ce commit 7b7fa12
Show file tree
Hide file tree
Showing 12 changed files with 660 additions and 333 deletions.
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,8 @@ disable=raw-checker-failed,
consider-using-generator,
broad-exception-caught,
super-init-not-called,
duplicate-code
duplicate-code,
too-many-positional-arguments

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
23 changes: 22 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ Here is a breakdown of the general options:
| mock_data_len | Max length of a single mock data sample. Equivalent to `max_seq_len` but for mock data. |
| deepspeed_options | Config options to specify for the DeepSpeed optimizer. |
| lora | Options to specify if you intend to perform a LoRA train instead of a full fine-tune. |
| chat_tmpl_path | Specifies the chat template / special tokens for training. |
| checkpoint_at_epoch | Whether or not we should save a checkpoint at the end of each epoch. |
| fsdp_options | The settings for controlling FSDP when it's selected as the distributed backend. |
| distributed_backend | Specifies which distributed training backend to use. Supported options are "fsdp" and "deepspeed". |
| disable_flash_attn | Disables flash attention when set to true. This allows for training on older devices. |

#### `DeepSpeedOptions`

Expand All @@ -141,8 +146,24 @@ allow you to customize aspects of the ZeRO stage 2 optimizer.
| Field | Description |
| --- | --- |
| cpu_offload_optimizer | Whether or not to do CPU offloading in DeepSpeed stage 2. |
| cpu_offload_optimizer_ratio | Floating point between 0 & 1. Specifies the ratio of parameters updating (i.e. optimizer step) on CPU side. |
| cpu_offload_optimizer_pin_memory | If true, offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead. |
| save_samples | The number of samples to see before saving a DeepSpeed checkpoint. |

#### `loraOptions`
#### `FSDPOptions`

Like DeepSpeed, we only expose a number of parameters for you to modify with FSDP.
They are listed below:

| Field | Description |
| --- | --- |
| cpu_offload_params | When set to true, offload parameters from the accelerator onto the CPU. This is an all-or-nothing option. |
| sharding_strategy | Specifies the model sharding strategy that FSDP should use. Valid options are: `FULL_SHARD` (ZeRO-3), `HYBRID_SHARD` (ZeRO-3*), `SHARD_GRAD_OP` (ZeRO-2), and `NO_SHARD`. |

> [!NOTE]
> For `sharding_strategy` - Only `SHARD_GRAD_OP` has been extensively tested and is actively supported by this library.
#### `LoraOptions`

If you'd like to do a LoRA train, you can specify a LoRA
option to `TrainingArgs` via the `LoraOptions` object.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ py-cpuinfo
# replace custom pytorch images with the 2.3.0
torch>=2.3.0a0
transformers>=4.41.2
accelerate>=0.34.2
datasets>=2.15.0
numba
# Note: numpy ranges copied from instructlab/instructlab
Expand Down
6 changes: 6 additions & 0 deletions src/instructlab/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@
"TorchrunArgs",
"TrainingArgs",
"run_training",
"FSDPOptions",
"ShardingStrategies",
"DistributedBackend",
)

# Local
from .config import (
DataProcessArgs,
DeepSpeedOffloadStrategy,
DeepSpeedOptions,
DistributedBackend,
FSDPOptions,
LoraOptions,
QuantizeDataType,
ShardingStrategies,
TorchrunArgs,
TrainingArgs,
)
Expand Down
13 changes: 7 additions & 6 deletions src/instructlab/training/chat_templates/ibm_generic_tmpl.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# SPDX-License-Identifier: Apache-2.0

# First Party
from instructlab.training.tokenizer_utils import SpecialTokens
from instructlab.training.tokenizer_utils import SpecialTokens, TokenInfo

SPECIAL_TOKENS = SpecialTokens(
system="<|system|>",
user="<|user|>",
assistant="<|assistant|>",
eos="<|endoftext|>",
pad="<|pad|>",
system=TokenInfo("<|system|>", add_to_tokenizer=True),
user=TokenInfo("<|user|>", add_to_tokenizer=True),
assistant=TokenInfo("<|assistant|>", add_to_tokenizer=True),
eos=TokenInfo("<|endoftext|>", add_to_tokenizer=True),
pad=TokenInfo("<|pad|>", add_to_tokenizer=True),
bos=TokenInfo("<|begginingoftext|>", add_to_tokenizer=True),
)

CHAT_TEMPLATE = (
Expand Down
43 changes: 29 additions & 14 deletions src/instructlab/training/chat_templates/mistral_tmpl.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,39 @@
# SPDX-License-Identifier: Apache-2.0

# First Party
from instructlab.training.tokenizer_utils import SpecialTokens
from instructlab.training.tokenizer_utils import SpecialTokens, TokenInfo

SPECIAL_TOKENS = SpecialTokens(
bos="<s>",
eos="</s>",
user="[INST]",
assistant="[/INST]",
bos=TokenInfo("<s>", add_to_tokenizer=True),
eos=TokenInfo("</s>", add_to_tokenizer=True),
user=TokenInfo("[INST]", add_to_tokenizer=False),
assistant=TokenInfo("[/INST]", add_to_tokenizer=False),
)

CHAT_TEMPLATE = (
"{%- if messages[0]['role'] == 'system' %}"
"{%- set system_message = messages[0]['content'] %}"
"{%- set loop_messages = messages[1:] %}"
"{%- else %}"
"{%- set loop_messages = messages %}"
"{%- endif %}"
"{{ '<s>' }}"
"{% for message in messages %}"
"{% if message['role'] == 'pretraining' %}"
"{{'<|pretrain|>' + message['content'] + '</s>' + '<|/pretrain|>'}}"
"{% elif message['role'] == 'user' %}"
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ message['content'] + '</s>'}}"
"{% endif %}"
"{% endfor %}"
"{%- for message in loop_messages %}"
"{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
"{{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}"
"{%- endif %}"
"{%- if message['role'] == 'user' %}"
"{%- if loop.first and system_message is defined %}"
"{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}"
"{%- else %}"
"{{- ' [INST] ' + message['content'] + ' [/INST]' }}"
"{%- endif %}"
"{%- elif message['role'] == 'pretraining' %}"
"{{- '<|pretrain|>' + message['content'] + '</s>' + '<|/pretrain|>' }}"
"{%- elif message['role'] == 'assistant' %}"
"{{- ' ' + message['content'] + '</s>'}}"
"{%- else %}"
"{{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}"
"{%- endif %}"
"{%- endfor %}"
)
30 changes: 30 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ class DeepSpeedOffloadStrategy(Enum):
NONE = None


# public API
class DistributedBackend(Enum):
FSDP: str = "fsdp"
DEEPSPEED: str = "deepspeed"


# public API
class QuantizeDataType(Enum):
"""
Expand Down Expand Up @@ -111,6 +117,24 @@ class DeepSpeedOptions(BaseModel):
save_samples: int | None = None


# public API
class ShardingStrategies(Enum):
FULL_SHARD = "FULL_SHARD"
SHARD_GRAD_OP = "SHARD_GRAD_OP"
NO_SHARD = "NO_SHARD"
HYBRID_SHARD = "HYBRID_SHARD"


# public API
class FSDPOptions(BaseModel):
"""
Represents the options for configuring FSDP which are exposed by the Training Library
"""

cpu_offload_params: Optional[bool] = False
sharding_strategy: ShardingStrategies = ShardingStrategies.SHARD_GRAD_OP


# public API
class TrainingArgs(BaseModel):
"""
Expand Down Expand Up @@ -157,6 +181,12 @@ class TrainingArgs(BaseModel):
cpu_offload_optimizer_pin_memory=False,
)
)
fsdp_options: FSDPOptions = Field(
default_factory=lambda: FSDPOptions(
cpu_offload_params=False, sharding_strategy=ShardingStrategies.SHARD_GRAD_OP
)
)
distributed_backend: DistributedBackend = DistributedBackend.DEEPSPEED

disable_flash_attn: Optional[bool] = False

Expand Down
Loading

0 comments on commit 7b7fa12

Please sign in to comment.