forked from metavoiceio/metavoice-src
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lora.py
144 lines (113 loc) · 4.7 KB
/
lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from torch import nn, Tensor
import torch
from fam.llm.fast_model import Transformer
from torch.nn import functional as F
import math
def get_lora_model(model: nn.Module) -> nn.Module:
for name, param in model.named_parameters():
if "lora" in name:
print("Enabling gradient for LoRA parameter:", name)
param.requires_grad = True
else:
param.requires_grad = False
return model
class LoRALinear(nn.Linear):
def __init__(self,
#nn.linear
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
#LoRA parameters
lora_rank: int = 0,
lora_alpha: float= 0.0,
lora_dropout: float = 0.0
) -> None:
nn.Linear.__init__(
self,
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype
)
#LoRA stuff
self.has_weights_merged = False
if lora_rank > 0:
self.lora_dropout = nn.Dropout(lora_dropout)
self.lora_scaling = lora_alpha / lora_rank
self.lora_A = nn.Parameter(torch.empty((lora_rank, self.in_features), device=device, dtype=dtype))
self.lora_B = nn.Parameter(torch.empty((self.out_features, lora_rank), device=device, dtype=dtype))
self.lora_A.requires_grad = False
self.lora_B.requires_grad = False
self.reset_parameters()
def is_lora(self) -> bool:
return hasattr(self, 'lora_A')
def reset_parameters(self) -> None:
nn.Linear.reset_parameters(self)
if self.is_lora():
torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_B)
def forward(self, input: torch.Tensor) -> torch.Tensor:
x = nn.Linear.forward(self, input)
if not self.has_weights_merged and self.is_lora():
x += self.lora_scaling * F.linear(
F.linear(
self.lora_dropout(input),
self.lora_A
),
self.lora_B
)
return x
def extra_repr(self) -> str:
out = nn.Linear.extra_repr(self)
if self.is_lora():
out += f', lora_rank={self.lora_A.shape[0]}, lora_scaling={self.lora_scaling}, lora_dropout={self.lora_dropout.p}'
return out
def train(self, mode: bool = True) -> "LoRALinear":
nn.Linear.train(self, mode)
if self.has_weights_merged and self.is_lora():
self.weight.data -= self.lora_scaling * self.lora_B @ self.lora_A
self.has_weights_merged = False
return self
def eval(self) -> "LoRALinear":
nn.Linear.eval(self)
if not self.has_weights_merged and self.is_lora():
self.weight.data += self.lora_scaling * self.lora_B @ self.lora_A
self.has_weights_merged = True
return self
class TransformerWithLoRA(nn.Module):
def __init__(self, base_model: Transformer, rank: int = 8, alpha: int = 16, dropout: float = 0.1, training_mode: bool =True):
super().__init__()
self.config = base_model.config
#LoRALinear injections into attention layers
for i, layer in enumerate(base_model.layers):
if i == 1:
break
layer.attention.wqkv = LoRALinear(
in_features=layer.attention.wqkv.in_features,
out_features=layer.attention.wqkv.out_features,
lora_rank=rank,
lora_alpha=alpha,
lora_dropout=dropout
)
layer.attention.wo = LoRALinear(
in_features=layer.attention.wo.in_features,
out_features=layer.attention.wo.out_features,
lora_rank=rank,
lora_alpha=alpha,
lora_dropout=dropout
)
if training_mode:
self.base_model = get_lora_model(base_model)
def forward(self, idx: Tensor, spk_emb: Tensor, input_pos: Tensor, targets: Tensor = None, debug_mode = False):
return self.base_model(idx, spk_emb, input_pos, targets, debug_mode)
def setup_spk_cond_mask(self):
self.base_model.setup_spk_cond_mask()
def setup_caches(self, *args, **kwargs):
self.base_mdoel.setup_caches(*args, **kwargs)
def save_lora(self, path: str):
torch.save(self.base_model.speaker_cond_pos.state_dict(), path)
def load_lora(self, path: str):
self.base_model.speaker_cond_pos.load_state_dict(torch.load(path))