Skip to content

Commit

Permalink
added review updates
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-fong committed Dec 24, 2024
1 parent 99cfb68 commit 1229fc5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
13 changes: 8 additions & 5 deletions src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,11 @@ class LoRAConfig(AdapterConfig):
composition_mode: str = "add"
init_weights: str = "lora"
use_gating: bool = False
d: Union[bool, float] = None
b: Union[bool, float] = None
vera_d: float = None
vera_b: float = None
dtype: Optional[str] = None


@dataclass(eq=False)
class IA3Config(LoRAConfig):
"""
Expand Down Expand Up @@ -542,7 +543,7 @@ class VeraConfig(LoRAConfig):
Lora Config that applies vector-based random matrix adaptation. It adds
trainable matrices 'd' and 'b' while keeping the original LoRA matrices
frozen, random, and shared across layers. See more through their paper:
https://arxiv.org/pdf/2106.09685. Note that `r` will still be supplied
https://arxiv.org/pdf/2310.11454. Note that `r` will still be supplied
since we are still initializing decomposition matrices A and B.
The `composition_mode` parameter should also be set to `add`.
"""
Expand All @@ -552,9 +553,11 @@ class VeraConfig(LoRAConfig):
output_lora: bool = False

r: int = 8
d: Union[bool, float] = 0.1
b: Union[bool, float] = 0.0
vera_d: float = 0.1
vera_b: float = 0.0
init_weights: str = "vera"
composition_mode: str = "add"
dtype: Optional[str] = None


@dataclass(eq=False)
Expand Down
33 changes: 17 additions & 16 deletions src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
lora_B_shape,
config: LoRAConfig,
gating_heads: int = 1,
name: str = None,
):
super().__init__()
assert config.composition_mode == "add", "LoRA module only supports composition_mode='add'."
Expand All @@ -46,6 +47,7 @@ def __init__(
self.composition_mode = config.composition_mode
self.attn_matrices = config.attn_matrices
self.use_gating = config.use_gating
self.name = name
# Optional dropout
if config.dropout > 0.0:
self.lora_dropout = nn.Dropout(p=config.dropout)
Expand Down Expand Up @@ -115,6 +117,7 @@ def __init__(
lora_B_shape,
config: LoRAConfig,
gating_heads: int = 1,
name: str = None,
):
super().__init__()
assert config.composition_mode == "scale", "IA3 module only supports composition_mode='scale'."
Expand All @@ -125,6 +128,7 @@ def __init__(
self.composition_mode = config.composition_mode
self.attn_matrices = config.attn_matrices
self.use_gating = config.use_gating
self.name = name
# Optional dropout
if config.dropout > 0.0:
raise ValueError("IA3 module does not support dropout.")
Expand Down Expand Up @@ -186,13 +190,20 @@ def __init__(
lora_B_shape,
config: LoRAConfig,
gating_heads: int = 1,
name: str = None,
):
super().__init__()
self.d = config.d
self.b = config.b
self.d = config.vera_d
self.b = config.vera_b
self.r = config.r
self.alpha = config.alpha
self.use_gating = config.use_gating
self.name = name

# check to make sure that the `composition_mode` is set to `add`
self.composition_mode = config.composition_mode
if self.composition_mode != "add":
raise ValueError("Vera module only supports composition_mode='add'.")

# Optional dropout
if config.dropout > 0.0:
Expand Down Expand Up @@ -239,13 +250,8 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens

if getattr(self, "lora_dropout"):
hidden_states = self.lora_dropout(hidden_states)
# print(self.vera_B.shape)
# print(lora_B.shape)
# print(self.vera_D.shape)
# print(lora_A.shape)
# print((self.vera_B @ lora_B @ self.vera_D @ lora_A).shape)
# print(hidden_states.shape)
hidden_states = hidden_states @ torch.t(self.vera_B @ lora_B @ self.vera_D @ lora_A )

hidden_states = hidden_states @ torch.t(self.vera_B @ lora_B @ self.vera_D @ lora_A)

if self.use_gating:
gate = torch.sigmoid(self.gate(layer_input))
Expand All @@ -256,9 +262,6 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens

return hidden_states, gate

def set_vera_adapter_name(self, name):
self.name = name


def init_shared_vera_parameters(model_config, adapter_config, device):
hidden_size = model_config.hidden_size
Expand Down Expand Up @@ -323,7 +326,7 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
)
if lora_config is not None and self._check_lora_location(lora_config):
if lora_config.composition_mode == "add":
if isinstance(lora_config.d, float) or isinstance(lora_config.b, float):
if isinstance(lora_config.vera_d, float) or isinstance(lora_config.vera_b, float):
lora_cls = Vera
else:
lora_cls = LoRA
Expand All @@ -335,10 +338,8 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
*self._get_lora_shapes(lora_config),
lora_config,
gating_heads=self.get_n_heads(lora_config),
name=adapter_name,
)
# if we're using Vera, then set the adapter name into the Vera object
if lora_cls == Vera:
lora.set_vera_adapter_name(name=adapter_name)

lora.train(self.training)
lora = lora.to(self.weight.device)
Expand Down

0 comments on commit 1229fc5

Please sign in to comment.