Skip to content

Commit

Permalink
[Low-bit optim] Support for dcp.save() and dcp.load() (#1217)
Browse files Browse the repository at this point in the history
* support dcp.save

* add test for dcp.load()

* fix test

* typo

* implement aten.slice

* skip test

* fix checks

* run ruff

* fix formatting

* remove add safe globals in test

* sort some imports

---------

Co-authored-by: Mark Saroufim <[email protected]>
  • Loading branch information
2 people authored and jainapurva committed Nov 11, 2024
1 parent 14d844a commit 82feeef
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 61 deletions.
156 changes: 98 additions & 58 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
import copy
import shutil
import tempfile
from pathlib import Path

import pytest
import torch
from packaging.version import Version
from torch import nn
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest

from torchao.prototype import low_bit_optim
from torchao.prototype.low_bit_optim.quant_utils import (
quantize_8bit_with_qmap,
quantize_4bit_with_qmap,
_fp32_to_bf16_sr,
quantize_4bit_with_qmap,
quantize_8bit_with_qmap,
)
from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
)

Expand Down Expand Up @@ -88,23 +94,15 @@ def test_bf16_stochastic_round(self, device, compile):
x = torch.rand(32, device=device) * 100
x_rep = x.view(-1, 1).repeat(1, 100_000)

if compile:
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(
x_rep
)
else:
x_rep_bf16 = _fp32_to_bf16_sr(x_rep)

func = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile)
x_rep_bf16 = func(x_rep)
assert x_rep_bf16.dtype is torch.bfloat16

# must cast BF16 tensor back to FP32 so that .mean() is accurate
torch.testing.assert_close(x_rep_bf16.float().mean(1), x, atol=3e-5, rtol=3e-5)


class TestOptim(TestCase):
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize(
"optim_name",
["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"],
Expand Down Expand Up @@ -151,29 +149,46 @@ def test_optim_smoke(self, optim_name, dtype, device):
for p1, p2 in zip(model.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1)

# aten.slice is required for dcp.load() when world size changes i.e. re-sharding
# however, it's cumbersome to test it directly, since we would need to run distributed
# test 2 times with different world size, and persist checkpoint across the 2 runs.
# thus, we only test for the required op. note that future implementations of dcp.load()
# may use other ops.
@parametrize("subclass", [OptimState4bit, OptimState8bit, OptimStateFp8])
@parametrize("shape", [(4096,), (256, 256)])
@parametrize("device", _DEVICES)
def test_subclass_slice(self, subclass, shape, device):
if subclass == OptimStateFp8:
if device == "cpu" and len(shape) > 1 and not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("fill_cpu not implemented for Float8_e4m3fn for torch<2.5")
if device == "cuda" and not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("FP8 CUDA requires PyTorch >= 2.4")
if device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 CUDA requires compute capability >= 8.9")

tensor = subclass.zeros(shape, device=device)
offset = shape[0] // 2

torch.testing.assert_close(tensor.dequantize()[:offset], tensor[:offset].dequantize())
torch.testing.assert_close(tensor.dequantize()[offset:offset*2], tensor[offset:offset*2].dequantize())

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="bitsandbytes 8-bit Adam only works for CUDA",
)
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
def test_optim_8bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)
model2 = copy.deepcopy(model1)

# https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(low_bit_optim, optim_name)(
model2.parameters(), block_size=block_size
)
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand All @@ -196,15 +211,11 @@ def test_optim_8bit_correctness(self, optim_name):
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA"
)
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
def test_optim_4bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)
model2 = copy.deepcopy(model1)

# lpmm doesn't have Adam. use AdamW with no weight decay instead.
Expand Down Expand Up @@ -238,12 +249,11 @@ def test_optim_4bit_correctness(self, optim_name):
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model1[0].requires_grad_(
False
) # make sure it can work in the presence of non-trainable params
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)

# make sure it can work in the presence of non-trainable params
model1[0].requires_grad_(False)
model2 = copy.deepcopy(model1)

optim1 = torch.optim.AdamW(model1.parameters())
Expand Down Expand Up @@ -273,12 +283,9 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
)
def test_optim_cpu_offload_save_load(self):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
optim1 = low_bit_optim.CPUOffloadOptimizer(
model1.parameters(), torch.optim.AdamW
)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand All @@ -293,9 +300,7 @@ def test_optim_cpu_offload_save_load(self):

# resume training
model2 = copy.deepcopy(model1)
optim2 = low_bit_optim.CPUOffloadOptimizer(
model2.parameters(), torch.optim.AdamW
)
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
optim2.load_state_dict(state_dict)

for _ in range(2):
Expand All @@ -315,16 +320,17 @@ def test_optim_cpu_offload_save_load(self):
def test_optim_bf16_stochastic_round_correctness(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(2024)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)
model2 = copy.deepcopy(model1).bfloat16()

# small LR so that weight update is small
# when bf16_stochastic_round=False, the test will fail after 1 iteration
optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5)
optim2 = low_bit_optim._AdamW(
model2.parameters(), lr=1e-5, bf16_stochastic_round=True
model2.parameters(),
lr=1e-5,
bf16_stochastic_round=True,
)

# overfit on this sample
Expand All @@ -350,10 +356,13 @@ def test_optim_bf16_stochastic_round_correctness(self):
)


_FSDP_WORLD_SIZE = 2


class TestFSDP2(FSDPTest):
@property
def world_size(self) -> int:
return 2
return _FSDP_WORLD_SIZE

@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required."
Expand All @@ -370,12 +379,12 @@ def test_fsdp2(self):
)

def _test_fsdp2(self, optim_cls):
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.utils._pytree as pytree
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
TransformerBlock,
)
from torch.distributed.tensor import DTensor
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock

batch_size = 3
vocab_size = 1024
Expand Down Expand Up @@ -413,9 +422,7 @@ def _test_fsdp2(self, optim_cls):
base_loss.backward()
for param in base_model.parameters():
if param.grad is not None:
torch.distributed.all_reduce(
param.grad, op=torch.distributed.ReduceOp.AVG
)
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
base_optim.step()
self.assertEqual(fsdp_loss, base_loss)

Expand All @@ -428,6 +435,39 @@ def _test_fsdp2(self, optim_cls):

self.assertEqual(base_exp_avg.dequantize(), full_fsdp_exp_avg.dequantize())

# test for compatibility with dcp.save() and .load()
checkpoint_id = f"_fsdp_low_bit_optim_{optim_cls.__name__}"
if Path(checkpoint_id).exists():
shutil.rmtree(checkpoint_id)
dcp.save(fsdp_optim.state_dict(), checkpoint_id=checkpoint_id)

# normally we would want to use dcp.state_dict.get_optimizer_state_dict() to initialize optim states.
# however, currently it does not respect tensor-ness of LR pytorch/pytorch#139575.
# therefore, we have to manually initialize optim state here.
resumed_fsdp_optim = optim_cls(fsdp_model.parameters(), lr=1e-2)
for p in fsdp_model.parameters():
p.grad = torch.zeros_like(p)

# this will change model weights due to weight decay, but since we don't use the model anymore, it's fine.
resumed_fsdp_optim.step()

dcp.load(resumed_fsdp_optim.state_dict(), checkpoint_id=checkpoint_id)
if dist.get_rank() == 0:
shutil.rmtree(checkpoint_id)

subclasses = (OptimState4bit, OptimState8bit, OptimStateFp8)

for v1, v2 in zip(pytree.tree_iter(resumed_fsdp_optim.state_dict()), pytree.tree_iter(fsdp_optim.state_dict())):
assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__)
if isinstance(v1, DTensor):
v1 = v1.to_local()
v2 = v2.to_local()
assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__)
if isinstance(v1, subclasses):
v1 = v1.dequantize()
v2 = v2.dequantize()
self.assertEqual(v1, v2)


instantiate_parametrized_tests(TestQuantize)
instantiate_parametrized_tests(TestOptim)
Expand Down
51 changes: 50 additions & 1 deletion torchao/prototype/low_bit_optim/subclass_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,15 @@ def _(func, types, args, kwargs):
)


# this is needed for DTensor.full_tensor()
@OptimState4bit.implements(
[
# required by DTensor.full_tensor()
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
# required by torch.distributed.checkpoint.save
aten.detach.default,
]
)
def _(func, types, args, kwargs):
Expand All @@ -201,6 +203,53 @@ def _(func, types, args, kwargs):
return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape)


# required by torch.distributed.checkpoint.save
# note that we don't actually implement pin memory for this tensor subclass
# (pin_memory argument is ignored in aten._to_copy)
@OptimState4bit.implements(aten.is_pinned.default)
def _(func, types, args, kwargs):
return (
args[0].codes.is_pinned()
and args[0].scale.is_pinned()
and args[0].qmap.is_pinned()
)


# required by torch.distributed.checkpoint.load when world size changes i.e. re-sharding
@OptimState4bit.implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
x, dim, start, end = args[:4]
step = args[4] if len(args) > 4 else 1

# input validation
if dim != 0:
raise ValueError("Only support aten.slice along the first dim")
if step != 1:
raise ValueError("Only support aten.slice with step=1")

block_size = x.block_size
stride = math.prod(x.shape[1:])

# for 1 increment in x along the first dim,
# (flattened) scale will increment by stride / block_size
if (start * stride) % block_size != 0 or (end * stride) % block_size != 0:
raise ValueError(
f"Invalid start or end for shape={x.shape} and block_size={block_size}. "
f"Make sure start and end align with block boundary. "
f"Received start={start}, end={end}."
)

# note that for 4-bit, we store .codes as flattened buffer
# divide by 2 since we store 2x 4-bit in 1x uint8
codes = x.codes[start * stride // 2 : end * stride // 2]
scale = x.scale[start * stride // block_size : end * stride // block_size]

# adjust the first dim
shape = (x.shape[0] * codes.numel() // x.codes.numel(),) + x.shape[1:]

return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape)


if TORCH_VERSION_AT_LEAST_2_5:
from torch.serialization import add_safe_globals

Expand Down
Loading

0 comments on commit 82feeef

Please sign in to comment.