Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: reduce CorDA memory consumption + docs #2324

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/corda_finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ lora_config = LoraConfig(
init_lora_weights="corda",
corda_config=corda_config,
)

# Call `preprocess_corda` first to collect covariance matrix and build SVD result for model
# For more details, please refer to documentation of `preprocess_corda`
preprocess_corda(model, lora_config, run_model=run_model)

# Call `get_peft_model` after preprocessing, or else you'll encounter error
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()

Expand Down
2 changes: 1 addition & 1 deletion examples/corda_finetuning/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft.mapping import get_peft_model
from peft import get_peft_model
from peft.tuners.lora.config import CordaConfig, LoraConfig
from peft.tuners.lora.corda import preprocess_corda

Expand Down
5 changes: 5 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class CordaConfig:
use_float16_for_covariance (`bool`):
If true, uses float16 for the covariance matrix. This can reduce the memory usage of the covariance matrix
by half, but may lead to numerical instability. Defaults to `False`.
prune_temporary_fields (`bool`):
If true, temporary fields generated in CorDA preprocessing will be pruned. Defaults to `True`.
"""

cache_file: Optional[str] = field(
Expand Down Expand Up @@ -189,6 +191,9 @@ class CordaConfig:
)
},
)
prune_temporary_fields: bool = field(
default=True, metadata={"help": "If true, temporary fields generated in CorDA preprocessing will be pruned."}
)


@dataclass
Expand Down
44 changes: 22 additions & 22 deletions src/peft/tuners/lora/corda.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,27 @@ def preprocess_corda(
"""
Build necessary CorDA fields for a model.

For each `M * N` linear layer, a `M * M` covariance matrix will be built temporarily during the preprocessing
process, consuming roughly another `2 * MODEL_SIZE` memory for typical LLMs if model weight is FP16 and covariance
is FP32. If that's too much, consider specifying `use_float16_for_covariance` in `lora_config.corda_config`.

Args:
model (`nn.Module`):
Model to preprocess.
lora_config (`LoraConfig`):
Lora configuration of the model. `lora_config.corda_config` should be set.
run_model (`Optional[Callable[[], None]]`):
Callback to run the model when building covariance. Typically you should run model inference on your sample
dataset in this callback. Experiments have shown 256 samples to be a good default dataset size. `run_model`
can be `None` only if covariance file in `lora_config.corda_config` is already created.
dataset in this callback. Experiments have shown that when token count per sample is 2048, hidden dimension
is 4096, collecting 256 distinct samples is enough. If you collect too few or too repetitive samples, the
covariance matrix may be low-ranked and unstabilize preprocessing. You can estimate sample count as
`HIDDEN_DIM / TOKEN_PER_SAMPLE * 128`. `run_model` can be `None` only if covariance file in
`lora_config.corda_config` is already created.
hooked_model (`Optional[nn.Module]`):
Model to hook when building covariance. If none, original model will be hooked. This is only useful when
you want to hook a different model than the one you are training, typically you should leave this `None`.

Upon completion, the following fields are set for each target module:
corda_method (`Literal["ipm", "kpm"]`):
CorDA method to apply. "ipm" for Instruction-Previewed Mode, "kpm" for Knowledge-Preserved Mode.
rank (`int`):
Rank of CorDA to apply.
eigens.S_WC (`torch.Tensor`):
Singular values of the weight matrix.
eigens.U_WC (`torch.Tensor`):
Expand All @@ -90,13 +93,12 @@ def preprocess_corda(
covariance_file = lora_config.corda_config.covariance_file
corda_method = lora_config.corda_config.corda_method
verbose = lora_config.corda_config.verbose
prune_temporary_fields = lora_config.corda_config.prune_temporary_fields

# If cache exists, skip building
if cache_file is not None and os.path.exists(cache_file) and os.path.getsize(cache_file) > 0:
cache = torch.load(cache_file, map_location=get_model_device(model))
for name, module in target_modules(model, lora_config):
module.corda_method = cache[f"{name}.corda_method"]
module.rank = cache[f"{name}.rank"]
module.eigens = CordaEigens(
S_WC=cache[f"{name}.eigens.S_WC"],
U_WC=cache[f"{name}.eigens.U_WC"],
Expand All @@ -123,12 +125,22 @@ def preprocess_corda(
# Crop CorDA eigens so that there's less to save
crop_corda_eigens(model, lora_config)

# Remove redundant fields if exist
if prune_temporary_fields:
for name, module in target_modules(model, lora_config):
if hasattr(module, "sample_count"):
del module.sample_count
if hasattr(module, "covariance_matrix"):
del module.covariance_matrix
if hasattr(module, "corda_method"):
del module.corda_method
if hasattr(module, "rank"):
del module.rank

# Save cache to disk
if cache_file is not None:
cache: dict[str, Any] = {}
for name, module in target_modules(model, lora_config):
cache[f"{name}.corda_method"] = module.corda_method
cache[f"{name}.rank"] = module.rank
cache[f"{name}.eigens.S_WC"] = module.eigens.S_WC
cache[f"{name}.eigens.U_WC"] = module.eigens.U_WC
cache[f"{name}.eigens.V_WC"] = module.eigens.V_WC
Expand Down Expand Up @@ -174,15 +186,9 @@ def hook(module, input, output):
"Invalid value found in covariance. Please file an issue at https://github.com/huggingface/peft/issues."
)

# calculate mean and std
mean = input.mean(0)
std = input.std(0)

# add to module
module.sample_count += 1
module.covariance_matrix += covariance
module.mean += mean
module.std += std

# free memory
del covariance, input
Expand All @@ -191,8 +197,6 @@ def hook(module, input, output):
for name, module in target_modules(hooked_model, config):
module.sample_count = 0
module.covariance_matrix = 0
module.mean = 0
module.std = 0
handles.append(module.register_forward_hook(hook))

run_model()
Expand All @@ -213,14 +217,10 @@ def hook(module, input, output):
if name in targets:
targets[name].sample_count = module.sample_count
targets[name].covariance_matrix = module.covariance_matrix
targets[name].mean = module.mean
targets[name].std = module.std

# Divide by sample count
for name, module in target_modules(model, config):
module.covariance_matrix /= module.sample_count
module.mean /= module.sample_count
module.std /= module.sample_count

# Save covariance to disk
if covariance_file is not None:
Expand Down
3 changes: 3 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,9 @@ def corda_init(self, adapter_name, init_lora_weights):
weight = weight.to(dtype)
self.get_base_layer().weight.data = weight

# Remove redundant fields
del linear.eigens

def loftq_init(self, adapter_name):
from peft.utils.loftq_utils import loftq_init

Expand Down
34 changes: 34 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1916,13 +1916,46 @@ def data(self):
torch.manual_seed(233)
return torch.rand(1000, 1000).to(self.torch_device)

@pytest.mark.parametrize("corda_method", ("ipm", "kpm"))
def test_lora_corda_no_redundant_fields(self, data, corda_method):
original_model = self.get_model()
model = deepcopy(original_model)

corda_config = CordaConfig(
corda_method=corda_method,
)
config = LoraConfig(
init_lora_weights="corda",
target_modules=["linear"],
corda_config=corda_config,
)
preprocess_corda(
model,
config,
run_model=lambda: model(data),
hooked_model=model,
)
peft_model = get_peft_model(model, config)

# check if the redundant fields are removed
assert not hasattr(peft_model.base_model.linear, "sample_count")
assert not hasattr(peft_model.base_model.linear, "covariance_matrix")
assert not hasattr(peft_model.base_model.linear, "corda_method")
assert not hasattr(peft_model.base_model.linear, "rank")
assert not hasattr(peft_model.base_model.linear, "eigens")

# legacy debug fields
assert not hasattr(peft_model.base_model.linear, "mean")
assert not hasattr(peft_model.base_model.linear, "std")

@pytest.mark.parametrize("corda_method", ("ipm", "kpm"))
def test_lora_corda_sample_count(self, data, corda_method):
original_model = self.get_model()
model = deepcopy(original_model)

corda_config = CordaConfig(
corda_method=corda_method,
prune_temporary_fields=False,
)
config = LoraConfig(
init_lora_weights="corda",
Expand Down Expand Up @@ -1960,6 +1993,7 @@ def hook(*args):

corda_config = CordaConfig(
corda_method=corda_method,
prune_temporary_fields=False,
)
config = LoraConfig(
init_lora_weights="corda",
Expand Down
Loading