diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py index 03f2ca10..5ca4f8bf 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py @@ -59,9 +59,9 @@ class LoRA_MLP(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, X : torch.Tensor, - gateW, gateW_quant, gateA, gateB, gateS, - upW, upW_quant, upA, upB, upS, - downW, downW_quant, downA, downB, downS, + gateW, gateW_quant, gate_bias, gateA, gateB, gateS, + upW, upW_quant, up_bias, upA, upB, upS, + downW, downW_quant, down_bias, downA, downB, downS, _forward_function, _backward_function, dropout_gate=None, dropout_up=None, dropout_down=None, ): @@ -69,8 +69,11 @@ def forward(ctx, X : torch.Tensor, e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS, dropout=dropout_gate) g = matmul_lora(X, upW, upW_quant, upA, upB, upS, dropout=dropout_up) + e += gate_bias + g += up_bias h = _forward_function(e, g) i = matmul_lora(h, downW, downW_quant, downA, downB, downS, dropout=dropout_down) + i += down_bias # Extract post-dropout X for use in backward computation _dropped_X = [] @@ -152,13 +155,13 @@ def backward(ctx, dY : torch.Tensor): del gateW dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t()) - # gateW, gateW_quant, gateA, gateB, gateS, - # upW, upW_quant, upA, upB, upS, - # downW, downW_quant, downA, downB, downS, + # gateW, gateW_quant, gate_bias, gateA, gateB, gateS, + # upW, upW_quant, up_bias, upA, upB, upS, + # downW, downW_quant, down_bias, downA, downB, downS, return (dX.view(batch, seq_len, hd), \ - None, None, d_gateA.t(), d_gateB.t(), None, \ - None, None, d_upA.t(), d_upB.t(), None, \ - None, None, d_downA.t(), d_downB.t(), None, \ + None, None, None, d_gateA.t(), d_gateB.t(), None, \ + None, None, None, d_upA.t(), d_upB.t(), None, \ + None, None, None, d_downA.t(), d_downB.t(), None, \ None, None, # _backward and _forward None, None, None, # dropout modules ) @@ -168,13 +171,13 @@ def backward(ctx, dY : torch.Tensor): from ..swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel def apply_lora_mlp_swiglu(self, X): - gateW, gateW_quant, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj) - upW, upW_quant, upA, upB, upS, dropout_up = get_lora_parameters(self. up_proj) - downW, downW_quant, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj) + gateW, gateW_quant, gate_bias, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj) + upW, upW_quant, up_bias, upA, upB, upS, dropout_up = get_lora_parameters(self. up_proj) + downW, downW_quant, down_bias, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj) out = LoRA_MLP.apply(X, - gateW, gateW_quant, gateA, gateB, gateS, - upW, upW_quant, upA, upB, upS, - downW, downW_quant, downA, downB, downS, + gateW, gateW_quant, gate_bias, gateA, gateB, gateS, + upW, upW_quant, up_bias, upA, upB, upS, + downW, downW_quant, down_bias, downA, downB, downS, swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel, dropout_gate, dropout_up, dropout_down, ) @@ -184,13 +187,13 @@ def apply_lora_mlp_swiglu(self, X): from ..geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel def apply_lora_mlp_geglu_exact(self, X): - gateW, gateW_quant, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj) - upW, upW_quant, upA, upB, upS, dropout_up = get_lora_parameters(self. up_proj) - downW, downW_quant, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj) + gateW, gateW_quant, gate_bias, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj) + upW, upW_quant, up_bias, upA, upB, upS, dropout_up = get_lora_parameters(self. up_proj) + downW, downW_quant, down_bias, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj) out = LoRA_MLP.apply(X, - gateW, gateW_quant, gateA, gateB, gateS, - upW, upW_quant, upA, upB, upS, - downW, downW_quant, downA, downB, downS, + gateW, gateW_quant, gate_bias, gateA, gateB, gateS, + upW, upW_quant, up_bias, upA, upB, upS, + downW, downW_quant, down_bias, downA, downB, downS, geglu_exact_forward_kernel, geglu_exact_backward_kernel, dropout_gate, dropout_up, dropout_down, ) @@ -200,13 +203,13 @@ def apply_lora_mlp_geglu_exact(self, X): from ..geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel def apply_lora_mlp_geglu_approx(self, X): - gateW, gateW_quant, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj) - upW, upW_quant, upA, upB, upS, dropout_up = get_lora_parameters(self. up_proj) - downW, downW_quant, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj) + gateW, gateW_quant, gate_bias, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj) + upW, upW_quant, up_bias, upA, upB, upS, dropout_up = get_lora_parameters(self. up_proj) + downW, downW_quant, down_bias, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj) out = LoRA_MLP.apply(X, - gateW, gateW_quant, gateA, gateB, gateS, - upW, upW_quant, upA, upB, upS, - downW, downW_quant, downA, downB, downS, + gateW, gateW_quant, gate_bias, gateA, gateB, gateS, + upW, upW_quant, up_bias, upA, upB, upS, + downW, downW_quant, down_bias, downA, downB, downS, geglu_approx_forward_kernel, geglu_approx_backward_kernel, dropout_gate, dropout_up, dropout_down, ) @@ -247,9 +250,9 @@ class LoRA_QKV(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, X : torch.Tensor, - QW, QW_quant, QA, QB, QS, - KW, KW_quant, KA, KB, KS, - VW, VW_quant, VA, VB, VS, + QW, QW_quant, Q_bias, QA, QB, QS, + KW, KW_quant, K_bias, KA, KB, KS, + VW, VW_quant, V_bias, VA, VB, VS, dropout_Q=None, dropout_K=None, dropout_V=None ): dtype = X.dtype @@ -258,6 +261,10 @@ def forward(ctx, X : torch.Tensor, K = matmul_lora(X, KW, KW_quant, KA, KB, KS, dropout=dropout_K) V = matmul_lora(X, VW, VW_quant, VA, VB, VS, dropout=dropout_V) + Q += Q_bias + K += K_bias + V += V_bias + # Extract post-dropout X for use in backward computation _dropped_X = [] for _dropout in [ @@ -340,26 +347,26 @@ def backward(ctx, dQ, dK, dV): del VW dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t()) - # QW, QW_quant, QA, QB, QS, - # KW, KW_quant, KA, KB, KS, - # VW, VW_quant, VA, VB, VS, + # QW, QW_quant, Q_bias, QA, QB, QS, + # KW, KW_quant, K_bias, KA, KB, KS, + # VW, VW_quant, V_bias, VA, VB, VS, return dX.view(batch, seq_len, hd), \ - None, None, d_QA.t(), d_QB.t(), None, \ - None, None, d_KA.t(), d_KB.t(), None, \ - None, None, d_VA.t(), d_VB.t(), None, \ + None, None, None, d_QA.t(), d_QB.t(), None, \ + None, None, None, d_KA.t(), d_KB.t(), None, \ + None, None, None, d_VA.t(), d_VB.t(), None, \ None, None, None # dropout pass pass def apply_lora_qkv(self, X): - QW, QW_quant, QA, QB, QS, dropoutQ = get_lora_parameters(self.q_proj) - KW, KW_quant, KA, KB, KS, dropoutK = get_lora_parameters(self.k_proj) - VW, VW_quant, VA, VB, VS, dropoutV = get_lora_parameters(self.v_proj) + QW, QW_quant, Q_bias, QA, QB, QS, dropoutQ = get_lora_parameters(self.q_proj) + KW, KW_quant, K_bias, KA, KB, KS, dropoutK = get_lora_parameters(self.k_proj) + VW, VW_quant, V_bias, VA, VB, VS, dropoutV = get_lora_parameters(self.v_proj) Q, K, V = LoRA_QKV.apply(X, - QW, QW_quant, QA, QB, QS, - KW, KW_quant, KA, KB, KS, - VW, VW_quant, VA, VB, VS, + QW, QW_quant, Q_bias, QA, QB, QS, + KW, KW_quant, K_bias, KA, KB, KS, + VW, VW_quant, V_bias, VA, VB, VS, dropoutQ, dropoutK, dropoutV, ) return Q, K, V @@ -396,9 +403,10 @@ class LoRA_W(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, X : torch.Tensor, - W, W_quant, A, B, S, dropout_O): + W, W_quant, bias, A, B, S, dropout_O): dtype = X.dtype XW = matmul_lora(X, W, W_quant, A, B, S, dropout=dropout_O) + XW += bias # Extract post-dropout X for use in backward computation if dropout_O is not None: @@ -440,21 +448,21 @@ def backward(ctx, dY : torch.Tensor): # W, W_quant, A, B, S return ( dX.view(batch, seq_len, hd), - None, None, d_A.t(), d_B.t(), None, + None, None, None, d_A.t(), d_B.t(), None, None, # dropout modules ) pass pass def apply_lora_o(self, X): - OW, OW_quant, OA, OB, OS, dropoutO = get_lora_parameters(self.o_proj) - O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS, dropoutO) + OW, OW_quant, Obias, OA, OB, OS, dropoutO = get_lora_parameters(self.o_proj) + O = LoRA_W.apply(X, OW, OW_quant, Obias, OA, OB, OS, dropoutO) return O pass # added by flim@sg.ibm.com # this will be patchable on the actual module def apply_lora_o_v2(self, X): - OW, OW_quant, OA, OB, OS, dropoutO = get_lora_parameters(self) - O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS, dropoutO) + OW, OW_quant, Obias, OA, OB, OS, dropoutO = get_lora_parameters(self) + O = LoRA_W.apply(X, OW, OW_quant, Obias, OA, OB, OS, dropoutO) return O \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py index 4000a258..b8e1cfcb 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py @@ -97,12 +97,13 @@ def get_lora_parameters(proj): # For DPO or disabled adapters base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj qstate = extract_gptq_state(base_layer) + bias = base_layer.bias if hasattr(base_layer, 'bias') else None if base_layer.__module__.startswith("auto_gptq"): setattr(qstate.qzeros, "offset", 1) if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: - return qstate, None, None, None, None + return qstate, None, None, None, None, None active_adapter = ( proj.active_adapters[0] @@ -114,7 +115,7 @@ def get_lora_parameters(proj): s = proj.scaling[active_adapter] dropout = proj.lora_dropout[active_adapter] if hasattr(proj, "lora_dropout") else None dropout.X = None - return qstate, A, B, s, dropout + return qstate, bias, A, B, s, dropout # modified by aaron.chew1@ibm.com def matmul_lora_canonicalized(X, W, A, B, s, dropout=None): @@ -211,6 +212,7 @@ def forward( gate_qzeros, gate_g_idx, gate_bits, + gate_bias, gateA, gateB, gateS, @@ -219,6 +221,7 @@ def forward( up_qzeros, up_g_idx, up_bits, + up_bias, upA, upB, upS, @@ -227,6 +230,7 @@ def forward( down_qzeros, down_g_idx, down_bits, + down_bias, downA, downB, downS, @@ -243,6 +247,8 @@ def forward( e = matmul_lora(X, gateW, gateA, gateB, gateS, dropout=dropout_gate) upW = dequant248(up_qweight, up_scales, up_qzeros, up_g_idx, up_bits) g = matmul_lora(X, upW, upA, upB, upS, dropout=dropout_up) + e += gate_bias + g += up_bias # f = torch.nn.functional.silu(e) # h = f * g h = swiglu_fg_kernel(e, g) @@ -251,6 +257,7 @@ def forward( down_qweight, down_scales, down_qzeros, down_g_idx, down_bits ) i = matmul_lora(h, downW, downA, downB, downS, dropout=dropout_down) + i += down_bias ctx.custom_saved_tensors = ( gate_qweight, @@ -389,6 +396,7 @@ def backward(ctx, dY: torch.Tensor): None, # qzeros None, # g_idx None, # bits + None, # gate_bias d_gateA.t(), d_gateB.t(), None, @@ -397,6 +405,7 @@ def backward(ctx, dY: torch.Tensor): None, None, None, + None, # up_bias d_upA.t(), d_upB.t(), None, # dS @@ -405,6 +414,7 @@ def backward(ctx, dY: torch.Tensor): None, None, None, + None, # down_bias d_downA.t(), d_downB.t(), None, @@ -415,20 +425,23 @@ def backward(ctx, dY: torch.Tensor): def apply_lora_mlp(self, X): - gateQstate, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj) - upQState, upA, upB, upS, dropout_up = get_lora_parameters(self.up_proj) - downQState, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj) + gateQstate, gate_bias, gateA, gateB, gateS, dropout_gate = get_lora_parameters(self.gate_proj) + upQState, up_bias, upA, upB, upS, dropout_up = get_lora_parameters(self.up_proj) + downQState, down_bias, downA, downB, downS, dropout_down = get_lora_parameters(self.down_proj) out = LoRA_MLP.apply( X, *unpack_gptqstate(gateQstate), + gate_bias, gateA, gateB, gateS, *unpack_gptqstate(upQState), + up_bias, upA, upB, upS, *unpack_gptqstate(downQState), + down_bias, downA, downB, downS, @@ -480,6 +493,7 @@ def forward( Q_qzeros, Q_g_idx, Q_bits, + Q_bias, QA, QB, QS, @@ -488,6 +502,7 @@ def forward( K_qzeros, K_g_idx, K_bits, + K_bias, KA, KB, KS, @@ -496,6 +511,7 @@ def forward( V_qzeros, V_g_idx, V_bits, + V_bias, VA, VB, VS, @@ -513,6 +529,10 @@ def forward( K = matmul_lora(X, KW, KA, KB, KS, dropout=dropout_K) V = matmul_lora(X, VW, VA, VB, VS, dropout=dropout_V) + Q += Q_bias + K += K_bias + V += V_bias + ctx.custom_saved_tensors = ( Q_qweight, Q_scales, @@ -642,9 +662,9 @@ def backward(ctx, dQ, dK, dV): del VW dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t()) - # Q_qweight, Q_scales, Q_qzeros, Q_wf, Q_g_idx, Q_bits, QA, QB, QS, - # K_qweight, K_scales, K_qzeros, K_wf, K_g_idx, K_bits, KA, KB, KS, - # V_qweight, V_scales, V_qzeros, V_wf, V_g_idx, V_bits, VA, VB, VS, + # Q_qweight, Q_scales, Q_qzeros, Q_wf, Q_g_idx, Q_bits, Q_bias, QA, QB, QS, + # K_qweight, K_scales, K_qzeros, K_wf, K_g_idx, K_bits, K_bias, KA, KB, KS, + # V_qweight, V_scales, V_qzeros, V_wf, V_g_idx, V_bits, V_bias, VA, VB, VS, return ( dX.view(batch, seq_len, hd), None, @@ -652,6 +672,7 @@ def backward(ctx, dQ, dK, dV): None, None, None, + None, d_QA.t(), d_QB.t(), None, # d_QS.t(), @@ -660,6 +681,7 @@ def backward(ctx, dQ, dK, dV): None, None, None, + None, d_KA.t(), d_KB.t(), None, # d_KS.t(), @@ -668,6 +690,7 @@ def backward(ctx, dQ, dK, dV): None, None, None, + None, d_VA.t(), d_VB.t(), None, @@ -678,20 +701,23 @@ def backward(ctx, dQ, dK, dV): def apply_lora_qkv(self, X): - Qqstate, QA, QB, QS, Qdropout = get_lora_parameters(self.q_proj) - Kqstate, KA, KB, KS, Kdropout = get_lora_parameters(self.k_proj) - Vqstate, VA, VB, VS, Vdropout = get_lora_parameters(self.v_proj) + Qqstate, Q_bias, QA, QB, QS, Qdropout = get_lora_parameters(self.q_proj) + Kqstate, K_bias, KA, KB, KS, Kdropout = get_lora_parameters(self.k_proj) + Vqstate, V_bias, VA, VB, VS, Vdropout = get_lora_parameters(self.v_proj) Q, K, V = LoRA_QKV.apply( X, *unpack_gptqstate(Qqstate), + Q_bias, QA, QB, QS, *unpack_gptqstate(Kqstate), + K_bias, KA, KB, KS, *unpack_gptqstate(Vqstate), + V_bias, VA, VB, VS, @@ -740,6 +766,7 @@ def forward( O_qzeros, O_g_idx, O_bits, + O_bias, A, B, S, @@ -747,6 +774,7 @@ def forward( ): W = dequant248(O_qweight, O_scales, O_qzeros, O_g_idx, O_bits) XW = matmul_lora(X, W, A, B, S, dropout=dropout_O) + XW += O_bias del W ctx.custom_saved_tensors = ( O_qweight, @@ -791,7 +819,7 @@ def backward(ctx, dY: torch.Tensor): del W dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t()) - # O_qweight, O_scales, O_qzeros, O_wf, O_g_idx, O_bits, A, B, S + # O_qweight, O_scales, O_qzeros, O_wf, O_g_idx, O_bits, O_bias, A, B, S return ( dX.view(batch, seq_len, hd), None, @@ -799,6 +827,7 @@ def backward(ctx, dY: torch.Tensor): None, None, None, + None, d_A.t(), d_B.t(), None, @@ -807,13 +836,13 @@ def backward(ctx, dY: torch.Tensor): def apply_lora_o(self, X): - Oqstate, OA, OB, OS, dropout = get_lora_parameters(self.o_proj) - O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS, dropout) + Oqstate, O_bias, OA, OB, OS, dropout = get_lora_parameters(self.o_proj) + O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), O_bias, OA, OB, OS, dropout) return O # added by flim@sg.ibm.com # this version can be directly patched on the output linear def apply_lora_o_v2(self, X): - Oqstate, OA, OB, OS, dropout = get_lora_parameters(self) - O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS, dropout) + Oqstate, O_bias, OA, OB, OS, dropout = get_lora_parameters(self.o_proj) + O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), O_bias, OA, OB, OS, dropout) return O diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py index a7d88de6..4316aa40 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py @@ -62,9 +62,10 @@ def get_lora_parameters(proj): # For DPO or disabled adapters base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight + bias = base_layer.bias if hasattr(base_layer, 'bias') else None if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: - return W, QUANT_STATE(W, base_layer), None, None, None, None + return W, QUANT_STATE(W, base_layer), None, None, None, None, None pass active_adapter = proj.active_adapters[0] if \ @@ -73,7 +74,7 @@ def get_lora_parameters(proj): B = proj.lora_B [active_adapter].weight s = proj.scaling[active_adapter] dropout = proj.lora_dropout[active_adapter] if hasattr(proj, "lora_dropout") else None - return W, QUANT_STATE(W, base_layer), A, B, s, dropout + return W, QUANT_STATE(W, base_layer), bias, A, B, s, dropout pass # modified by flim@sg.ibm.com