Skip to content

Commit

Permalink
PEFT compatible GEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
dtlzhuangz committed Apr 28, 2024
1 parent 5c08b06 commit 5a66e02
Showing 1 changed file with 39 additions and 4 deletions.
43 changes: 39 additions & 4 deletions python/eetq/modules/qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import numpy as np
import math
from torch.autograd import Function

from EETQ import quant_weights, preprocess_weights, w8_a16_gemm

Expand Down Expand Up @@ -60,13 +61,45 @@ def forward(self, input):
output = output + self.bias if self.bias is not None else output
return output

class EetqLinearMMFunction(Function):
@staticmethod
# ctx is the first argument to forward
def forward(
ctx,
x,
weight,
scales,
bias=None
):
# The forward pass can use ctx.
ctx.save_for_backward(x, weight, scales, bias)
output = w8_a16_gemm(x, weight, scales)
output = output + bias if bias is not None else output
return output

@staticmethod
def backward(ctx, grad_output):
input, weight, scales, bias = ctx.saved_tensors
identity = torch.eye(weight.shape[0]).to(weight.device).to(input.dtype)

# Dequantize the weight
weight = w8_a16_gemm(identity, weight, scales)

if ctx.needs_input_grad[0]:
# 2D matrix multiplication, unsqueeze to 3D
grad_input = grad_output.squeeze(0).mm(
weight.transpose(0, 1)
).unsqueeze(0)

return grad_input, None, None, None

class EetqLinear(nn.Module):
def __init__(self, in_features, out_features, bias=True, device="cuda:0"):
def __init__(self, in_features, out_features, bias=True, device="cuda:0", training=False):
super().__init__()

self.in_features = in_features
self.out_features = out_features
self.training = training

self.register_buffer("weight", torch.zeros((in_features, out_features), dtype=torch.int8, device=device))

Expand All @@ -83,10 +116,12 @@ def register_scale(self, device):
weight_scale = torch.zeros((out_features), dtype=torch.float16, device=device)
self.register_buffer("weight_scales", weight_scale)

@torch.no_grad()
def forward(self, input):
output = w8_a16_gemm(input, self.weight, self.weight_scales)
output = output + self.bias if self.bias is not None else output
if self.training:
output = EetqLinearMMFunction.apply(input, self.weight, self.weight_scales, self.bias)
else:
with torch.no_grad():
output = EetqLinearMMFunction.apply(input, self.weight, self.weight_scales, self.bias)
return output


Expand Down

0 comments on commit 5a66e02

Please sign in to comment.