Skip to content

Commit

Permalink
Add basic FSDP2 to MNIST
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Jan 20, 2025
1 parent 86d2a0d commit c4e665a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
6 changes: 3 additions & 3 deletions algorithmic_efficiency/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@


def pytorch_setup() -> Tuple[bool, int, torch.device, int]:
use_pytorch_ddp = 'LOCAL_RANK' in os.environ
rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0
use_pytorch_fsdp2 = 'LOCAL_RANK' in os.environ
rank = int(os.environ['LOCAL_RANK']) if use_pytorch_fsdp2 else 0
device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')
n_gpus = torch.cuda.device_count()
return use_pytorch_ddp, rank, device, n_gpus
return use_pytorch_fsdp2, rank, device, n_gpus


def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None:
Expand Down
17 changes: 7 additions & 10 deletions algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from torch import nn
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed._composable.fsdp import fully_shard

from algorithmic_efficiency import init_utils
from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
from algorithmic_efficiency.pytorch_utils import pytorch_setup
from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload

USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()
USE_PYTORCH_FSDP2, RANK, DEVICE, N_GPUS = pytorch_setup()


class _Model(nn.Module):
Expand Down Expand Up @@ -82,7 +82,7 @@ def _build_input_queue(
else:
weights = torch.ones_like(targets, dtype=torch.bool, device=DEVICE)
# Send batch to other devices when using DDP.
if USE_PYTORCH_DDP:
if USE_PYTORCH_FSDP2:
dist.broadcast(inputs, src=0)
inputs = inputs[0]
dist.broadcast(targets, src=0)
Expand Down Expand Up @@ -127,10 +127,7 @@ def init_model_fn(
del aux_dropout_rate

if hasattr(self, '_model'):
if isinstance(self._model, (DDP, torch.nn.DataParallel)):
self._model.module.reset_parameters()
else:
self._model.reset_parameters()
self._model.reset_parameters()
return self._model, None

torch.random.manual_seed(rng[0])
Expand All @@ -139,8 +136,8 @@ def init_model_fn(
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
self._model.to(DEVICE)
if N_GPUS > 1:
if USE_PYTORCH_DDP:
self._model = DDP(self._model, device_ids=[RANK], output_device=RANK)
if USE_PYTORCH_FSDP2:
self._model = fully_shard(self._model)
else:
self._model = torch.nn.DataParallel(self._model)
return self._model, None
Expand Down Expand Up @@ -229,7 +226,7 @@ def _eval_model(
def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str,
Any]) -> Dict[str, float]:
if USE_PYTORCH_DDP:
if USE_PYTORCH_FSDP2:
for metric in total_metrics.values():
dist.all_reduce(metric)
return {k: float(v.item() / num_examples) for k, v in total_metrics.items()}

0 comments on commit c4e665a

Please sign in to comment.