Skip to content

Commit

Permalink
update bnb and gptq fast lora to support bias
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Sep 28, 2024
1 parent 3993b8c commit e87f351
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,21 @@ class LoRA_MLP(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, X : torch.Tensor,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
gateW, gateW_quant, gate_bias, gateA, gateB, gateS,
upW, upW_quant, up_bias, upA, upB, upS,
downW, downW_quant, down_bias, downA, downB, downS,
_forward_function, _backward_function,
dropout_gate=None, dropout_up=None, dropout_down=None,
):
dtype = X.dtype

e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS, dropout=dropout_gate)
g = matmul_lora(X, upW, upW_quant, upA, upB, upS, dropout=dropout_up)
e += gate_bias
g += up_bias
h = _forward_function(e, g)
i = matmul_lora(h, downW, downW_quant, downA, downB, downS, dropout=dropout_down)
i += down_bias

# Extract post-dropout X for use in backward computation
_dropped_X = []
Expand Down Expand Up @@ -152,13 +155,13 @@ def backward(ctx, dY : torch.Tensor):
del gateW
dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())

# gateW, gateW_quant, gateA, gateB, gateS,
# upW, upW_quant, upA, upB, upS,
# downW, downW_quant, downA, downB, downS,
# gateW, gateW_quant, gate_bias, gateA, gateB, gateS,
# upW, upW_quant, up_bias, upA, upB, upS,
# downW, downW_quant, down_bias, downA, downB, downS,
return (dX.view(batch, seq_len, hd), \
None, None, d_gateA.t(), d_gateB.t(), None, \
None, None, d_upA.t(), d_upB.t(), None, \
None, None, d_downA.t(), d_downB.t(), None, \
None, None, None, d_gateA.t(), d_gateB.t(), None, \
None, None, None, d_upA.t(), d_upB.t(), None, \
None, None, None, d_downA.t(), d_downB.t(), None, \
None, None, # _backward and _forward
None, None, None, # dropout modules
)
Expand All @@ -168,13 +171,13 @@ def backward(ctx, dY : torch.Tensor):

from ..swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
def apply_lora_mlp_swiglu(self, X):
gateW, gateW_quant, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS, dropout_up = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj)
gateW, gateW_quant, gate_bias, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj)
upW, upW_quant, up_bias, upA, upB, upS, dropout_up = get_lora_parameters(self. up_proj)
downW, downW_quant, down_bias, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(X,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
gateW, gateW_quant, gate_bias, gateA, gateB, gateS,
upW, upW_quant, up_bias, upA, upB, upS,
downW, downW_quant, down_bias, downA, downB, downS,
swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,
dropout_gate, dropout_up, dropout_down,
)
Expand All @@ -184,13 +187,13 @@ def apply_lora_mlp_swiglu(self, X):

from ..geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
def apply_lora_mlp_geglu_exact(self, X):
gateW, gateW_quant, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS, dropout_up = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj)
gateW, gateW_quant, gate_bias, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj)
upW, upW_quant, up_bias, upA, upB, upS, dropout_up = get_lora_parameters(self. up_proj)
downW, downW_quant, down_bias, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(X,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
gateW, gateW_quant, gate_bias, gateA, gateB, gateS,
upW, upW_quant, up_bias, upA, upB, upS,
downW, downW_quant, down_bias, downA, downB, downS,
geglu_exact_forward_kernel, geglu_exact_backward_kernel,
dropout_gate, dropout_up, dropout_down,
)
Expand All @@ -200,13 +203,13 @@ def apply_lora_mlp_geglu_exact(self, X):

from ..geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
def apply_lora_mlp_geglu_approx(self, X):
gateW, gateW_quant, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS, dropout_up = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj)
gateW, gateW_quant, gate_bias, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj)
upW, upW_quant, up_bias, upA, upB, upS, dropout_up = get_lora_parameters(self. up_proj)
downW, downW_quant, down_bias, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(X,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS,
gateW, gateW_quant, gate_bias, gateA, gateB, gateS,
upW, upW_quant, up_bias, upA, upB, upS,
downW, downW_quant, down_bias, downA, downB, downS,
geglu_approx_forward_kernel, geglu_approx_backward_kernel,
dropout_gate, dropout_up, dropout_down,
)
Expand Down Expand Up @@ -247,9 +250,9 @@ class LoRA_QKV(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, X : torch.Tensor,
QW, QW_quant, QA, QB, QS,
KW, KW_quant, KA, KB, KS,
VW, VW_quant, VA, VB, VS,
QW, QW_quant, Q_bias, QA, QB, QS,
KW, KW_quant, K_bias, KA, KB, KS,
VW, VW_quant, V_bias, VA, VB, VS,
dropout_Q=None, dropout_K=None, dropout_V=None
):
dtype = X.dtype
Expand All @@ -258,6 +261,10 @@ def forward(ctx, X : torch.Tensor,
K = matmul_lora(X, KW, KW_quant, KA, KB, KS, dropout=dropout_K)
V = matmul_lora(X, VW, VW_quant, VA, VB, VS, dropout=dropout_V)

Q += Q_bias
K += K_bias
V += V_bias

# Extract post-dropout X for use in backward computation
_dropped_X = []
for _dropout in [
Expand Down Expand Up @@ -340,26 +347,26 @@ def backward(ctx, dQ, dK, dV):
del VW
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())

# QW, QW_quant, QA, QB, QS,
# KW, KW_quant, KA, KB, KS,
# VW, VW_quant, VA, VB, VS,
# QW, QW_quant, Q_bias, QA, QB, QS,
# KW, KW_quant, K_bias, KA, KB, KS,
# VW, VW_quant, V_bias, VA, VB, VS,
return dX.view(batch, seq_len, hd), \
None, None, d_QA.t(), d_QB.t(), None, \
None, None, d_KA.t(), d_KB.t(), None, \
None, None, d_VA.t(), d_VB.t(), None, \
None, None, None, d_QA.t(), d_QB.t(), None, \
None, None, None, d_KA.t(), d_KB.t(), None, \
None, None, None, d_VA.t(), d_VB.t(), None, \
None, None, None # dropout
pass
pass


def apply_lora_qkv(self, X):
QW, QW_quant, QA, QB, QS, dropoutQ = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS, dropoutK = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS, dropoutV = get_lora_parameters(self.v_proj)
QW, QW_quant, Q_bias, QA, QB, QS, dropoutQ = get_lora_parameters(self.q_proj)
KW, KW_quant, K_bias, KA, KB, KS, dropoutK = get_lora_parameters(self.k_proj)
VW, VW_quant, V_bias, VA, VB, VS, dropoutV = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply(X,
QW, QW_quant, QA, QB, QS,
KW, KW_quant, KA, KB, KS,
VW, VW_quant, VA, VB, VS,
QW, QW_quant, Q_bias, QA, QB, QS,
KW, KW_quant, K_bias, KA, KB, KS,
VW, VW_quant, V_bias, VA, VB, VS,
dropoutQ, dropoutK, dropoutV,
)
return Q, K, V
Expand Down Expand Up @@ -396,9 +403,10 @@ class LoRA_W(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, X : torch.Tensor,
W, W_quant, A, B, S, dropout_O):
W, W_quant, bias, A, B, S, dropout_O):
dtype = X.dtype
XW = matmul_lora(X, W, W_quant, A, B, S, dropout=dropout_O)
XW += bias

# Extract post-dropout X for use in backward computation
if dropout_O is not None:
Expand Down Expand Up @@ -440,21 +448,21 @@ def backward(ctx, dY : torch.Tensor):
# W, W_quant, A, B, S
return (
dX.view(batch, seq_len, hd),
None, None, d_A.t(), d_B.t(), None,
None, None, None, d_A.t(), d_B.t(), None,
None, # dropout modules
)
pass
pass

def apply_lora_o(self, X):
OW, OW_quant, OA, OB, OS, dropoutO = get_lora_parameters(self.o_proj)
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS, dropoutO)
OW, OW_quant, Obias, OA, OB, OS, dropoutO = get_lora_parameters(self.o_proj)
O = LoRA_W.apply(X, OW, OW_quant, Obias, OA, OB, OS, dropoutO)
return O
pass

# added by flim@sg.ibm.com
# this will be patchable on the actual module
def apply_lora_o_v2(self, X):
OW, OW_quant, OA, OB, OS, dropoutO = get_lora_parameters(self)
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS, dropoutO)
OW, OW_quant, Obias, OA, OB, OS, dropoutO = get_lora_parameters(self)
O = LoRA_W.apply(X, OW, OW_quant, Obias, OA, OB, OS, dropoutO)
return O
Loading

0 comments on commit e87f351

Please sign in to comment.