-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor int4 and int8 weight only quantization to use quantize
#301
Conversation
…antize` Summary: Previously we added `quantize` as a general API (pytorch#256) for Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general. The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant and 8da4w (for executorch). This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor subclass. We'll make sure the performance does not regress for vit model. Test Plan: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py reference: elapsed_time: 1.4821058654785155 milliseconds after refactor: elapsed_time: 1.4804757690429688 milliseconds generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d Reviewers: Subscribers: Tasks: Tags:
Summary: Similar to pytorch#294 we replaced the implementation of int8 weight only quant to used the newly added `quantize` function, as a part of the unification effort for affine quantization Test Plan: 1. unit perf test: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_wo_quant_perf elapsed time: 0.23909856796264647, ref elapsed time: 0.25150911331176756 elapsed time: 0.24894208908081056, ref elapsed time: 0.2570047950744629 elapsed time: 0.21607391357421876, ref elapsed time: 0.22809568405151368 2. integration test: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py Reference: elapsed_time: 1.355208740234375 milliseconds After refactor: elapsed_time: 1.32778857421875 milliseconds code diff (gist): https://gist.github.com/jerryzh168/921a722cf20d476c8fc5888482e722dc code diff (meta-only paste): https://www.internalfb.com/phabricator/paste/view/P1387333845 Reviewers: Subscribers: Tasks: Tags:
…antize` Summary: Previously we added `quantize` as a general API (pytorch#256) for Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general. The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant and 8da4w (for executorch). This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor subclass. We'll make sure the performance does not regress for vit model. Test Plan: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py reference: elapsed_time: 1.4821058654785155 milliseconds after refactor: elapsed_time: 1.4804757690429688 milliseconds generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/301
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 74ecb09 with merge base 729fa4d (): BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
quantize
this is rebased on int8-wo PR (#299) so will need to update this PR after the int8-wo PR is landed |
quantize
quantize
@@ -930,6 +930,7 @@ def _test_lin_weight_subclass_impl( | |||
) | |||
|
|||
@parameterized.expand(COMMON_DEVICE_DTYPE) | |||
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there an issue or a short description of both bugs we can add, otherwise will be hard to remember when to remove the skipIf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's just a inductor c++ compilation bug I think, I'm planning to open a PR after this, I have opened one for the other skip here: #300
return layout_cls | ||
return decorator | ||
|
||
def get_aqt_layout_cls(extended_layout: str) -> Callable: | ||
def get_aqt_layout_cls_ctr(extended_layout: str) -> Callable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does ctr stand for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this means constructor, since we are returning class.from_plain
now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs a comment I don't believe ctr is a common abbreviation for constructor
# int_data = int_data.view(shape) | ||
# changed = self.from_plain(int_data, scale, zero) | ||
# return changed | ||
# TODO: changing shape is no-op for int4 packed weight right now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you share some more detail on this I'm quite curious
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I'm confirming with @HDCharles right now, I think this is pretty weird, see comments in L575 of aqt.py for more details
torchao/dtypes/aqt.py
Outdated
|
||
@classmethod | ||
def from_plain(cls, int_data, scale, zero_point): | ||
# TODO: expose the arg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just do it now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one needs a bit more discussions with pt core team
if extended_layout == "tensor_core_tiled": | ||
from torchao.quantization.utils import find_multiple | ||
orig_out_features, orig_in_features = input_float.shape | ||
in_features = find_multiple(orig_in_features, 1024) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where do the constants for 1024 and 8 come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is specific to tinygemm kernels I think, copied from old code:
ao/torchao/quantization/subclass.py
Lines 585 to 586 in 8a4e693
in_features = find_multiple(orig_in_features, 1024) | |
out_features = find_multiple(orig_out_features, 8) |
torchao.apply_dynamic_quant(model) | ||
from torch._inductor import config as inductorconfig | ||
inductorconfig.force_fuse_int_mm_with_mul = True | ||
# int8 act, int8 weight dynamic quantization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we delete code here instead of commenting it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, this is just for people to easily try out different APIs, but we can just ask people to copy paste from README as well
# groupwise int4 quantization | ||
groupsize = weight_qtensor.block_size[-1] | ||
if not _from_flinear: | ||
weight_qtensor = weight_qtensor.t() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b q: why does this require a transpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is to align the dimensions, for block_size so that we can get groupsize from block_size argument, see L662, and also related to L575. right now the _quantized_linear
does not have a well-defined accepted weight shape, we need to fix that
@@ -507,7 +571,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs): | |||
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" | |||
) | |||
|
|||
def _quantized_linear_op(input_tensor, weight_qtensor, bias): | |||
def _quantized_linear_op(input_tensor, weight_qtensor, bias, _from_flinear=True): | |||
# TODO: the old tensor subclass can use the single implementation for both F.linear dispatch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@msaroufim see this comment for more details
return layout_cls | ||
return decorator | ||
|
||
def get_aqt_layout_cls(extended_layout: str) -> Callable: | ||
def get_aqt_layout_cls_ctr(extended_layout: str) -> Callable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs a comment I don't believe ctr is a common abbreviation for constructor
torchao/quantization/quant_api.py
Outdated
filter_fn, | ||
) | ||
if TORCH_VERSION_AFTER_2_4: | ||
quantize(model, get_apply_int4wo_quant(**kwargs), filter_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blind kwargs make it impossible to document the behavior. i understand that change_linear_weights_to_int4_woqtensors
has this as well. Seems like something that could be worth fixing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah sure
@@ -55,3 +58,10 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs): | |||
) | |||
measurement = t0.blocked_autorange() | |||
return measurement.mean * 1e6 | |||
|
|||
|
|||
def find_multiple(n: int, *args: Tuple[int]) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we now use this in torchao/dtypes and torchao/quantization and have to do import tricks to avoid circular dep
Summary: This is similar to pytorch#294 but applied for int4 weight only quantization Test Plan: unit perf test: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4_wo_quant_perf elapsed time: 0.2166275215148926, ref elapsed time: 0.2191881561279297 elapsed time: 0.2376406478881836, ref elapsed time: 0.22721023559570314 elapsed time: 0.21919679641723633, ref elapsed time: 0.2154969596862793 integration perf test: reference: elapsed_time: 2.5900126953125 milliseconds after refactor: elapsed_time: 2.56680078125 milliseconds diff: no diff TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py Before: After: generated code diff: Reviewers: Subscribers: Tasks: Tags:
quantize
quantize
Please don't merge PRs when CI is red and we can't get signal for incremental changes. Fix main CI first, then merge. |
makes sense, sorry about this, will do next time |
…torch#301) * Replace implementation for int8 dynamic quantization with call to `quantize` Summary: Previously we added `quantize` as a general API (pytorch#256) for Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general. The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant and 8da4w (for executorch). This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor subclass. We'll make sure the performance does not regress for vit model. Test Plan: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py reference: elapsed_time: 1.4821058654785155 milliseconds after refactor: elapsed_time: 1.4804757690429688 milliseconds generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d Reviewers: Subscribers: Tasks: Tags: * Refactor int8 weight only quant to use `quantize` Summary: Similar to pytorch#294 we replaced the implementation of int8 weight only quant to used the newly added `quantize` function, as a part of the unification effort for affine quantization Test Plan: 1. unit perf test: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_wo_quant_perf elapsed time: 0.23909856796264647, ref elapsed time: 0.25150911331176756 elapsed time: 0.24894208908081056, ref elapsed time: 0.2570047950744629 elapsed time: 0.21607391357421876, ref elapsed time: 0.22809568405151368 2. integration test: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py Reference: elapsed_time: 1.355208740234375 milliseconds After refactor: elapsed_time: 1.32778857421875 milliseconds code diff (gist): https://gist.github.com/jerryzh168/921a722cf20d476c8fc5888482e722dc code diff (meta-only paste): https://www.internalfb.com/phabricator/paste/view/P1387333845 Reviewers: Subscribers: Tasks: Tags: * Replace implementation for int8 dynamic quantization with call to `quantize` Summary: Previously we added `quantize` as a general API (pytorch#256) for Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general. The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant and 8da4w (for executorch). This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor subclass. We'll make sure the performance does not regress for vit model. Test Plan: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py reference: elapsed_time: 1.4821058654785155 milliseconds after refactor: elapsed_time: 1.4804757690429688 milliseconds generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d Reviewers: Subscribers: Tasks: Tags: * Refactor int4 weight only quantization with call to `quantize` Summary: This is similar to pytorch#294 but applied for int4 weight only quantization Test Plan: unit perf test: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4_wo_quant_perf elapsed time: 0.2166275215148926, ref elapsed time: 0.2191881561279297 elapsed time: 0.2376406478881836, ref elapsed time: 0.22721023559570314 elapsed time: 0.21919679641723633, ref elapsed time: 0.2154969596862793 integration perf test: reference: elapsed_time: 2.5900126953125 milliseconds after refactor: elapsed_time: 2.56680078125 milliseconds diff: no diff TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py Before: After: generated code diff: Reviewers: Subscribers: Tasks: Tags: --------- Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
Summary:
This is similar to #294 but applied for int4 weight only quantization
Test Plan:
unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4_wo_quant_perf
elapsed time: 0.2166275215148926, ref elapsed time: 0.2191881561279297
elapsed time: 0.2376406478881836, ref elapsed time: 0.22721023559570314
elapsed time: 0.21919679641723633, ref elapsed time: 0.2154969596862793
integration perf test:
reference: elapsed_time: 2.5900126953125 milliseconds
after refactor: elapsed_time: 2.56680078125 milliseconds
diff: no diff
TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py
Before:
After:
generated code diff:
Reviewers:
Subscribers:
Tasks:
Tags:
Refactor int8 weight only quant to use quantize #299 logs
Summary:
Similar to #294 we replaced the implementation
of int8 weight only quant to used the newly added quantize function, as a part of
the unification effort for affine quantization
Test Plan:
unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_wo_quant_perf
elapsed time: 0.23909856796264647, ref elapsed time: 0.25150911331176756
elapsed time: 0.24894208908081056, ref elapsed time: 0.2570047950744629
elapsed time: 0.21607391357421876, ref elapsed time: 0.22809568405151368
integration test:
TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py
Reference: elapsed_time: 1.355208740234375 milliseconds
After refactor: elapsed_time: 1.32778857421875 milliseconds
code diff (gist): gist.github.com/jerryzh168/921a722cf20d476c8fc5888482e722dc
code diff (meta-only paste): internalfb.com/phabricator/paste/view/P1387333845