From a5e3dbc3d787967f7820839c754389aa3b9979d3 Mon Sep 17 00:00:00 2001 From: 5eqn Date: Mon, 13 Jan 2025 08:23:40 +0800 Subject: [PATCH 1/5] FIX: prune redundant fields --- src/peft/tuners/lora/config.py | 7 ++++++- src/peft/tuners/lora/corda.py | 25 +++++++++++++++++-------- src/peft/tuners/lora/layer.py | 3 +++ tests/test_initialization.py | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 9 deletions(-) diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index df1fc06958..27d8c23a21 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -145,6 +145,8 @@ class CordaConfig: use_float16_for_covariance (`bool`): If true, uses float16 for the covariance matrix. This can reduce the memory usage of the covariance matrix by half, but may lead to numerical instability. Defaults to `False`. + prune_temporary_fields (`bool`): + If true, temporary fields generated in CorDA preprocessing will be pruned. Defaults to `True`. """ cache_file: Optional[str] = field( @@ -189,6 +191,9 @@ class CordaConfig: ) }, ) + prune_temporary_fields: bool = field( + default=True, metadata={"help": "If true, temporary fields generated in CorDA preprocessing will be pruned."} + ) @dataclass @@ -232,7 +237,7 @@ class LoraConfig(PeftConfig): How to initialize the weights of the adapter layers. Passing True (default) results in the default initialization from the reference implementation from Microsoft. Passing 'gaussian' results in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization to False leads to - completely random initialization and is discouraged. Pass `'loftq'` to use LoftQ initialization. Passing + completely hrandom initialization and is discouraged. Pass `'loftq'` to use LoftQ initialization. Passing `'eva'` results in a data-driven initialization of Explained Variance Adaptation. EVA initalizes LoRA based on the SVD of layer input activations and achieves SOTA performance due to its ability to adapt to the finetuning data. Pass `'olora'` to use OLoRA initialization. diff --git a/src/peft/tuners/lora/corda.py b/src/peft/tuners/lora/corda.py index 0d1d70b1a8..5a01c773c0 100644 --- a/src/peft/tuners/lora/corda.py +++ b/src/peft/tuners/lora/corda.py @@ -75,10 +75,6 @@ def preprocess_corda( you want to hook a different model than the one you are training, typically you should leave this `None`. Upon completion, the following fields are set for each target module: - corda_method (`Literal["ipm", "kpm"]`): - CorDA method to apply. "ipm" for Instruction-Previewed Mode, "kpm" for Knowledge-Preserved Mode. - rank (`int`): - Rank of CorDA to apply. eigens.S_WC (`torch.Tensor`): Singular values of the weight matrix. eigens.U_WC (`torch.Tensor`): @@ -90,13 +86,12 @@ def preprocess_corda( covariance_file = lora_config.corda_config.covariance_file corda_method = lora_config.corda_config.corda_method verbose = lora_config.corda_config.verbose + prune_temporary_fields = lora_config.corda_config.prune_temporary_fields # If cache exists, skip building if cache_file is not None and os.path.exists(cache_file) and os.path.getsize(cache_file) > 0: cache = torch.load(cache_file, map_location=get_model_device(model)) for name, module in target_modules(model, lora_config): - module.corda_method = cache[f"{name}.corda_method"] - module.rank = cache[f"{name}.rank"] module.eigens = CordaEigens( S_WC=cache[f"{name}.eigens.S_WC"], U_WC=cache[f"{name}.eigens.U_WC"], @@ -123,12 +118,26 @@ def preprocess_corda( # Crop CorDA eigens so that there's less to save crop_corda_eigens(model, lora_config) + # Remove redundant fields if exist + if prune_temporary_fields: + for name, module in target_modules(model, lora_config): + if hasattr(module, "sample_count"): + del module.sample_count + if hasattr(module, "covariance_matrix"): + del module.covariance_matrix + if hasattr(module, "mean"): + del module.mean + if hasattr(module, "std"): + del module.std + if hasattr(module, "corda_method"): + del module.corda_method + if hasattr(module, "rank"): + del module.rank + # Save cache to disk if cache_file is not None: cache: dict[str, Any] = {} for name, module in target_modules(model, lora_config): - cache[f"{name}.corda_method"] = module.corda_method - cache[f"{name}.rank"] = module.rank cache[f"{name}.eigens.S_WC"] = module.eigens.S_WC cache[f"{name}.eigens.U_WC"] = module.eigens.U_WC cache[f"{name}.eigens.V_WC"] = module.eigens.V_WC diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 20bef8ed10..1a94c8cec5 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -349,6 +349,9 @@ def corda_init(self, adapter_name, init_lora_weights): weight = weight.to(dtype) self.get_base_layer().weight.data = weight + # Remove redundant fields + del linear.eigens + def loftq_init(self, adapter_name): from peft.utils.loftq_utils import loftq_init diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 15ad34ea89..8e1956cc18 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -1916,6 +1916,36 @@ def data(self): torch.manual_seed(233) return torch.rand(1000, 1000).to(self.torch_device) + @pytest.mark.parametrize("corda_method", ("ipm", "kpm")) + def test_lora_corda_no_redundant_fields(self, data, corda_method): + original_model = self.get_model() + model = deepcopy(original_model) + + corda_config = CordaConfig( + corda_method=corda_method, + ) + config = LoraConfig( + init_lora_weights="corda", + target_modules=["linear"], + corda_config=corda_config, + ) + preprocess_corda( + model, + config, + run_model=lambda: model(data), + hooked_model=model, + ) + peft_model = get_peft_model(model, config) + + # check if the redundant fields are removed + assert not hasattr(peft_model.base_model.linear, "sample_count") + assert not hasattr(peft_model.base_model.linear, "covariance_matrix") + assert not hasattr(peft_model.base_model.linear, "mean") + assert not hasattr(peft_model.base_model.linear, "std") + assert not hasattr(peft_model.base_model.linear, "corda_method") + assert not hasattr(peft_model.base_model.linear, "rank") + assert not hasattr(peft_model.base_model.linear, "eigens") + @pytest.mark.parametrize("corda_method", ("ipm", "kpm")) def test_lora_corda_sample_count(self, data, corda_method): original_model = self.get_model() @@ -1923,6 +1953,7 @@ def test_lora_corda_sample_count(self, data, corda_method): corda_config = CordaConfig( corda_method=corda_method, + prune_temporary_fields=False, ) config = LoraConfig( init_lora_weights="corda", @@ -1960,6 +1991,7 @@ def hook(*args): corda_config = CordaConfig( corda_method=corda_method, + prune_temporary_fields=False, ) config = LoraConfig( init_lora_weights="corda", From 6f7011610b951b9fcb54a4ce093d9b6380b0447b Mon Sep 17 00:00:00 2001 From: 5eqn Date: Mon, 13 Jan 2025 08:44:09 +0800 Subject: [PATCH 2/5] FIX: remove mean and std --- src/peft/tuners/lora/corda.py | 16 ---------------- tests/test_initialization.py | 6 ++++-- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/src/peft/tuners/lora/corda.py b/src/peft/tuners/lora/corda.py index 5a01c773c0..3d44be946f 100644 --- a/src/peft/tuners/lora/corda.py +++ b/src/peft/tuners/lora/corda.py @@ -125,10 +125,6 @@ def preprocess_corda( del module.sample_count if hasattr(module, "covariance_matrix"): del module.covariance_matrix - if hasattr(module, "mean"): - del module.mean - if hasattr(module, "std"): - del module.std if hasattr(module, "corda_method"): del module.corda_method if hasattr(module, "rank"): @@ -183,15 +179,9 @@ def hook(module, input, output): "Invalid value found in covariance. Please file an issue at https://github.com/huggingface/peft/issues." ) - # calculate mean and std - mean = input.mean(0) - std = input.std(0) - # add to module module.sample_count += 1 module.covariance_matrix += covariance - module.mean += mean - module.std += std # free memory del covariance, input @@ -200,8 +190,6 @@ def hook(module, input, output): for name, module in target_modules(hooked_model, config): module.sample_count = 0 module.covariance_matrix = 0 - module.mean = 0 - module.std = 0 handles.append(module.register_forward_hook(hook)) run_model() @@ -222,14 +210,10 @@ def hook(module, input, output): if name in targets: targets[name].sample_count = module.sample_count targets[name].covariance_matrix = module.covariance_matrix - targets[name].mean = module.mean - targets[name].std = module.std # Divide by sample count for name, module in target_modules(model, config): module.covariance_matrix /= module.sample_count - module.mean /= module.sample_count - module.std /= module.sample_count # Save covariance to disk if covariance_file is not None: diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 8e1956cc18..d85c25404a 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -1940,12 +1940,14 @@ def test_lora_corda_no_redundant_fields(self, data, corda_method): # check if the redundant fields are removed assert not hasattr(peft_model.base_model.linear, "sample_count") assert not hasattr(peft_model.base_model.linear, "covariance_matrix") - assert not hasattr(peft_model.base_model.linear, "mean") - assert not hasattr(peft_model.base_model.linear, "std") assert not hasattr(peft_model.base_model.linear, "corda_method") assert not hasattr(peft_model.base_model.linear, "rank") assert not hasattr(peft_model.base_model.linear, "eigens") + # legacy debug fields + assert not hasattr(peft_model.base_model.linear, "mean") + assert not hasattr(peft_model.base_model.linear, "std") + @pytest.mark.parametrize("corda_method", ("ipm", "kpm")) def test_lora_corda_sample_count(self, data, corda_method): original_model = self.get_model() From 5703009ae6bc53e266379f105680de3fbeaca0cd Mon Sep 17 00:00:00 2001 From: 5eqn Date: Mon, 13 Jan 2025 21:21:12 +0800 Subject: [PATCH 3/5] DOC: memory and sample size --- examples/corda_finetuning/README.md | 5 +++++ src/peft/tuners/lora/corda.py | 11 +++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/corda_finetuning/README.md b/examples/corda_finetuning/README.md index c248e99ae1..f07672f7a5 100644 --- a/examples/corda_finetuning/README.md +++ b/examples/corda_finetuning/README.md @@ -100,7 +100,12 @@ lora_config = LoraConfig( init_lora_weights="corda", corda_config=corda_config, ) + +# Call `preprocess_corda` first to collect covariance matrix and build SVD result for model +# For more details, please refer to documentation of `preprocess_corda` preprocess_corda(model, lora_config, run_model=run_model) + +# Call `get_peft_model` after preprocessing, or else you'll encounter error peft_model = get_peft_model(model, lora_config) peft_model.print_trainable_parameters() diff --git a/src/peft/tuners/lora/corda.py b/src/peft/tuners/lora/corda.py index 3d44be946f..8b991d276e 100644 --- a/src/peft/tuners/lora/corda.py +++ b/src/peft/tuners/lora/corda.py @@ -61,6 +61,10 @@ def preprocess_corda( """ Build necessary CorDA fields for a model. + For each `M * N` linear layer, a `M * M` covariance matrix will be built temporarily during the preprocessing + process, consuming roughly another `2 * MODEL_SIZE` memory for typical LLMs if model weight is FP16 and covariance + is FP32. If that's too much, consider specifying `use_float16_for_covariance` in `lora_config.corda_config`. + Args: model (`nn.Module`): Model to preprocess. @@ -68,8 +72,11 @@ def preprocess_corda( Lora configuration of the model. `lora_config.corda_config` should be set. run_model (`Optional[Callable[[], None]]`): Callback to run the model when building covariance. Typically you should run model inference on your sample - dataset in this callback. Experiments have shown 256 samples to be a good default dataset size. `run_model` - can be `None` only if covariance file in `lora_config.corda_config` is already created. + dataset in this callback. Experiments have shown that when token count per sample is 2048, hidden dimension + is 4096, collecting 256 distinct samples is enough. If you collect too few or too repetitive samples, the + covariance matrix may be low-ranked and unstabilize preprocessing. You can estimate sample count as + `HIDDEN_DIM / TOKEN_PER_SAMPLE * 128`. `run_model` can be `None` only if covariance file in + `lora_config.corda_config` is already created. hooked_model (`Optional[nn.Module]`): Model to hook when building covariance. If none, original model will be hooked. This is only useful when you want to hook a different model than the one you are training, typically you should leave this `None`. From fba893fc3e9b5827f0ab6cc30b35cfdabbde674c Mon Sep 17 00:00:00 2001 From: 5eqn Date: Mon, 13 Jan 2025 21:24:07 +0800 Subject: [PATCH 4/5] FIX: remove typo --- src/peft/tuners/lora/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 27d8c23a21..b36de0c43a 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -237,7 +237,7 @@ class LoraConfig(PeftConfig): How to initialize the weights of the adapter layers. Passing True (default) results in the default initialization from the reference implementation from Microsoft. Passing 'gaussian' results in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization to False leads to - completely hrandom initialization and is discouraged. Pass `'loftq'` to use LoftQ initialization. Passing + completely random initialization and is discouraged. Pass `'loftq'` to use LoftQ initialization. Passing `'eva'` results in a data-driven initialization of Explained Variance Adaptation. EVA initalizes LoRA based on the SVD of layer input activations and achieves SOTA performance due to its ability to adapt to the finetuning data. Pass `'olora'` to use OLoRA initialization. From 8826f2bb86880e42f54186d94676d2817108d5da Mon Sep 17 00:00:00 2001 From: 5eqn Date: Wed, 15 Jan 2025 06:46:18 +0800 Subject: [PATCH 5/5] FIX: import --- examples/corda_finetuning/preprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/corda_finetuning/preprocess.py b/examples/corda_finetuning/preprocess.py index 01721d296e..15bb18cb6b 100644 --- a/examples/corda_finetuning/preprocess.py +++ b/examples/corda_finetuning/preprocess.py @@ -21,7 +21,7 @@ from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer -from peft.mapping import get_peft_model +from peft import get_peft_model from peft.tuners.lora.config import CordaConfig, LoraConfig from peft.tuners.lora.corda import preprocess_corda