Skip to content

Commit

Permalink
Reviewer feedback:
Browse files Browse the repository at this point in the history
- Allow 2nd adapter to target subset of 1st
- Update tests to reflect this
- Improve docs
  • Loading branch information
BenjaminBossan committed Feb 3, 2025
1 parent d32d23f commit 8b7e602
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docs/source/package_reference/hotswap.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Hotswapping works with transformers models and diffusers models. However, there

- Right now, only LoRA is properly supported.
- It only works for the same PEFT method, so no swapping LoRA and LoHa, for example.
- The adapters must be compatible (e.g. same target modules).
- The adapter that is being swapped in must target the same layers as the previous adapter or a subset of those layers. It cannot target new layers. Therefore, if possible, start with the adapter that targets most layers.

[[autodoc]] utils.hotswap.hotswap_adapter
- all
Expand Down
41 changes: 33 additions & 8 deletions src/peft/utils/hotswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,21 @@ def prepare_model_for_compiled_hotswap(
config (`LoraConfig` or `dict[str, LoraConfig]`, *optional*):
Optionally pass the `LoraConfig`s of the LoRA adapters. If passed, the rank in the configs will be updated
to `target_rank`.
Example:
```py
base_model = ...
model = PeftModel.from_pretrained(base_model, path_adapter_0)
# Prepare the model to allow hotswapping even if ranks/scalings of 2nd adapter differ.
# You can skip this step if all ranks and scalings are identical.
prepare_model_for_compiled_hotswap(model, target_rank=highest_lora_rank)
model = torch.compile(model)
# do inference with adapter 0
# replace the "default" lora adapter with the new one
hotswap_adapter(model, path_adapter_1, adapter_name="default", torch_device=device)
# do inference with adapter 1
```
"""
is_compiled = hasattr(model, "_orig_mod")
if is_compiled:
Expand Down Expand Up @@ -343,29 +358,39 @@ def hotswap_adapter_from_state_dict(
is_compiled = hasattr(model, "_orig_mod")
# TODO: there is probably a more precise way to identify the adapter keys
missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)}
unexpected_keys = set()
unexpected_keys = []

# first: dry run, not swapping anything
for key, new_val in state_dict.items():
try:
old_val = attrgetter(key)(model)
except AttributeError:
unexpected_keys.add(key)
unexpected_keys.append(key)
continue

if is_compiled:
missing_keys.remove("_orig_mod." + key)
else:
missing_keys.remove(key)

if missing_keys or unexpected_keys:
msg = "Hot swapping the adapter did not succeed."
if missing_keys:
msg += f" Missing keys: {', '.join(sorted(missing_keys))}."
if unexpected_keys:
msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}."
# Right now, we don't deal with unexpected keys, i.e. if the adapter being swapped in targeting new layers. We could
# probably add LoRA to these layers ad hoc, but that would not work with compiled models.
if unexpected_keys:
msg = f"Hot swapping the adapter did not succeed, unexpected keys found: {', '.join(unexpected_keys)}."
raise RuntimeError(msg)

# If the adapter that is being swapped in is missing some keys, this is fine. We just need to ensure that those LoRA
# weights from the previous adapter are set to 0 so that they don't influence the output. We don't need to worry
# about ranks are alphas.
for key in missing_keys:
# in case it's a compiled model
key = key.removeprefix("_orig_mod.")
# get LoRA parent module name by removing the 'lora_*.<adapter-name>.weight' part
module_name = ".".join(key.split(".")[:-3])
module = model.get_submodule(module_name)
old_val = attrgetter(key)(model)
old_val.data.fill_(0.0)

# actual swapping
for key, new_val in state_dict.items():
# get LoRA parent module name by removing the 'lora_*.<adapter-name>.weight' part
Expand Down
22 changes: 18 additions & 4 deletions tests/run_compiled_model_hotswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@


torch_device = infer_device()
inputs = torch.arange(10).view(-1, 1)


def check_hotswap(do_hotswap=True, ranks=(8, 8), alpha_scalings=(16, 16)):
Expand All @@ -42,10 +43,21 @@ def check_hotswap(do_hotswap=True, ranks=(8, 8), alpha_scalings=(16, 16)):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM").to(torch_device)
rank0, rank1 = ranks
alpha0, alpha1 = alpha_scalings
config0 = LoraConfig(init_lora_weights=False, r=rank0, lora_alpha=alpha0)
config1 = LoraConfig(init_lora_weights=False, r=rank1, lora_alpha=alpha1)
# note that the 2nd adapter targeting a subset of the 1st adapter is okay, but not the other way round
config0 = LoraConfig(init_lora_weights=False, r=rank0, lora_alpha=alpha0, target_modules=["q_proj", "v_proj"])
config1 = LoraConfig(init_lora_weights=False, r=rank1, lora_alpha=alpha1, target_modules=["q_proj"])
model = get_peft_model(model, config0, adapter_name="adapter0").eval()
with torch.inference_mode():
output0 = model(inputs).logits

model.add_adapter("adapter1", config1)
model.set_adapter("adapter1")
with torch.inference_mode():
output1 = model(inputs).logits

# sanity check:
tol = 1e-5
assert not torch.allclose(output0, output1, atol=tol, rtol=tol)

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)
Expand All @@ -56,7 +68,8 @@ def check_hotswap(do_hotswap=True, ranks=(8, 8), alpha_scalings=(16, 16)):
if do_hotswap:
prepare_model_for_compiled_hotswap(model, config=model.peft_config, target_rank=max(ranks))
model = torch.compile(model, mode="reduce-overhead")
model(inputs).logits
output_after0 = model(inputs).logits
assert torch.allclose(output0, output_after0, atol=tol, rtol=tol)

# swap and check that we get the output from adapter1
if do_hotswap:
Expand All @@ -66,7 +79,8 @@ def check_hotswap(do_hotswap=True, ranks=(8, 8), alpha_scalings=(16, 16)):
model.set_adapter("other")

# we need to call forward to potentially trigger recompilation
model(inputs).logits
output_after1 = model(inputs).logits
assert torch.allclose(output1, output_after1, atol=tol, rtol=tol)


if __name__ == "__main__":
Expand Down
14 changes: 8 additions & 6 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2969,8 +2969,8 @@ def test_hotswap_wrong_peft_types_raises(self, tmp_path):
with pytest.raises(ValueError, match=msg):
hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default")

def test_hotswap_missing_key_raises(self, tmp_path):
# When a key is missing, raise
def test_hotswap_missing_key_works(self, tmp_path):
# When a key is missing, it is fine, the extra weight is zeroed out
config = LoraConfig(target_modules=["lin0", "lin1"])

model = self.get_model()
Expand All @@ -2993,9 +2993,11 @@ def test_hotswap_missing_key_raises(self, tmp_path):
model = self.get_model()
model = PeftModel.from_pretrained(model, tmp_path / "adapter0")

msg = f"Hot swapping the adapter did not succeed. Missing keys: {key}"
with pytest.raises(RuntimeError, match=msg):
hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default")
# sanity check: the missing weight is not already all zeros
assert not (model.base_model.model.lin1.lora_A["default"].weight == 0).all()
hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default")
# after hotswapping, it is zeroed out
assert (model.base_model.model.lin1.lora_A["default"].weight == 0).all()

def test_hotswap_extra_key_raises(self, tmp_path):
# When there is an extra key, raise
Expand All @@ -3021,7 +3023,7 @@ def test_hotswap_extra_key_raises(self, tmp_path):
model = self.get_model()
model = PeftModel.from_pretrained(model, tmp_path / "adapter0")

msg = f"Hot swapping the adapter did not succeed. Unexpected keys: {new_key}"
msg = f"Hot swapping the adapter did not succeed, unexpected keys found: {new_key}"
with pytest.raises(RuntimeError, match=msg):
hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default")

Expand Down

0 comments on commit 8b7e602

Please sign in to comment.