Skip to content

Commit

Permalink
Support qlora in CPU (#9233)
Browse files Browse the repository at this point in the history
* support qlora in CPU

* revert example

* fix style
  • Loading branch information
yangw1234 authored Oct 27, 2023
1 parent 8838707 commit 163d033
Showing 1 changed file with 45 additions and 9 deletions.
54 changes: 45 additions & 9 deletions python/llm/src/bigdl/llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,37 @@ def backward(ctx, grad_output):
return grad_A, grad_weight, None


class MatMulLowBitCPU(torch.autograd.Function):

@staticmethod
def forward(ctx, A, weight):
if torch.is_autocast_enabled():
A = A.to(torch.get_autocast_dtype())
ctx.is_empty = False
x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape,
weight._shape[0] * weight._shape[1])
result = torch.matmul(A, x0_fp32.T)
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (A, weight)
else:
ctx.tensors = (None, None)
return result

@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, _, = ctx.needs_input_grad
A, weight = ctx.tensors
grad_A, grad_weight = None, None
if req_gradA:
x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape,
weight._shape[0] * weight._shape[1])
grad_A = torch.matmul(grad_output, x0_fp32.to(grad_output.dtype))
return grad_A, grad_weight, None


class LowBitLinear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True, mp_group=None):
Expand Down Expand Up @@ -388,16 +419,21 @@ def forward(self, x: torch.Tensor):
and self.qtype != FP4,
"NF3, NF4, FP4 and FP8 quantization are currently not"
" supported on CPU")
if IS_SERVER and (not IS_SPR) and \
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
result = F.linear(x, x0_fp32, self.bias)
else:
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.training and x.requires_grad:
result = MatMulLowBitCPU.apply(x, self.weight)
if self.bias is not None:
result += self.bias
result = result + self.bias
else:
if IS_SERVER and (not IS_SPR) and \
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
result = F.linear(x, x0_fp32, self.bias)
else:
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.bias is not None:
result += self.bias
return result


Expand Down

0 comments on commit 163d033

Please sign in to comment.