Skip to content

Commit

Permalink
Enable CPU Offload for Intel GPU (#1324)
Browse files Browse the repository at this point in the history
* feat(cpu-offload): enable CPU Offload for XPU

Signed-off-by: dbyoung18 <[email protected]>

* test(cpu-offload): enable benchmark_low_bit_adam for XPU

Signed-off-by: dbyoung18 <[email protected]>

* fix(cpu-offload): auto-detect ProfilerActivity

Signed-off-by: dbyoung18 <[email protected]>

* fix(cpu-offload): replace if-else w/ getattr for device API calls

Signed-off-by: dbyoung18 <[email protected]>

* fix(cpu-offload): add auto-detect available devices to utils

Signed-off-by: dbyoung18 <[email protected]>

* fix(cpu-offload): improve auto-detect ProfilerActivity

Signed-off-by: dbyoung18 <[email protected]>

* fix(cpu-offload): improve device assert

Signed-off-by: dbyoung18 <[email protected]>

* fix(cpu-offload): fix auto-detect mps

Signed-off-by: dbyoung18 <[email protected]>

* fix(cpu-offload): fix import order

Signed-off-by: dbyoung18 <[email protected]>

* refactor(cpu-offload): use ruff format

Signed-off-by: dbyoung18 <[email protected]>

* doc(cpu-offload): modify README to cover XPU

Signed-off-by: dbyoung18 <[email protected]>

---------

Signed-off-by: dbyoung18 <[email protected]>
  • Loading branch information
dbyoung18 authored Nov 26, 2024
1 parent 6312329 commit 6ff3904
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 72 deletions.
88 changes: 63 additions & 25 deletions benchmarks/benchmark_low_bit_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# - lpmm (4-bit optim): pip install yacs git+https://github.com/thu-ml/low-bit-optimizers.git
# - DeepSpeed (ZeRO-Offload):
# sudo apt install libopenmpi-dev
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4p
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4py
# DS_BUILD_CPU_ADAM=1 pip install deepspeed --no-cache-dir
#
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default AdamW optimizer from PyTorch core
Expand All @@ -31,11 +31,15 @@
import torch.nn.functional as F
import wandb
from torch.utils.data import DataLoader
from torchao.utils import get_available_devices
from torchvision.transforms import v2
from tqdm import tqdm

from torchao.prototype import low_bit_optim

_DEVICE = get_available_devices()[-1]
assert _DEVICE in ["cuda", "xpu"], "Benchmark currently only supports CUDA & XPU(BF16)"

OPTIM_MAP = dict(
AdamW=partial(torch.optim.AdamW, fused=True),
AdamW8bitBnb=bnb.optim.AdamW8bit,
Expand All @@ -49,7 +53,9 @@

OPTIM_MAP.update(
AdamW4bitLpmm=partial(lpmm.optim.AdamW, fused=True),
AdamW4bitRank1Lpmm=partial(lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")),
AdamW4bitRank1Lpmm=partial(
lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")
),
)

except ImportError:
Expand All @@ -67,8 +73,12 @@ def get_lr(self, step: int) -> float:
if step < self.warmup_steps:
return self.lr * step / self.warmup_steps
if step < self.total_steps:
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi))
progress = (step - self.warmup_steps) / (
self.total_steps - self.warmup_steps
)
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (
1 + math.cos(progress * math.pi)
)
return self.final_lr


Expand All @@ -92,7 +102,9 @@ def get_parser():
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--optim_kwargs", type=json.loads, default=dict())
parser.add_argument("--cosine_lr_scheduler", action="store_true")
parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"])
parser.add_argument(
"--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"]
)

parser.add_argument("--project")
parser.add_argument("--run_name", default="debug")
Expand All @@ -110,11 +122,15 @@ def get_dloader(args, training: bool):
transforms.extend([v2.Resize(256), v2.CenterCrop(224)])

transforms.append(v2.ToDtype(torch.float32, scale=True))
transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
transforms.append(
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
transforms = v2.Compose(transforms)

# use dataset from HF so download is fast
ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation")
ds = datasets.load_dataset(
"timm/resisc45", split="train" if training else "validation"
)
ds = ds.select_columns(["image", "label"])
ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"]))

Expand All @@ -128,9 +144,9 @@ def get_dloader(args, training: bool):
)


def get_amp_ctx(amp):
def get_amp_ctx(amp, device):
dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp]
return torch.autocast("cuda", dtype=dtype, enabled=amp != "none")
return torch.autocast(device, dtype=dtype, enabled=amp != "none")


@torch.no_grad()
Expand All @@ -148,8 +164,8 @@ def evaluate_model(model, args):
if args.channels_last:
batch["image"] = batch["image"].to(memory_format=torch.channels_last)

with get_amp_ctx(args.amp):
all_preds.append(model(batch["image"].cuda()).argmax(1).cpu())
with get_amp_ctx(args.amp, _DEVICE):
all_preds.append(model(batch["image"].to(_DEVICE)).argmax(1).cpu())

all_labels = torch.cat(all_labels, dim=0)
all_preds = torch.cat(all_preds, dim=0)
Expand All @@ -164,8 +180,12 @@ def evaluate_model(model, args):
if args.full_bf16:
assert args.amp == "none", "When --full_bf16 is set, --amp must be none"
if args.optim_cpu_offload == "deepspeed":
assert args.amp == "none", "When using DeepSpeed ZeRO-Offload, --amp must be none"
assert args.optim == "AdamW", "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
assert (
args.amp == "none"
), "When using DeepSpeed ZeRO-Offload, --amp must be none"
assert (
args.optim == "AdamW"
), "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
if args.profile:
args.n_epochs = 1
if args.seed is not None:
Expand All @@ -185,14 +205,16 @@ def evaluate_model(model, args):
dloader = get_dloader(args, True)
print(f"Train dataset: {len(dloader.dataset):,} images")

model = timm.create_model(args.model, pretrained=True, num_classes=45, **args.model_kwargs)
model = timm.create_model(
args.model, pretrained=True, num_classes=45, **args.model_kwargs
)
if args.checkpoint_activations:
model.set_grad_checkpointing()
if args.full_bf16:
model.bfloat16()
if args.channels_last:
model.to(memory_format=torch.channels_last)
model.cuda() # move model to CUDA after optionally convert it to BF16
model.to(_DEVICE) # move model to DEVICE after optionally convert it to BF16
if args.compile:
model.compile(fullgraph=True)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Expand Down Expand Up @@ -227,9 +249,15 @@ def evaluate_model(model, args):
optim_cls = OPTIM_MAP[args.optim]

if args.optim_cpu_offload == "ao":
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls)
optim_cls = partial(
low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls
)
elif args.optim_cpu_offload == "ao_offload_grads":
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True)
optim_cls = partial(
low_bit_optim.CPUOffloadOptimizer,
optimizer_class=optim_cls,
offload_gradients=True,
)

optim = optim_cls(
model.parameters(),
Expand All @@ -239,24 +267,30 @@ def evaluate_model(model, args):
)

lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
grad_scaler = torch.amp.GradScaler(_DEVICE, enabled=args.amp == "fp16")
log_interval = 10
t0 = time.perf_counter()

step = 0
for epoch_idx in range(args.n_epochs):
model.train()
pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}")
pbar = tqdm(
dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"
)

with torch.profiler.profile() if args.profile else nullcontext() as prof:
for batch in pbar:
if args.full_bf16:
batch["image"] = batch["image"].bfloat16()
if args.channels_last:
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
batch["image"] = batch["image"].to(
memory_format=torch.channels_last
)

with get_amp_ctx(args.amp):
loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda())
with get_amp_ctx(args.amp, _DEVICE):
loss = F.cross_entropy(
model(batch["image"].to(_DEVICE)), batch["label"].to(_DEVICE)
)

if args.optim_cpu_offload == "deepspeed":
model.backward(loss)
Expand All @@ -275,7 +309,9 @@ def evaluate_model(model, args):
log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"])
if step > 0:
t1 = time.perf_counter()
log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0)
log_dict["imgs_per_second"] = (
args.batch_size * log_interval / (t1 - t0)
)
t0 = t1
logger.log(log_dict, step=step)

Expand All @@ -296,9 +332,11 @@ def evaluate_model(model, args):

else:
val_acc = evaluate_model(model, args)
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
print(
f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}"
)
logger.log(dict(val_acc=val_acc), step=step)

peak_mem = torch.cuda.max_memory_allocated() / 1e9
peak_mem = getattr(torch, _DEVICE).max_memory_allocated() / 1e9
print(f"Max memory used: {peak_mem:.02f} GB")
logger.log(dict(max_memory_allocated=peak_mem))
49 changes: 36 additions & 13 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
from torchao.utils import (
get_available_devices,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
Expand All @@ -42,7 +43,7 @@
lpmm = None


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_DEVICES = get_available_devices()


class TestQuantize(TestCase):
Expand Down Expand Up @@ -94,7 +95,9 @@ def test_bf16_stochastic_round(self, device, compile):
x = torch.rand(32, device=device) * 100
x_rep = x.view(-1, 1).repeat(1, 100_000)

func = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile)
func = torch.compile(
_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile
)
x_rep_bf16 = func(x_rep)
assert x_rep_bf16.dtype is torch.bfloat16

Expand Down Expand Up @@ -169,8 +172,13 @@ def test_subclass_slice(self, subclass, shape, device):
tensor = subclass.zeros(shape, device=device)
offset = shape[0] // 2

torch.testing.assert_close(tensor.dequantize()[:offset], tensor[:offset].dequantize())
torch.testing.assert_close(tensor.dequantize()[offset:offset*2], tensor[offset:offset*2].dequantize())
torch.testing.assert_close(
tensor.dequantize()[:offset], tensor[:offset].dequantize()
)
torch.testing.assert_close(
tensor.dequantize()[offset : offset * 2],
tensor[offset : offset * 2].dequantize(),
)

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
@pytest.mark.skipif(
Expand All @@ -188,7 +196,9 @@ def test_optim_8bit_correctness(self, optim_name):
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
optim2 = getattr(low_bit_optim, optim_name)(
model2.parameters(), block_size=block_size
)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand Down Expand Up @@ -244,11 +254,12 @@ def test_optim_4bit_correctness(self, optim_name):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.skipif(
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
not torch.cuda.is_available() and not torch.xpu.is_available(),
reason="optim CPU offload requires CUDA or XPU",
)
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
device = "cuda"
device = _DEVICES[-1]
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)

Expand Down Expand Up @@ -279,13 +290,16 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
torch.testing.assert_close(p2, p1)

@pytest.mark.skipif(
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
not torch.cuda.is_available() and not torch.xpu.is_available(),
reason="optim CPU offload requires CUDA or XPU",
)
def test_optim_cpu_offload_save_load(self):
device = "cuda"
device = _DEVICES[-1]
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)
optim1 = low_bit_optim.CPUOffloadOptimizer(
model1.parameters(), torch.optim.AdamW
)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand All @@ -300,7 +314,9 @@ def test_optim_cpu_offload_save_load(self):

# resume training
model2 = copy.deepcopy(model1)
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
optim2 = low_bit_optim.CPUOffloadOptimizer(
model2.parameters(), torch.optim.AdamW
)
optim2.load_state_dict(state_dict)

for _ in range(2):
Expand Down Expand Up @@ -384,7 +400,11 @@ def _test_fsdp2(self, optim_cls):
import torch.utils._pytree as pytree
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.tensor import DTensor
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
TransformerBlock,
)

batch_size = 3
vocab_size = 1024
Expand Down Expand Up @@ -457,7 +477,10 @@ def _test_fsdp2(self, optim_cls):

subclasses = (OptimState4bit, OptimState8bit, OptimStateFp8)

for v1, v2 in zip(pytree.tree_iter(resumed_fsdp_optim.state_dict()), pytree.tree_iter(fsdp_optim.state_dict())):
for v1, v2 in zip(
pytree.tree_iter(resumed_fsdp_optim.state_dict()),
pytree.tree_iter(fsdp_optim.state_dict()),
):
assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__)
if isinstance(v1, DTensor):
v1 = v1.to_local()
Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ All of our low-bit optimizers mentioned above also support `bf16_stochastic_roun

## Optimizer CPU offload

This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA is supported. For multi-GPU training, you can use FSDP's built-in CPU offload.
This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA and XPU is supported. For multi-GPU training, you can use FSDP's built-in CPU offload.

```python
import torch
Expand All @@ -97,7 +97,7 @@ optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, offload_gradi

This will reduce GPU memory usage by optimizer state size, and additionally gradient size if `offload_gradients=True`. `CPUOffloadOptimizer` can wrap any base optimizer.

For saving and loading `CPUOffloadOptimizer`, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside `CPUOffloadOptimizer.__init__()`. (TODO: we might want to have a method to synchronize CUDA and CPU params in either direction CPU->CUDA and CUDA->CPU, in case they are out of sync.)
For saving and loading `CPUOffloadOptimizer`, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside `CPUOffloadOptimizer.__init__()`. (TODO: we might want to have a method to synchronize GPU and CPU params in either direction CPU->GPU and GPU->CPU, in case they are out of sync.)

```python
ckpt = torch.load("checkpoint.pth")
Expand Down
Loading

0 comments on commit 6ff3904

Please sign in to comment.