From 6ff3904ad01abb4c1be517578f0c419be79bb838 Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Tue, 26 Nov 2024 11:23:10 +0800 Subject: [PATCH] Enable CPU Offload for Intel GPU (#1324) * feat(cpu-offload): enable CPU Offload for XPU Signed-off-by: dbyoung18 * test(cpu-offload): enable benchmark_low_bit_adam for XPU Signed-off-by: dbyoung18 * fix(cpu-offload): auto-detect ProfilerActivity Signed-off-by: dbyoung18 * fix(cpu-offload): replace if-else w/ getattr for device API calls Signed-off-by: dbyoung18 * fix(cpu-offload): add auto-detect available devices to utils Signed-off-by: dbyoung18 * fix(cpu-offload): improve auto-detect ProfilerActivity Signed-off-by: dbyoung18 * fix(cpu-offload): improve device assert Signed-off-by: dbyoung18 * fix(cpu-offload): fix auto-detect mps Signed-off-by: dbyoung18 * fix(cpu-offload): fix import order Signed-off-by: dbyoung18 * refactor(cpu-offload): use ruff format Signed-off-by: dbyoung18 * doc(cpu-offload): modify README to cover XPU Signed-off-by: dbyoung18 --------- Signed-off-by: dbyoung18 --- benchmarks/benchmark_low_bit_adam.py | 88 +++++++++++++------ test/prototype/test_low_bit_optim.py | 49 ++++++++--- torchao/prototype/low_bit_optim/README.md | 4 +- .../prototype/low_bit_optim/cpu_offload.py | 69 ++++++++------- torchao/utils.py | 13 +++ 5 files changed, 151 insertions(+), 72 deletions(-) diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index bd31193892..986cc58b4f 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -4,7 +4,7 @@ # - lpmm (4-bit optim): pip install yacs git+https://github.com/thu-ml/low-bit-optimizers.git # - DeepSpeed (ZeRO-Offload): # sudo apt install libopenmpi-dev -# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4p +# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4py # DS_BUILD_CPU_ADAM=1 pip install deepspeed --no-cache-dir # # To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default AdamW optimizer from PyTorch core @@ -31,11 +31,15 @@ import torch.nn.functional as F import wandb from torch.utils.data import DataLoader +from torchao.utils import get_available_devices from torchvision.transforms import v2 from tqdm import tqdm from torchao.prototype import low_bit_optim +_DEVICE = get_available_devices()[-1] +assert _DEVICE in ["cuda", "xpu"], "Benchmark currently only supports CUDA & XPU(BF16)" + OPTIM_MAP = dict( AdamW=partial(torch.optim.AdamW, fused=True), AdamW8bitBnb=bnb.optim.AdamW8bit, @@ -49,7 +53,9 @@ OPTIM_MAP.update( AdamW4bitLpmm=partial(lpmm.optim.AdamW, fused=True), - AdamW4bitRank1Lpmm=partial(lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")), + AdamW4bitRank1Lpmm=partial( + lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1") + ), ) except ImportError: @@ -67,8 +73,12 @@ def get_lr(self, step: int) -> float: if step < self.warmup_steps: return self.lr * step / self.warmup_steps if step < self.total_steps: - progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) - return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi)) + progress = (step - self.warmup_steps) / ( + self.total_steps - self.warmup_steps + ) + return self.final_lr + 0.5 * (self.lr - self.final_lr) * ( + 1 + math.cos(progress * math.pi) + ) return self.final_lr @@ -92,7 +102,9 @@ def get_parser(): parser.add_argument("--weight_decay", type=float, default=0) parser.add_argument("--optim_kwargs", type=json.loads, default=dict()) parser.add_argument("--cosine_lr_scheduler", action="store_true") - parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"]) + parser.add_argument( + "--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"] + ) parser.add_argument("--project") parser.add_argument("--run_name", default="debug") @@ -110,11 +122,15 @@ def get_dloader(args, training: bool): transforms.extend([v2.Resize(256), v2.CenterCrop(224)]) transforms.append(v2.ToDtype(torch.float32, scale=True)) - transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) + transforms.append( + v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ) transforms = v2.Compose(transforms) # use dataset from HF so download is fast - ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation") + ds = datasets.load_dataset( + "timm/resisc45", split="train" if training else "validation" + ) ds = ds.select_columns(["image", "label"]) ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"])) @@ -128,9 +144,9 @@ def get_dloader(args, training: bool): ) -def get_amp_ctx(amp): +def get_amp_ctx(amp, device): dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp] - return torch.autocast("cuda", dtype=dtype, enabled=amp != "none") + return torch.autocast(device, dtype=dtype, enabled=amp != "none") @torch.no_grad() @@ -148,8 +164,8 @@ def evaluate_model(model, args): if args.channels_last: batch["image"] = batch["image"].to(memory_format=torch.channels_last) - with get_amp_ctx(args.amp): - all_preds.append(model(batch["image"].cuda()).argmax(1).cpu()) + with get_amp_ctx(args.amp, _DEVICE): + all_preds.append(model(batch["image"].to(_DEVICE)).argmax(1).cpu()) all_labels = torch.cat(all_labels, dim=0) all_preds = torch.cat(all_preds, dim=0) @@ -164,8 +180,12 @@ def evaluate_model(model, args): if args.full_bf16: assert args.amp == "none", "When --full_bf16 is set, --amp must be none" if args.optim_cpu_offload == "deepspeed": - assert args.amp == "none", "When using DeepSpeed ZeRO-Offload, --amp must be none" - assert args.optim == "AdamW", "When using DeepSpeed ZeRO-Offload, --optim must be AdamW" + assert ( + args.amp == "none" + ), "When using DeepSpeed ZeRO-Offload, --amp must be none" + assert ( + args.optim == "AdamW" + ), "When using DeepSpeed ZeRO-Offload, --optim must be AdamW" if args.profile: args.n_epochs = 1 if args.seed is not None: @@ -185,14 +205,16 @@ def evaluate_model(model, args): dloader = get_dloader(args, True) print(f"Train dataset: {len(dloader.dataset):,} images") - model = timm.create_model(args.model, pretrained=True, num_classes=45, **args.model_kwargs) + model = timm.create_model( + args.model, pretrained=True, num_classes=45, **args.model_kwargs + ) if args.checkpoint_activations: model.set_grad_checkpointing() if args.full_bf16: model.bfloat16() if args.channels_last: model.to(memory_format=torch.channels_last) - model.cuda() # move model to CUDA after optionally convert it to BF16 + model.to(_DEVICE) # move model to DEVICE after optionally convert it to BF16 if args.compile: model.compile(fullgraph=True) print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") @@ -227,9 +249,15 @@ def evaluate_model(model, args): optim_cls = OPTIM_MAP[args.optim] if args.optim_cpu_offload == "ao": - optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls) + optim_cls = partial( + low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls + ) elif args.optim_cpu_offload == "ao_offload_grads": - optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True) + optim_cls = partial( + low_bit_optim.CPUOffloadOptimizer, + optimizer_class=optim_cls, + offload_gradients=True, + ) optim = optim_cls( model.parameters(), @@ -239,24 +267,30 @@ def evaluate_model(model, args): ) lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs) - grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16") + grad_scaler = torch.amp.GradScaler(_DEVICE, enabled=args.amp == "fp16") log_interval = 10 t0 = time.perf_counter() step = 0 for epoch_idx in range(args.n_epochs): model.train() - pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}") + pbar = tqdm( + dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}" + ) with torch.profiler.profile() if args.profile else nullcontext() as prof: for batch in pbar: if args.full_bf16: batch["image"] = batch["image"].bfloat16() if args.channels_last: - batch["image"] = batch["image"].to(memory_format=torch.channels_last) + batch["image"] = batch["image"].to( + memory_format=torch.channels_last + ) - with get_amp_ctx(args.amp): - loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda()) + with get_amp_ctx(args.amp, _DEVICE): + loss = F.cross_entropy( + model(batch["image"].to(_DEVICE)), batch["label"].to(_DEVICE) + ) if args.optim_cpu_offload == "deepspeed": model.backward(loss) @@ -275,7 +309,9 @@ def evaluate_model(model, args): log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]) if step > 0: t1 = time.perf_counter() - log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0) + log_dict["imgs_per_second"] = ( + args.batch_size * log_interval / (t1 - t0) + ) t0 = t1 logger.log(log_dict, step=step) @@ -296,9 +332,11 @@ def evaluate_model(model, args): else: val_acc = evaluate_model(model, args) - print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}") + print( + f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}" + ) logger.log(dict(val_acc=val_acc), step=step) - peak_mem = torch.cuda.max_memory_allocated() / 1e9 + peak_mem = getattr(torch, _DEVICE).max_memory_allocated() / 1e9 print(f"Max memory used: {peak_mem:.02f} GB") logger.log(dict(max_memory_allocated=peak_mem)) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index a97d1cffdd..e325cb1c14 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -26,6 +26,7 @@ from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8 from torchao.utils import ( + get_available_devices, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, @@ -42,7 +43,7 @@ lpmm = None -_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) +_DEVICES = get_available_devices() class TestQuantize(TestCase): @@ -94,7 +95,9 @@ def test_bf16_stochastic_round(self, device, compile): x = torch.rand(32, device=device) * 100 x_rep = x.view(-1, 1).repeat(1, 100_000) - func = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile) + func = torch.compile( + _fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile + ) x_rep_bf16 = func(x_rep) assert x_rep_bf16.dtype is torch.bfloat16 @@ -169,8 +172,13 @@ def test_subclass_slice(self, subclass, shape, device): tensor = subclass.zeros(shape, device=device) offset = shape[0] // 2 - torch.testing.assert_close(tensor.dequantize()[:offset], tensor[:offset].dequantize()) - torch.testing.assert_close(tensor.dequantize()[offset:offset*2], tensor[offset:offset*2].dequantize()) + torch.testing.assert_close( + tensor.dequantize()[:offset], tensor[:offset].dequantize() + ) + torch.testing.assert_close( + tensor.dequantize()[offset : offset * 2], + tensor[offset : offset * 2].dequantize(), + ) @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available") @pytest.mark.skipif( @@ -188,7 +196,9 @@ def test_optim_8bit_correctness(self, optim_name): block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048 optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) - optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size) + optim2 = getattr(low_bit_optim, optim_name)( + model2.parameters(), block_size=block_size + ) for _ in range(2): x = torch.randn(4, 32, device=device) @@ -244,11 +254,12 @@ def test_optim_4bit_correctness(self, optim_name): torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) @pytest.mark.skipif( - not torch.cuda.is_available(), reason="optim CPU offload requires CUDA" + not torch.cuda.is_available() and not torch.xpu.is_available(), + reason="optim CPU offload requires CUDA or XPU", ) @parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)]) def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): - device = "cuda" + device = _DEVICES[-1] model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)) model1.to(device) @@ -279,13 +290,16 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): torch.testing.assert_close(p2, p1) @pytest.mark.skipif( - not torch.cuda.is_available(), reason="optim CPU offload requires CUDA" + not torch.cuda.is_available() and not torch.xpu.is_available(), + reason="optim CPU offload requires CUDA or XPU", ) def test_optim_cpu_offload_save_load(self): - device = "cuda" + device = _DEVICES[-1] model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)) model1.to(device) - optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW) + optim1 = low_bit_optim.CPUOffloadOptimizer( + model1.parameters(), torch.optim.AdamW + ) for _ in range(2): x = torch.randn(4, 32, device=device) @@ -300,7 +314,9 @@ def test_optim_cpu_offload_save_load(self): # resume training model2 = copy.deepcopy(model1) - optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW) + optim2 = low_bit_optim.CPUOffloadOptimizer( + model2.parameters(), torch.optim.AdamW + ) optim2.load_state_dict(state_dict) for _ in range(2): @@ -384,7 +400,11 @@ def _test_fsdp2(self, optim_cls): import torch.utils._pytree as pytree from torch.distributed._composable.fsdp import fully_shard from torch.distributed.tensor import DTensor - from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock + from torch.testing._internal.distributed._tensor.common_dtensor import ( + ModelArgs, + Transformer, + TransformerBlock, + ) batch_size = 3 vocab_size = 1024 @@ -457,7 +477,10 @@ def _test_fsdp2(self, optim_cls): subclasses = (OptimState4bit, OptimState8bit, OptimStateFp8) - for v1, v2 in zip(pytree.tree_iter(resumed_fsdp_optim.state_dict()), pytree.tree_iter(fsdp_optim.state_dict())): + for v1, v2 in zip( + pytree.tree_iter(resumed_fsdp_optim.state_dict()), + pytree.tree_iter(fsdp_optim.state_dict()), + ): assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__) if isinstance(v1, DTensor): v1 = v1.to_local() diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index bd66262609..6358574e45 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -80,7 +80,7 @@ All of our low-bit optimizers mentioned above also support `bf16_stochastic_roun ## Optimizer CPU offload -This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA is supported. For multi-GPU training, you can use FSDP's built-in CPU offload. +This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA and XPU is supported. For multi-GPU training, you can use FSDP's built-in CPU offload. ```python import torch @@ -97,7 +97,7 @@ optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, offload_gradi This will reduce GPU memory usage by optimizer state size, and additionally gradient size if `offload_gradients=True`. `CPUOffloadOptimizer` can wrap any base optimizer. -For saving and loading `CPUOffloadOptimizer`, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside `CPUOffloadOptimizer.__init__()`. (TODO: we might want to have a method to synchronize CUDA and CPU params in either direction CPU->CUDA and CUDA->CPU, in case they are out of sync.) +For saving and loading `CPUOffloadOptimizer`, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside `CPUOffloadOptimizer.__init__()`. (TODO: we might want to have a method to synchronize GPU and CPU params in either direction CPU->GPU and GPU->CPU, in case they are out of sync.) ```python ckpt = torch.load("checkpoint.pth") diff --git a/torchao/prototype/low_bit_optim/cpu_offload.py b/torchao/prototype/low_bit_optim/cpu_offload.py index c69932aa4c..90008f67fe 100644 --- a/torchao/prototype/low_bit_optim/cpu_offload.py +++ b/torchao/prototype/low_bit_optim/cpu_offload.py @@ -3,7 +3,7 @@ import torch from torch.optim.optimizer import Optimizer, ParamsT -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, get_available_devices class CPUOffloadOptimizer: @@ -38,51 +38,56 @@ def __init__( if not isinstance(param_groups[0], dict): param_groups = [{"params": param_groups}] - self.param_cuda2cpu_map = dict() + self.param_d2h_map = dict() self.optim_dict = dict() - self.stream = torch.cuda.Stream() + self.device = get_available_devices()[-1] + assert self.device in [ + "cuda", + "xpu", + ], "CPU Offload currently only supports CUDA & XPU" + self.stream = getattr(torch, self.device).Stream() # the queue maintains the order which param we should do optim step on first. self.queue = dict() - def backward_hook(p_cuda): - if p_cuda.grad is not None: - p_cpu = self.param_cuda2cpu_map[p_cuda] + def backward_hook(p_device): + if p_device.grad is not None: + p_host = self.param_d2h_map[p_device] # make sure backward for this param finishes - self.stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(self.stream): - p_cpu.grad.copy_(p_cuda.grad, non_blocking=True) + self.stream.wait_stream(getattr(torch, self.device).current_stream()) + with getattr(torch, self.device).stream(self.stream): + p_host.grad.copy_(p_device.grad, non_blocking=True) # we rely on CPython implementation of dictionary, which preserves insertion order. # if a param is added again (e.g. due to gradient accumulation), it is moved to the # end of the queue by removing and inserting it again. - if p_cuda in self.queue: - del self.queue[p_cuda] - self.queue[p_cuda] = self.stream.record_event() + if p_device in self.queue: + del self.queue[p_device] + self.queue[p_device] = self.stream.record_event() - # deallocate CUDA gradients once D2H transfer finishes. + # deallocate DEVICE gradients once D2H transfer finishes. if offload_gradients: - p_cuda.grad.record_stream(self.stream) - p_cuda.grad = None + p_device.grad.record_stream(self.stream) + p_device.grad = None for param_group in param_groups: params = param_group.pop("params") - for p_cuda in params: - if not p_cuda.requires_grad: + for p_device in params: + if not p_device.requires_grad: continue # pre-allocate CPU params and grads - p_cpu = torch.empty_like(p_cuda, device="cpu", pin_memory=True) - p_cpu.grad = torch.empty_like(p_cpu, pin_memory=True) + p_host = torch.empty_like(p_device, device="cpu", pin_memory=True) + p_host.grad = torch.empty_like(p_host, pin_memory=True) - p_cpu.copy_(p_cuda.detach(), non_blocking=True) - self.param_cuda2cpu_map[p_cuda] = p_cpu + p_host.copy_(p_device.detach(), non_blocking=True) + self.param_d2h_map[p_device] = p_host - p_cuda.register_post_accumulate_grad_hook(backward_hook) - self.optim_dict[p_cuda] = optimizer_class( - [{"params": p_cpu, **param_group}], **kwargs + p_device.register_post_accumulate_grad_hook(backward_hook) + self.optim_dict[p_device] = optimizer_class( + [{"params": p_host, **param_group}], **kwargs ) @torch.no_grad() @@ -91,16 +96,16 @@ def step(self, closure=None): if closure is not None: loss = closure() - for p_cuda, grad_d2h_event in self.queue.items(): + for p_device, grad_d2h_event in self.queue.items(): grad_d2h_event.synchronize() - self.optim_dict[p_cuda].step() + self.optim_dict[p_device].step() # submit more job to self.stream. it guarantees that we only start # moving param H2D once all backwards finish, since self.stream # will wait for current_stream when moving grad D2H. - p_cpu = self.param_cuda2cpu_map[p_cuda] - with torch.cuda.stream(self.stream): - p_cuda.copy_(p_cpu, non_blocking=True) + p_host = self.param_d2h_map[p_device] + with getattr(torch, self.device).stream(self.stream): + p_device.copy_(p_host, non_blocking=True) self.queue.clear() return loss @@ -108,9 +113,9 @@ def step(self, closure=None): def zero_grad(self, set_to_none=True): assert set_to_none - # only clear CUDA grad. CPU grad will always be overwritten by CUDA grad. - for p_cuda in self.param_cuda2cpu_map.keys(): - p_cuda.grad = None + # only clear DEVICE grad. CPU grad will always be overwritten by DEVICE grad. + for p_device in self.param_d2h_map.keys(): + p_device.grad = None @property def param_groups(self): diff --git a/torchao/utils.py b/torchao/utils.py index 2813f0b0b4..ba91fb3fe0 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -13,6 +13,7 @@ __all__ = [ "benchmark_model", "profiler_runner", + "get_available_devices", "get_compute_capability", "skip_if_compute_capability_less_than", "benchmark_torch_function_in_microseconds", @@ -124,6 +125,18 @@ def profiler_runner(path, fn, *args, **kwargs): return result +def get_available_devices(): + devices = ["cpu"] + if torch.cuda.is_available(): + devices.append("cuda") + elif torch.xpu.is_available(): + devices.append("xpu") + if TORCH_VERSION_AT_LEAST_2_5: + if torch.mps.is_available(): + devices.append("mps") + return devices + + def get_compute_capability(): if torch.cuda.is_available(): capability = torch.cuda.get_device_capability()