-
Notifications
You must be signed in to change notification settings - Fork 256
Refactor int8 dynamic quantization with call to quantize
#294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -177,6 +177,11 @@ def _apply_fn_to_data(self, fn): | |
fn(self.zero_point), | ||
) | ||
|
||
def _change_shape(self, shape): | ||
return self.__class__( | ||
self.int_data.view(shape), self.scale, self.zero_point | ||
) | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args, kwargs): | ||
kwargs = {} if kwargs is None else kwargs | ||
|
@@ -186,6 +191,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): | |
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) | ||
) | ||
|
||
if func is aten.view.default: | ||
assert len(args) == 2 | ||
new = args[0]._change_shape(args[1]) | ||
return return_and_correct_aliasing(func, args, kwargs, new) | ||
|
||
raise NotImplementedError( | ||
f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported" | ||
) | ||
|
@@ -245,6 +255,7 @@ def __tensor_unflatten__( | |
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride | ||
): | ||
packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] | ||
# TODO: fix the unflatten logic | ||
return cls(packed_weight, scale_and_zero) | ||
|
||
def to(self, *args, **kwargs): | ||
|
@@ -470,6 +481,11 @@ def _apply_fn_to_data(self, fn): | |
strides=self.stride(), | ||
) | ||
|
||
def _change_shape(self, shape, block_size): | ||
return self.__class__( | ||
self.layout_tensor.view(shape), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride() | ||
) | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args, kwargs): | ||
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future | ||
|
@@ -491,13 +507,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): | |
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" | ||
) | ||
|
||
@implements_aqt_torch_function(torch.nn.functional.linear) | ||
def functional_linear(*args, **kwargs): | ||
input_tensor, weight_qtensor, bias = ( | ||
args[0], | ||
args[1], | ||
args[2] if len(args) > 2 else None, | ||
) | ||
def _quantized_linear_op(input_tensor, weight_qtensor, bias): | ||
is_cuda = weight_qtensor.is_cuda | ||
is_cpu = weight_qtensor.device == torch.device("cpu") | ||
if isinstance(weight_qtensor, AffineQuantizedTensor): | ||
|
@@ -508,14 +518,10 @@ def functional_linear(*args, **kwargs): | |
# if input tensor is quantized, either dispatch to the int8 mm kernel | ||
# or just dequantize the input tensor | ||
input_is_int8 = _aqt_is_int8_reduced_range(input_tensor) | ||
input_tensor_dtype_is_expected = input_tensor.dtype in [ | ||
torch.float, | ||
torch.bfloat16 | ||
] | ||
if ( | ||
is_cuda and | ||
input_is_int8 and | ||
input_tensor_dtype_is_expected and | ||
input_tensor.dtype == weight_qtensor.dtype and | ||
input_tensor.layout == "plain" and | ||
weight_qtensor.layout == "plain" | ||
): | ||
|
@@ -576,45 +582,83 @@ def functional_linear(*args, **kwargs): | |
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and | ||
weight_qtensor.layout == "plain" | ||
): | ||
# TODO: enable mps path as well | ||
# TODO: enable cpu and mps efficient path | ||
# per channel int8 weight only quantizated mm | ||
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.layout_tensor.int_data, weight_qtensor.layout_tensor.scale) | ||
else: | ||
weight_tensor = weight_qtensor.dequantize() | ||
return torch.nn.functional.linear(input_tensor, weight_tensor, bias) | ||
else: | ||
w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t().contiguous() | ||
orig_dtype = input_tensor.dtype | ||
y = ( | ||
torch.mm( | ||
input_tensor.reshape(-1, input_tensor.shape[-1]), | ||
w_vals_int8_t.to(input_tensor.dtype), | ||
) | ||
* weight_qtensor.scale | ||
) | ||
y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) | ||
if bias is not None: | ||
y += bias | ||
return y.to(orig_dtype) | ||
|
||
# is_cpu and is_mps only, some issue with is_contiguous() currently | ||
# return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale) | ||
|
||
raise NotImplementedError("No specialized dispatch found for quantized linear op") | ||
|
||
|
||
@implements_aqt_torch_function(torch.nn.functional.linear) | ||
def functional_linear(*args, **kwargs): | ||
input_tensor, weight_tensor, bias = ( | ||
args[0], | ||
args[1], | ||
args[2] if len(args) > 2 else None, | ||
) | ||
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor | ||
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to | ||
# make the branches easier to understand in `_quantized_linear_op` | ||
try: | ||
return _quantized_linear_op(input_tensor, weight_tensor, bias) | ||
except: | ||
if isinstance(input_tensor, AffineQuantizedTensor): | ||
input_tensor = input_tensor.dequantize() | ||
if isinstance(weight_tensor, AffineQuantizedTensor): | ||
weight_tensor = weight_tensor.dequantize() | ||
return torch.nn.functional.linear(input_tensor, weight_tensor, bias) | ||
|
||
|
||
@implements_aqt_aten_ops([aten.mm.default, aten.addmm.default]) | ||
def aten_mm(func, *args, **kwargs): | ||
if not args[0].is_floating_point(): | ||
raise NotImplementedError(f"{func} is not implemented for non floating point input") | ||
|
||
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor | ||
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to | ||
# make the branches easier to understand in `_quantized_linear_op` | ||
if func == aten.addmm.default: | ||
assert args[1].shape[-1] == args[2].shape[0], ( | ||
f"need mat1 shape: {args[1].shape} final" | ||
f"dim to match mat2 shape: {args[2].shape} first dim " | ||
) | ||
input_tensor, weight_qtensor, bias = ( | ||
input_tensor, weight_tensor, bias = ( | ||
args[1], | ||
args[2], | ||
args[0], | ||
) | ||
try: | ||
return _quantized_linear_op(input_tensor, weight_tensor, bias) | ||
except: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there' a bunch of code duplication here, also why do we need the try except block? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh we actually need to call the function in different ways here, not sure when the change is reverted, will fix try except is used as a fallback when the specific configuration of input and weight tensor is not caught by any of the special dispatches in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added some comments |
||
if isinstance(input_tensor, AffineQuantizedTensor): | ||
input_tensor = input_tensor.dequantize() | ||
if isinstance(weight_tensor, AffineQuantizedTensor): | ||
weight_tensor = weight_tensor.dequantize() | ||
return func(bias, input_tensor, weight_tensor) | ||
else: | ||
assert args[0].shape[-1] == args[1].shape[0], ( | ||
f"need mat1 shape: {args[0].shape} final dim" | ||
f"to match mat2 shape: {args[1].shape} first dim" | ||
) | ||
input_tensor, weight_qtensor, bias = ( | ||
input_tensor, weight_tensor, bias = ( | ||
args[0], | ||
args[1], | ||
None if len(args) == 2 else args[2], | ||
None | ||
) | ||
weight_tensor = weight_qtensor.dequantize() | ||
return func(input_tensor, weight_tensor, bias) | ||
try: | ||
return _quantized_linear_op(input_tensor, weight_tensor, bias) | ||
except: | ||
if isinstance(input_tensor, AffineQuantizedTensor): | ||
input_tensor = input_tensor.dequantize() | ||
if isinstance(weight_tensor, AffineQuantizedTensor): | ||
weight_tensor = weight_tensor.dequantize() | ||
return func(input_tensor, weight_tensor) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so here is a difference of how we call the function, since we have aten.mm here, the order of passing around the args are different from aten.addmm |
||
|
||
@implements_aqt_aten_ops([aten.detach.default]) | ||
def detach(func, *args, **kwargs): | ||
|
@@ -641,10 +685,10 @@ def _to_copy(func, *args, **kwargs): | |
|
||
@implements_aqt_aten_ops([aten.t.default]) | ||
def t(func, *args, **kwargs): | ||
# TODO: need to implement this | ||
# args[0].transposed = not args[0].transposed | ||
# new = args[0]._change_shape(args[0].shape[::-1]) | ||
# return return_and_correct_aliasing(func, args, kwargs, new) | ||
raise Exception("transpose not implemented yet") | ||
block_size = args[0].block_size | ||
assert len(block_size) == 2 | ||
transposed_block_size = (block_size[1], block_size[0]) | ||
new = args[0]._change_shape(args[0].shape[::-1], transposed_block_size) | ||
return return_and_correct_aliasing(func, args, kwargs, new) | ||
|
||
to_aq = AffineQuantizedTensor.from_float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test might end up being flaky, also how long does this test take? strange to do a benchmark for unit tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is pretty quick when I run it in my A100 machine, finishes in a few seconds. I could also skip this by default and just have people run this locally when making changes to these APIs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skipped this one by default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah doing benchmarks in unit tests is a known anti pattern. Test environments don't need to be inconsistent and it's likely a waste of resources to make them so.