From 82feeef6f97eebef35cfa74d9d9663d8689a7ee4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 9 Nov 2024 15:11:51 +0800 Subject: [PATCH] [Low-bit optim] Support for `dcp.save()` and `dcp.load()` (#1217) * support dcp.save * add test for dcp.load() * fix test * typo * implement aten.slice * skip test * fix checks * run ruff * fix formatting * remove add safe globals in test * sort some imports --------- Co-authored-by: Mark Saroufim --- test/prototype/test_low_bit_optim.py | 156 +++++++++++------- .../prototype/low_bit_optim/subclass_4bit.py | 51 +++++- .../prototype/low_bit_optim/subclass_8bit.py | 50 +++++- .../prototype/low_bit_optim/subclass_fp8.py | 44 ++++- 4 files changed, 240 insertions(+), 61 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index f0b608b47d..a97d1cffdd 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -1,27 +1,33 @@ import copy +import shutil import tempfile +from pathlib import Path import pytest import torch from packaging.version import Version from torch import nn +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( TestCase, instantiate_parametrized_tests, parametrize, run_tests, ) -from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import FSDPTest + from torchao.prototype import low_bit_optim from torchao.prototype.low_bit_optim.quant_utils import ( - quantize_8bit_with_qmap, - quantize_4bit_with_qmap, _fp32_to_bf16_sr, + quantize_4bit_with_qmap, + quantize_8bit_with_qmap, ) +from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit +from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit +from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8 from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, ) @@ -88,13 +94,8 @@ def test_bf16_stochastic_round(self, device, compile): x = torch.rand(32, device=device) * 100 x_rep = x.view(-1, 1).repeat(1, 100_000) - if compile: - x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)( - x_rep - ) - else: - x_rep_bf16 = _fp32_to_bf16_sr(x_rep) - + func = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile) + x_rep_bf16 = func(x_rep) assert x_rep_bf16.dtype is torch.bfloat16 # must cast BF16 tensor back to FP32 so that .mean() is accurate @@ -102,9 +103,6 @@ def test_bf16_stochastic_round(self, device, compile): class TestOptim(TestCase): - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3" - ) @parametrize( "optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"], @@ -151,29 +149,46 @@ def test_optim_smoke(self, optim_name, dtype, device): for p1, p2 in zip(model.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1) + # aten.slice is required for dcp.load() when world size changes i.e. re-sharding + # however, it's cumbersome to test it directly, since we would need to run distributed + # test 2 times with different world size, and persist checkpoint across the 2 runs. + # thus, we only test for the required op. note that future implementations of dcp.load() + # may use other ops. + @parametrize("subclass", [OptimState4bit, OptimState8bit, OptimStateFp8]) + @parametrize("shape", [(4096,), (256, 256)]) + @parametrize("device", _DEVICES) + def test_subclass_slice(self, subclass, shape, device): + if subclass == OptimStateFp8: + if device == "cpu" and len(shape) > 1 and not TORCH_VERSION_AT_LEAST_2_5: + pytest.skip("fill_cpu not implemented for Float8_e4m3fn for torch<2.5") + if device == "cuda" and not TORCH_VERSION_AT_LEAST_2_4: + pytest.skip("FP8 CUDA requires PyTorch >= 2.4") + if device == "cuda" and torch.cuda.get_device_capability() < (8, 9): + pytest.skip("FP8 CUDA requires compute capability >= 8.9") + + tensor = subclass.zeros(shape, device=device) + offset = shape[0] // 2 + + torch.testing.assert_close(tensor.dequantize()[:offset], tensor[:offset].dequantize()) + torch.testing.assert_close(tensor.dequantize()[offset:offset*2], tensor[offset:offset*2].dequantize()) + @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available") @pytest.mark.skipif( not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA", ) - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3" - ) @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) def test_optim_8bit_correctness(self, optim_name): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( - device - ) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)) + model1.to(device) model2 = copy.deepcopy(model1) # https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0 block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048 optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) - optim2 = getattr(low_bit_optim, optim_name)( - model2.parameters(), block_size=block_size - ) + optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size) for _ in range(2): x = torch.randn(4, 32, device=device) @@ -196,15 +211,11 @@ def test_optim_8bit_correctness(self, optim_name): @pytest.mark.skipif( not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA" ) - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3" - ) @parametrize("optim_name", ["Adam4bit", "AdamW4bit"]) def test_optim_4bit_correctness(self, optim_name): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( - device - ) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)) + model1.to(device) model2 = copy.deepcopy(model1) # lpmm doesn't have Adam. use AdamW with no weight decay instead. @@ -238,12 +249,11 @@ def test_optim_4bit_correctness(self, optim_name): @parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)]) def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( - device - ) - model1[0].requires_grad_( - False - ) # make sure it can work in the presence of non-trainable params + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)) + model1.to(device) + + # make sure it can work in the presence of non-trainable params + model1[0].requires_grad_(False) model2 = copy.deepcopy(model1) optim1 = torch.optim.AdamW(model1.parameters()) @@ -273,12 +283,9 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): ) def test_optim_cpu_offload_save_load(self): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( - device - ) - optim1 = low_bit_optim.CPUOffloadOptimizer( - model1.parameters(), torch.optim.AdamW - ) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)) + model1.to(device) + optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW) for _ in range(2): x = torch.randn(4, 32, device=device) @@ -293,9 +300,7 @@ def test_optim_cpu_offload_save_load(self): # resume training model2 = copy.deepcopy(model1) - optim2 = low_bit_optim.CPUOffloadOptimizer( - model2.parameters(), torch.optim.AdamW - ) + optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW) optim2.load_state_dict(state_dict) for _ in range(2): @@ -315,16 +320,17 @@ def test_optim_cpu_offload_save_load(self): def test_optim_bf16_stochastic_round_correctness(self): device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(2024) - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( - device - ) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)) + model1.to(device) model2 = copy.deepcopy(model1).bfloat16() # small LR so that weight update is small # when bf16_stochastic_round=False, the test will fail after 1 iteration optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5) optim2 = low_bit_optim._AdamW( - model2.parameters(), lr=1e-5, bf16_stochastic_round=True + model2.parameters(), + lr=1e-5, + bf16_stochastic_round=True, ) # overfit on this sample @@ -350,10 +356,13 @@ def test_optim_bf16_stochastic_round_correctness(self): ) +_FSDP_WORLD_SIZE = 2 + + class TestFSDP2(FSDPTest): @property def world_size(self) -> int: - return 2 + return _FSDP_WORLD_SIZE @pytest.mark.skipif( not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required." @@ -370,12 +379,12 @@ def test_fsdp2(self): ) def _test_fsdp2(self, optim_cls): + import torch.distributed as dist + import torch.distributed.checkpoint as dcp + import torch.utils._pytree as pytree from torch.distributed._composable.fsdp import fully_shard - from torch.testing._internal.distributed._tensor.common_dtensor import ( - ModelArgs, - Transformer, - TransformerBlock, - ) + from torch.distributed.tensor import DTensor + from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock batch_size = 3 vocab_size = 1024 @@ -413,9 +422,7 @@ def _test_fsdp2(self, optim_cls): base_loss.backward() for param in base_model.parameters(): if param.grad is not None: - torch.distributed.all_reduce( - param.grad, op=torch.distributed.ReduceOp.AVG - ) + dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) base_optim.step() self.assertEqual(fsdp_loss, base_loss) @@ -428,6 +435,39 @@ def _test_fsdp2(self, optim_cls): self.assertEqual(base_exp_avg.dequantize(), full_fsdp_exp_avg.dequantize()) + # test for compatibility with dcp.save() and .load() + checkpoint_id = f"_fsdp_low_bit_optim_{optim_cls.__name__}" + if Path(checkpoint_id).exists(): + shutil.rmtree(checkpoint_id) + dcp.save(fsdp_optim.state_dict(), checkpoint_id=checkpoint_id) + + # normally we would want to use dcp.state_dict.get_optimizer_state_dict() to initialize optim states. + # however, currently it does not respect tensor-ness of LR pytorch/pytorch#139575. + # therefore, we have to manually initialize optim state here. + resumed_fsdp_optim = optim_cls(fsdp_model.parameters(), lr=1e-2) + for p in fsdp_model.parameters(): + p.grad = torch.zeros_like(p) + + # this will change model weights due to weight decay, but since we don't use the model anymore, it's fine. + resumed_fsdp_optim.step() + + dcp.load(resumed_fsdp_optim.state_dict(), checkpoint_id=checkpoint_id) + if dist.get_rank() == 0: + shutil.rmtree(checkpoint_id) + + subclasses = (OptimState4bit, OptimState8bit, OptimStateFp8) + + for v1, v2 in zip(pytree.tree_iter(resumed_fsdp_optim.state_dict()), pytree.tree_iter(fsdp_optim.state_dict())): + assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__) + if isinstance(v1, DTensor): + v1 = v1.to_local() + v2 = v2.to_local() + assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__) + if isinstance(v1, subclasses): + v1 = v1.dequantize() + v2 = v2.dequantize() + self.assertEqual(v1, v2) + instantiate_parametrized_tests(TestQuantize) instantiate_parametrized_tests(TestOptim) diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index e493b978fe..4c05cced87 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -177,13 +177,15 @@ def _(func, types, args, kwargs): ) -# this is needed for DTensor.full_tensor() @OptimState4bit.implements( [ + # required by DTensor.full_tensor() c10d_functional.all_gather_into_tensor.default, _c10d_functional.all_gather_into_tensor.default, c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default, + # required by torch.distributed.checkpoint.save + aten.detach.default, ] ) def _(func, types, args, kwargs): @@ -201,6 +203,53 @@ def _(func, types, args, kwargs): return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape) +# required by torch.distributed.checkpoint.save +# note that we don't actually implement pin memory for this tensor subclass +# (pin_memory argument is ignored in aten._to_copy) +@OptimState4bit.implements(aten.is_pinned.default) +def _(func, types, args, kwargs): + return ( + args[0].codes.is_pinned() + and args[0].scale.is_pinned() + and args[0].qmap.is_pinned() + ) + + +# required by torch.distributed.checkpoint.load when world size changes i.e. re-sharding +@OptimState4bit.implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + x, dim, start, end = args[:4] + step = args[4] if len(args) > 4 else 1 + + # input validation + if dim != 0: + raise ValueError("Only support aten.slice along the first dim") + if step != 1: + raise ValueError("Only support aten.slice with step=1") + + block_size = x.block_size + stride = math.prod(x.shape[1:]) + + # for 1 increment in x along the first dim, + # (flattened) scale will increment by stride / block_size + if (start * stride) % block_size != 0 or (end * stride) % block_size != 0: + raise ValueError( + f"Invalid start or end for shape={x.shape} and block_size={block_size}. " + f"Make sure start and end align with block boundary. " + f"Received start={start}, end={end}." + ) + + # note that for 4-bit, we store .codes as flattened buffer + # divide by 2 since we store 2x 4-bit in 1x uint8 + codes = x.codes[start * stride // 2 : end * stride // 2] + scale = x.scale[start * stride // block_size : end * stride // block_size] + + # adjust the first dim + shape = (x.shape[0] * codes.numel() // x.codes.numel(),) + x.shape[1:] + + return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape) + + if TORCH_VERSION_AT_LEAST_2_5: from torch.serialization import add_safe_globals diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index d23d159645..659a43f42d 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -1,3 +1,5 @@ +import math + import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing @@ -149,13 +151,15 @@ def _(func, types, args, kwargs): return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed) -# this is needed for DTensor.full_tensor() @OptimState8bit.implements( [ + # required by DTensor.full_tensor() c10d_functional.all_gather_into_tensor.default, _c10d_functional.all_gather_into_tensor.default, c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default, + # required by torch.distributed.checkpoint.save + aten.detach.default, ] ) def _(func, types, args, kwargs): @@ -172,6 +176,50 @@ def _(func, types, args, kwargs): ) +# required by torch.distributed.checkpoint.save +# note that we don't actually implement pin memory for this tensor subclass +# (pin_memory argument is ignored in aten._to_copy) +@OptimState8bit.implements(aten.is_pinned.default) +def _(func, types, args, kwargs): + return ( + args[0].codes.is_pinned() + and args[0].scale.is_pinned() + and args[0].qmap.is_pinned() + ) + + +# required by torch.distributed.checkpoint.load when world size changes i.e. re-sharding +@OptimState8bit.implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + x, dim, start, end = args[:4] + step = args[4] if len(args) > 4 else 1 + + # input validation + if dim != 0: + raise ValueError("Only support aten.slice along the first dim") + if step != 1: + raise ValueError("Only support aten.slice with step=1") + + block_size = x.block_size + stride = math.prod(x.shape[1:]) + + # for 1 increment in x along the first dim, + # (flattened) scale will increment by stride / block_size + if (start * stride) % block_size != 0 or (end * stride) % block_size != 0: + raise ValueError( + f"Invalid start or end for shape={x.shape} and block_size={block_size}. " + f"Make sure start and end align with block boundary. " + f"Received start={start}, end={end}." + ) + + return OptimState8bit( + x.codes[start:end], + x.scale[start * stride // block_size : end * stride // block_size], + x.qmap.clone(), + x.signed, + ) + + if TORCH_VERSION_AT_LEAST_2_5: from torch.serialization import add_safe_globals diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index d95b0c2661..b5c8af6c83 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -1,3 +1,5 @@ +import math + import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing @@ -124,13 +126,15 @@ def _(func, types, args, kwargs): return OptimStateFp8(x.codes.view(shape), x.scale) -# this is needed for DTensor.full_tensor() @OptimStateFp8.implements( [ + # required by DTensor.full_tensor() c10d_functional.all_gather_into_tensor.default, _c10d_functional.all_gather_into_tensor.default, c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default, + # required by torch.distributed.checkpoint.save + aten.detach.default, ] ) def _(func, types, args, kwargs): @@ -145,6 +149,44 @@ def _(func, types, args, kwargs): ) +# required by torch.distributed.checkpoint.save +# note that we don't actually implement pin memory for this tensor subclass +# (pin_memory argument is ignored in aten._to_copy) +@OptimStateFp8.implements(aten.is_pinned.default) +def _(func, types, args, kwargs): + return args[0].codes.is_pinned() and args[0].scale.is_pinned() + + +# required by torch.distributed.checkpoint.load when world size changes i.e. re-sharding +@OptimStateFp8.implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + x, dim, start, end = args[:4] + step = args[4] if len(args) > 4 else 1 + + # input validation + if dim != 0: + raise ValueError("Only support aten.slice along the first dim") + if step != 1: + raise ValueError("Only support aten.slice with step=1") + + block_size = x.block_size + stride = math.prod(x.shape[1:]) + + # for 1 increment in x along the first dim, + # (flattened) scale will increment by stride / block_size + if (start * stride) % block_size != 0 or (end * stride) % block_size != 0: + raise ValueError( + f"Invalid start or end for shape={x.shape} and block_size={block_size}. " + f"Make sure start and end align with block boundary. " + f"Received start={start}, end={end}." + ) + + return OptimStateFp8( + x.codes[start:end], + x.scale[start * stride // block_size : end * stride // block_size], + ) + + if TORCH_VERSION_AT_LEAST_2_5: from torch.serialization import add_safe_globals