Skip to content

Commit

Permalink
refactor(cpu-offload): use ruff format
Browse files Browse the repository at this point in the history
Signed-off-by: dbyoung18 <[email protected]>
  • Loading branch information
dbyoung18 committed Nov 24, 2024
1 parent d9cce7b commit 649b00b
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 27 deletions.
66 changes: 50 additions & 16 deletions benchmarks/benchmark_low_bit_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@

OPTIM_MAP.update(
AdamW4bitLpmm=partial(lpmm.optim.AdamW, fused=True),
AdamW4bitRank1Lpmm=partial(lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")),
AdamW4bitRank1Lpmm=partial(
lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")
),
)

except ImportError:
Expand All @@ -71,8 +73,12 @@ def get_lr(self, step: int) -> float:
if step < self.warmup_steps:
return self.lr * step / self.warmup_steps
if step < self.total_steps:
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi))
progress = (step - self.warmup_steps) / (
self.total_steps - self.warmup_steps
)
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (
1 + math.cos(progress * math.pi)
)
return self.final_lr


Expand All @@ -96,7 +102,9 @@ def get_parser():
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--optim_kwargs", type=json.loads, default=dict())
parser.add_argument("--cosine_lr_scheduler", action="store_true")
parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"])
parser.add_argument(
"--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"]
)

parser.add_argument("--project")
parser.add_argument("--run_name", default="debug")
Expand All @@ -114,11 +122,15 @@ def get_dloader(args, training: bool):
transforms.extend([v2.Resize(256), v2.CenterCrop(224)])

transforms.append(v2.ToDtype(torch.float32, scale=True))
transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
transforms.append(
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
transforms = v2.Compose(transforms)

# use dataset from HF so download is fast
ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation")
ds = datasets.load_dataset(
"timm/resisc45", split="train" if training else "validation"
)
ds = ds.select_columns(["image", "label"])
ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"]))

Expand Down Expand Up @@ -168,8 +180,12 @@ def evaluate_model(model, args):
if args.full_bf16:
assert args.amp == "none", "When --full_bf16 is set, --amp must be none"
if args.optim_cpu_offload == "deepspeed":
assert args.amp == "none", "When using DeepSpeed ZeRO-Offload, --amp must be none"
assert args.optim == "AdamW", "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
assert (
args.amp == "none"
), "When using DeepSpeed ZeRO-Offload, --amp must be none"
assert (
args.optim == "AdamW"
), "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
if args.profile:
args.n_epochs = 1
if args.seed is not None:
Expand All @@ -189,7 +205,9 @@ def evaluate_model(model, args):
dloader = get_dloader(args, True)
print(f"Train dataset: {len(dloader.dataset):,} images")

model = timm.create_model(args.model, pretrained=True, num_classes=45, **args.model_kwargs)
model = timm.create_model(
args.model, pretrained=True, num_classes=45, **args.model_kwargs
)
if args.checkpoint_activations:
model.set_grad_checkpointing()
if args.full_bf16:
Expand Down Expand Up @@ -231,9 +249,15 @@ def evaluate_model(model, args):
optim_cls = OPTIM_MAP[args.optim]

if args.optim_cpu_offload == "ao":
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls)
optim_cls = partial(
low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls
)
elif args.optim_cpu_offload == "ao_offload_grads":
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True)
optim_cls = partial(
low_bit_optim.CPUOffloadOptimizer,
optimizer_class=optim_cls,
offload_gradients=True,
)

optim = optim_cls(
model.parameters(),
Expand All @@ -250,17 +274,23 @@ def evaluate_model(model, args):
step = 0
for epoch_idx in range(args.n_epochs):
model.train()
pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}")
pbar = tqdm(
dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"
)

with torch.profiler.profile() if args.profile else nullcontext() as prof:
for batch in pbar:
if args.full_bf16:
batch["image"] = batch["image"].bfloat16()
if args.channels_last:
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
batch["image"] = batch["image"].to(
memory_format=torch.channels_last
)

with get_amp_ctx(args.amp, _DEVICE):
loss = F.cross_entropy(model(batch["image"].to(_DEVICE)), batch["label"].to(_DEVICE))
loss = F.cross_entropy(
model(batch["image"].to(_DEVICE)), batch["label"].to(_DEVICE)
)

if args.optim_cpu_offload == "deepspeed":
model.backward(loss)
Expand All @@ -279,7 +309,9 @@ def evaluate_model(model, args):
log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"])
if step > 0:
t1 = time.perf_counter()
log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0)
log_dict["imgs_per_second"] = (
args.batch_size * log_interval / (t1 - t0)
)
t0 = t1
logger.log(log_dict, step=step)

Expand All @@ -300,7 +332,9 @@ def evaluate_model(model, args):

else:
val_acc = evaluate_model(model, args)
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
print(
f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}"
)
logger.log(dict(val_acc=val_acc), step=step)

peak_mem = getattr(torch, _DEVICE).max_memory_allocated() / 1e9
Expand Down
40 changes: 30 additions & 10 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def test_bf16_stochastic_round(self, device, compile):
x = torch.rand(32, device=device) * 100
x_rep = x.view(-1, 1).repeat(1, 100_000)

func = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile)
func = torch.compile(
_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile
)
x_rep_bf16 = func(x_rep)
assert x_rep_bf16.dtype is torch.bfloat16

Expand Down Expand Up @@ -170,8 +172,13 @@ def test_subclass_slice(self, subclass, shape, device):
tensor = subclass.zeros(shape, device=device)
offset = shape[0] // 2

torch.testing.assert_close(tensor.dequantize()[:offset], tensor[:offset].dequantize())
torch.testing.assert_close(tensor.dequantize()[offset:offset*2], tensor[offset:offset*2].dequantize())
torch.testing.assert_close(
tensor.dequantize()[:offset], tensor[:offset].dequantize()
)
torch.testing.assert_close(
tensor.dequantize()[offset : offset * 2],
tensor[offset : offset * 2].dequantize(),
)

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
@pytest.mark.skipif(
Expand All @@ -189,7 +196,9 @@ def test_optim_8bit_correctness(self, optim_name):
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
optim2 = getattr(low_bit_optim, optim_name)(
model2.parameters(), block_size=block_size
)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand Down Expand Up @@ -246,7 +255,7 @@ def test_optim_4bit_correctness(self, optim_name):

@pytest.mark.skipif(
not torch.cuda.is_available() and not torch.xpu.is_available(),
reason="optim CPU offload requires CUDA or XPU"
reason="optim CPU offload requires CUDA or XPU",
)
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
Expand Down Expand Up @@ -282,13 +291,15 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):

@pytest.mark.skipif(
not torch.cuda.is_available() and not torch.xpu.is_available(),
reason="optim CPU offload requires CUDA or XPU"
reason="optim CPU offload requires CUDA or XPU",
)
def test_optim_cpu_offload_save_load(self):
device = _DEVICES[-1]
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
model1.to(device)
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)
optim1 = low_bit_optim.CPUOffloadOptimizer(
model1.parameters(), torch.optim.AdamW
)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand All @@ -303,7 +314,9 @@ def test_optim_cpu_offload_save_load(self):

# resume training
model2 = copy.deepcopy(model1)
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
optim2 = low_bit_optim.CPUOffloadOptimizer(
model2.parameters(), torch.optim.AdamW
)
optim2.load_state_dict(state_dict)

for _ in range(2):
Expand Down Expand Up @@ -387,7 +400,11 @@ def _test_fsdp2(self, optim_cls):
import torch.utils._pytree as pytree
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.tensor import DTensor
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
TransformerBlock,
)

batch_size = 3
vocab_size = 1024
Expand Down Expand Up @@ -460,7 +477,10 @@ def _test_fsdp2(self, optim_cls):

subclasses = (OptimState4bit, OptimState8bit, OptimStateFp8)

for v1, v2 in zip(pytree.tree_iter(resumed_fsdp_optim.state_dict()), pytree.tree_iter(fsdp_optim.state_dict())):
for v1, v2 in zip(
pytree.tree_iter(resumed_fsdp_optim.state_dict()),
pytree.tree_iter(fsdp_optim.state_dict()),
):
assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__)
if isinstance(v1, DTensor):
v1 = v1.to_local()
Expand Down
5 changes: 4 additions & 1 deletion torchao/prototype/low_bit_optim/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def __init__(
self.param_d2h_map = dict()
self.optim_dict = dict()
self.device = get_available_devices()[-1]
assert self.device in ["cuda", "xpu"], "CPU Offload currently only supports CUDA & XPU"
assert self.device in [
"cuda",
"xpu",
], "CPU Offload currently only supports CUDA & XPU"
self.stream = getattr(torch, self.device).Stream()

# the queue maintains the order which param we should do optim step on first.
Expand Down

0 comments on commit 649b00b

Please sign in to comment.