-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlora.py
27 lines (22 loc) · 830 Bytes
/
lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoraLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
lora_alpha: float = 0.0,
lora_rank=int,
device=None,
dtype=None,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
self.lora_A = nn.Parameter(torch.randn( lora_rank, in_features))
self.lora_B = nn.Parameter(torch.zeros( out_features, lora_rank))
self.scale = lora_alpha/float(lora_rank)
self.weight.requires_grad = False
def forward(self, input: torch.Tensor) -> torch.Tensor:
delta = torch.matmul(self.lora_B, self.lora_A*self.scale)
return super().forward(input) + F.linear(input, delta)