Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add learning rate scheduling support for DeepSpeedStrategy #20320

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
188a45f
Update fabric.py
amorehead Oct 5, 2024
baf5988
Update deepspeed.py
amorehead Oct 5, 2024
1f4c18e
Update deepspeed.py
amorehead Oct 5, 2024
585e302
Update fabric.py
amorehead Oct 5, 2024
0451761
Update fsdp.py
amorehead Oct 5, 2024
a912aab
Update strategy.py
amorehead Oct 5, 2024
d27d4a3
Update strategy.py
amorehead Oct 5, 2024
67089a1
Update xla_fsdp.py
amorehead Oct 5, 2024
1025875
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 5, 2024
9b45b99
Update fsdp.py
amorehead Oct 5, 2024
a7a5835
Update strategy.py
amorehead Oct 5, 2024
3ece31c
Update xla_fsdp.py
amorehead Oct 5, 2024
e48acd2
Update deepspeed.py
amorehead Oct 5, 2024
f13516d
Update seed.py
amorehead Oct 28, 2024
80b4a6d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2024
2cab7e2
Update seed.py
amorehead Oct 28, 2024
e9127f4
Update seed.py
amorehead Oct 28, 2024
c127458
Update seed.py
amorehead Oct 28, 2024
31a1fce
Merge branch 'master' into patch-2
lantiga Nov 12, 2024
f215626
Merge branch 'master' into patch-2
amorehead Nov 15, 2024
dfce07e
Merge branch 'master' into patch-2
lantiga Nov 25, 2024
25e8d48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2024
737162d
Update src/lightning/fabric/strategies/deepspeed.py
lantiga Nov 25, 2024
2d347d0
Update src/lightning/fabric/strategies/fsdp.py
lantiga Nov 25, 2024
5d227ff
Update src/lightning/fabric/strategies/strategy.py
lantiga Nov 25, 2024
f94efa7
Update src/lightning/fabric/strategies/xla_fsdp.py
lantiga Nov 25, 2024
c2613ec
Merge branch 'Lightning-AI:master' into patch-2
amorehead Jan 9, 2025
56464ed
Update deepspeed.py
amorehead Jan 9, 2025
e09941c
Update fabric.py
amorehead Jan 9, 2025
3709f1d
Update fabric_methods.rst
amorehead Jan 10, 2025
13195a2
Update wrappers.rst
amorehead Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source-fabric/api/fabric_methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,17 @@ Moves the model and optimizer to the correct device automatically.

model = nn.Linear(32, 64)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.3, total_iters=10)

# Set up model and optimizer for accelerated training
model, optimizer = fabric.setup(model, optimizer)

# If you don't want Fabric to set the device
model, optimizer = fabric.setup(model, optimizer, move_to_device=False)

# If you want to additionally register a learning rate scheduler with compatible strategies such as DeepSpeed
model, optimizer, scheduler = fabric.setup(model, optimizer, scheduler)


The setup method also prepares the model for the selected precision choice so that operations during ``forward()`` get
cast automatically. Advanced users should read :doc:`the notes on models wrapped by Fabric <../api/wrappers>`.
Expand Down
2 changes: 1 addition & 1 deletion docs/source-fabric/api/wrappers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ If you were to run this model in Fabric with multiple devices (DDP or FSDP), you
# OK: Calling the model directly
output = model(torch.randn(10))

# OK: Calling the model's forward (equivalent to the abvoe)
# OK: Calling the model's forward (equivalent to the above)
output = model.forward(torch.randn(10))

# ERROR: Calling another method that calls forward indirectly
Expand Down
12 changes: 8 additions & 4 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from lightning_utilities.core.overrides import is_overridden
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler

import lightning.fabric
Expand Down Expand Up @@ -206,6 +207,7 @@ def setup(
self,
module: nn.Module,
*optimizers: Optimizer,
scheduler: Optional[_LRScheduler] = None,
move_to_device: bool = True,
_reapply_compile: bool = True,
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
Expand All @@ -214,6 +216,7 @@ def setup(
Args:
module: A :class:`torch.nn.Module` to set up
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
scheduler: The learning rate scheduler to set up (no learning rate scheduler is also possible)
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
and alternatively use :meth:`to_device` manually.
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
Expand All @@ -222,7 +225,8 @@ def setup(
FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.

Returns:
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
The tuple containing wrapped module, optimizers, and an optional learning rate scheduler,
in the same order they were passed in.

"""
self._validate_setup(module, optimizers)
Expand All @@ -236,8 +240,8 @@ def setup(

# Let accelerator/plugin wrap and connect the models and optimizers
if optimizers:
module, optimizers = self._strategy.setup_module_and_optimizers( # type: ignore[assignment]
module, list(optimizers)
module, optimizers, scheduler = self._strategy.setup_module_and_optimizers( # type: ignore[assignment]
module, list(optimizers), scheduler
)
else:
module = self._strategy.setup_module(module)
Expand Down Expand Up @@ -266,7 +270,7 @@ def setup(

if optimizers:
# join both types in a tuple for API convenience
return (module, *optimizers)
return (module, *optimizers, scheduler) if scheduler is not None else (module, *optimizers)
return module

def setup_module(
Expand Down
31 changes: 15 additions & 16 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from lightning_utilities.core.imports import RequirementCache
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from typing_extensions import override

from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
Expand Down Expand Up @@ -316,25 +317,24 @@ def model(self) -> "DeepSpeedEngine":

@override
def setup_module_and_optimizers(
self, module: Module, optimizers: list[Optimizer]
) -> tuple["DeepSpeedEngine", list[Optimizer]]:
"""Set up a model and multiple optimizers together.

Currently, only a single optimizer is supported.
self, module: Module, optimizers: list[Optimizer], scheduler: Optional[_LRScheduler] = None
) -> tuple["DeepSpeedEngine", list[Optimizer], Any]:
"""Set up a model and multiple optimizers together, along with an optional learning rate scheduler. Currently,
only a single optimizer is supported.

Return:
The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
deepspeed optimizer.
The model wrapped into a :class:`deepspeed.DeepSpeedEngine`, a list with a single
deepspeed optimizer, and an optional learning rate scheduler.

"""
if len(optimizers) != 1:
raise ValueError(
f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead."
)

self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0])
self._deepspeed_engine, optimizer, scheduler = self._initialize_engine(module, optimizers[0], scheduler)
self._set_deepspeed_activation_checkpointing()
return self._deepspeed_engine, [optimizer]
return self._deepspeed_engine, [optimizer], scheduler

@override
def setup_module(self, module: Module) -> "DeepSpeedEngine":
Expand All @@ -343,7 +343,7 @@ def setup_module(self, module: Module) -> "DeepSpeedEngine":
For training, see :meth:`setup_module_and_optimizers`.

"""
self._deepspeed_engine, _ = self._initialize_engine(module)
self._deepspeed_engine, _, _ = self._initialize_engine(module)
return self._deepspeed_engine

@override
Expand Down Expand Up @@ -596,10 +596,8 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
)

def _initialize_engine(
self,
model: Module,
optimizer: Optional[Optimizer] = None,
) -> tuple["DeepSpeedEngine", Optimizer]:
self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None
) -> tuple["DeepSpeedEngine", Optimizer, Any]:
"""Initialize one model and one optimizer with an optional learning rate scheduler.

This calls ``deepspeed.initialize`` internally.
Expand All @@ -608,15 +606,16 @@ def _initialize_engine(
import deepspeed

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
args=argparse.Namespace(device_rank=self.root_device.index),
config=self.config,
model=model,
model_parameters=model_parameters,
optimizer=optimizer,
lr_scheduler=scheduler,
dist_init_required=False,
)
return deepspeed_engine, deepspeed_optimizer
return deepspeed_engine, deepspeed_optimizer, deepspeed_scheduler

@override
def setup_environment(self) -> None:
Expand Down
7 changes: 4 additions & 3 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from typing_extensions import TypeGuard, override

from lightning.fabric.accelerators import Accelerator
Expand Down Expand Up @@ -261,8 +262,8 @@ def setup_environment(self) -> None:

@override
def setup_module_and_optimizers(
self, module: Module, optimizers: list[Optimizer]
) -> tuple[Module, list[Optimizer]]:
self, module: Module, optimizers: list[Optimizer], scheduler: Optional[_LRScheduler] = None
) -> tuple[Module, list[Optimizer], Optional[_LRScheduler]]:
"""Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel`
module and sets `use_orig_params=True` to keep the reference to the original parameters in the optimizer."""
use_orig_params = self._fsdp_kwargs.get("use_orig_params")
Expand All @@ -274,7 +275,7 @@ def setup_module_and_optimizers(
" call `setup_optimizer`."
)
module = self.setup_module(module)
return module, optimizers
return module, optimizers, scheduler

@override
def setup_module(self, module: Module) -> Module:
Expand Down
7 changes: 4 additions & 3 deletions src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

from lightning.fabric.accelerators import Accelerator
Expand Down Expand Up @@ -145,8 +146,8 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractCont
return stack

def setup_module_and_optimizers(
self, module: Module, optimizers: list[Optimizer]
) -> tuple[Module, list[Optimizer]]:
self, module: Module, optimizers: list[Optimizer], scheduler: Optional[_LRScheduler] = None
) -> tuple[Module, list[Optimizer], Optional[_LRScheduler]]:
"""Set up a model and multiple optimizers together.

The returned objects are expected to be in the same order they were passed in. The default implementation will
Expand All @@ -155,7 +156,7 @@ def setup_module_and_optimizers(
"""
module = self.setup_module(module)
optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers]
return module, optimizers
return module, optimizers, scheduler

def setup_module(self, module: Module) -> Module:
"""Performs setup for the model, e.g., by wrapping it by another class."""
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from typing_extensions import override

Expand Down Expand Up @@ -196,8 +197,8 @@ def setup_environment(self) -> None:

@override
def setup_module_and_optimizers(
self, module: Module, optimizers: list[Optimizer]
) -> tuple[Module, list[Optimizer]]:
self, module: Module, optimizers: list[Optimizer], scheduler: Optional[_LRScheduler] = None
) -> tuple[Module, list[Optimizer], Optional[_LRScheduler]]:
"""Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup."""
raise NotImplementedError(
f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)."
Expand Down
Loading