|
21 | 21 | import math |
22 | 22 | import torch |
23 | 23 | from torch.nn import functional as F |
| 24 | +from torch.autograd import Function |
| 25 | +from .weight_only import quant_weight |
24 | 26 | from packaging.version import Version |
25 | 27 |
|
26 | 28 |
|
@@ -355,3 +357,95 @@ def extra_repr(self) -> str: |
355 | 357 | return 'in_features={}, out_features={}, bits={}, group_size={}, bias={}'.format( |
356 | 358 | self.in_features, self.out_features, self.bits, self.groupsize, self.bias is not None |
357 | 359 | ) |
| 360 | + |
| 361 | + |
| 362 | +class FakeAffineTensorQuantFunction(Function): |
| 363 | + """Fake version of affine quantization |
| 364 | + """ |
| 365 | + |
| 366 | + @staticmethod |
| 367 | + def forward(ctx, inputs, num_bits=4, group_size=1024): |
| 368 | + """ |
| 369 | +
|
| 370 | + As it will be only applied on activation with per tensor granularity, broadcast is not needed. |
| 371 | +
|
| 372 | + Args: |
| 373 | + ctx: Pytorch convention. |
| 374 | + inputs: A Tensor of type float32. |
| 375 | + min_range: A float. |
| 376 | + max_range: A float. |
| 377 | + num_bits: An integer |
| 378 | +
|
| 379 | + Returns: |
| 380 | + outputs: A Tensor of type output_dtype |
| 381 | + """ |
| 382 | + return quant_weight(inputs, num_bits, group_size) |
| 383 | + |
| 384 | + @staticmethod |
| 385 | + def backward(ctx, grad_outputs): |
| 386 | + """ |
| 387 | + Args: |
| 388 | + ctx: Pytorch convention. |
| 389 | + grad_output: A tensor of gradient of outputs |
| 390 | +
|
| 391 | + Returns: |
| 392 | + grad_inputs: A tensor of gradient |
| 393 | + """ |
| 394 | + return grad_outputs, None, None |
| 395 | + |
| 396 | + |
| 397 | +class TEQLinearFakeQuant(torch.nn.Module): |
| 398 | + """ |
| 399 | + wrapper quantization linear |
| 400 | + """ |
| 401 | + |
| 402 | + def __init__(self, orig_layer, alpha=None, num_bits=4, group_size=-1): |
| 403 | + """ |
| 404 | + A forward hook to linear module |
| 405 | + :param orig_layer: the original module |
| 406 | + :param alpha: trainable alpha/scale |
| 407 | + :param num_bits: quantization level |
| 408 | + :param group_size: for fine-grained quantization |
| 409 | + """ |
| 410 | + super(TEQLinearFakeQuant, self).__init__() |
| 411 | + self.orig_layer = orig_layer |
| 412 | + self.alpha = alpha |
| 413 | + |
| 414 | + self.num_bits = num_bits |
| 415 | + self.group_size = group_size |
| 416 | + |
| 417 | + def forward(self, x): |
| 418 | + alpha = torch.clip(self.alpha, 1e-5) |
| 419 | + shape_len = len(x.shape) - 1 |
| 420 | + shape = (1,) * shape_len + (-1,) |
| 421 | + x = x / alpha.view(shape) |
| 422 | + weight = self.orig_layer.weight |
| 423 | + weight = weight * alpha.unsqueeze(dim=0) |
| 424 | + weight_q = FakeAffineTensorQuantFunction().apply(weight, self.num_bits, self.group_size) |
| 425 | + return F.linear(x, weight_q, self.orig_layer.bias) |
| 426 | + |
| 427 | + |
| 428 | +class TEQMulLinear(torch.nn.Module): |
| 429 | + """ |
| 430 | + Trainable Equivalent Transformation (TEQ): linear wrapper to apply scale to input |
| 431 | + """ |
| 432 | + |
| 433 | + def __init__(self, module, input_scale): |
| 434 | + """ |
| 435 | + A forward hook to save input max of a module |
| 436 | + :param module: the linear module |
| 437 | + :param input_scale: scale for input |
| 438 | + """ |
| 439 | + |
| 440 | + super().__init__() |
| 441 | + self.register_buffer('input_scale', input_scale) |
| 442 | + self.add_module('sq_linear', module) |
| 443 | + |
| 444 | + @property |
| 445 | + def weight(self): |
| 446 | + return self.sq_linear.weight |
| 447 | + |
| 448 | + def forward(self, X): |
| 449 | + X = torch.mul(X, self.input_scale) |
| 450 | + X = self.sq_linear(X) |
| 451 | + return X |
0 commit comments