-
Notifications
You must be signed in to change notification settings - Fork 369
[CPU] add Float8OpaqueTensor for dynamic float8 act float8 weight #3075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3075
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit c7524ea with merge base 1a9b6f4 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
CC @mingfeima for review. Thanks. |
|
Hi @mingfeima @jerryzh168 @andrewor14 Could you please review this PR? Thanks. |
test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py
Show resolved
Hide resolved
|
Hi @mingfeima @jerryzh168 @andrewor14 Though this PR depends on #3100, could you please review this PR? Thanks. |
|
@jerryzh168 Could you please review this PR again? Thanks. |
|
Hi @jerryzh168 It's been awhile since last update of this PR because of reversions of prior PRs. I have rebased this PR. And for the |
| ) | ||
|
|
||
|
|
||
| common_utils.instantiate_parametrized_tests(TestFloat8OpaqueTensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: seems that we can add this to TestFloat8OpaqueTensor class as a decorator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thanks.
| @common_utils.parametrize("x_dim", [2, 3]) | ||
| @common_utils.parametrize("bias", [True, False]) | ||
| @common_utils.parametrize("bs", [4, 128]) | ||
| def test_dynamic_float8_linear_ref(self, dtype, x_dim, bias, bs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does ref mean? what's the difference between this test and previous one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It tests the fallback path in the kernel. I have updated the name and added a comment to explain this. Thanks.
| example_inputs = (example_inputs[0].unsqueeze(0),) | ||
| y = m(*example_inputs) | ||
|
|
||
| with torch.no_grad(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can do 8e3b3da in setUp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thanks.
| ) | ||
|
|
||
|
|
||
| class ToyLinearModel(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also just landed a util for this b4ec4cb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thanks.
|
|
||
| from torchao import quantize_ | ||
| from torchao.quantization import PerGroup, PerRow, PerTensor | ||
| from torchao.quantization.quant_api import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can just import from torchao.quantization, I feel in the end we might be able to make quant_api.py as implementation detail and don't expose to users
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thanks.
torchao/quantization/quant_api.py
Outdated
| return weight | ||
|
|
||
| elif not _fp8_mm_compat(weight): | ||
| elif not is_cpu and not _fp8_mm_compat(weight): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe check packing_format instead of device? since we are trying to make this not device specific
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thanks.
torchao/quantization/quant_api.py
Outdated
| return weight | ||
|
|
||
| if isinstance(weight_granularity, PerRow): | ||
| if not is_cpu and isinstance(weight_granularity, PerRow): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thanks.
| "Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" | ||
| ) | ||
|
|
||
| _check_hardware_support(granularity) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function name seems too general, but we can improve this later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. This function is not added by this PR. I just moved it here for config.version == 1.
| *, | ||
| parameter_name: str = "weight", | ||
| ): | ||
| assert is_sm_at_least_89() or is_MI300(), ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this removed? is it because it's already checked in _check_hardware_support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably due to rebase conflicts. I have added this back. Thanks.
| assert is_sm_at_least_89() or is_MI300(), ( | ||
| "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" | ||
| ) | ||
| if config.set_inductor_config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also why is this removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably due to rebase conflicts. I have added this back. Thanks.
| PerRow, | ||
| PerTensor, | ||
| ) | ||
| from torchao.quantization.observer import get_block_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why importing from observer? I thought it's moved to torchao.quantization.utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thanks.
| f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" | ||
| ) | ||
|
|
||
| act_mat = input_tensor.contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't this going to be slow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On CPU, we require input tensors to be contiguous. In fact, we almost always get contiguous inputs. So, the reordering won't actually happen. Here it just ensures the assumption.
| granularity = weight_tensor.act_quant_kwargs.granularity | ||
| if isinstance(granularity, PerGroup): | ||
| group_size = granularity.group_size | ||
| if weight_tensor.block_size[1] < weight_tensor.size(-1): | ||
| # weight_tensor is also per group quantized | ||
| assert weight_tensor.block_size[1] == group_size, ( | ||
| "input and weight should have the same group size but got" | ||
| f" {weight_tensor.block_size[1]} and {group_size}" | ||
| ) | ||
| act_block_size = get_block_size(act_mat.shape, granularity) | ||
| act_scale = _choose_scale_float8( | ||
| act_mat, | ||
| float8_dtype=torch.float8_e4m3fn, | ||
| block_size=act_block_size, | ||
| ) | ||
| act_mat = _quantize_affine_float8(act_mat, act_scale, torch.float8_e4m3fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this not using
| input_tensor = _choose_quant_func_and_quantize_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer. However, _choose_quant_func_and_quantize_tensor does the following:
if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs):
return Float8Tensor.from_hp(...)Unfortunately, Float8OpaqueTensor also uses QuantizeTensorToFloat8Kwargs so it cannot distinguish them.
Besides, in the implementation of Float8Tensor, activation is quantized by Float8Tensor.from_hp to a Float8Tensor and then unwrapped to get the quantized tensor data for computation. And this part of logic is not exposed to users. So, I feel that it's unnecessary to use Float8OpaqueTensor.from_hp to quantize the activation then unwrap it. It looks good to quantize it with _quantize_affine_float8.
What do you think? If you want Float8OpaqueTensor to be aligned with Float8Tensor, we may need to define a counterpart of QuantizeTensorToFloat8Kwargs for Float8OpaqueTensor so that we can distinguish them. Thanks.
| packed_weight, | ||
| scale, | ||
| bias.float() if bias is not None else bias, # requires bias to be float | ||
| torch.float, # out_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this align with the original activation dtype orig_dtype? oh or are you trying to do this for better precision?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thanks.
| assert K % block_size[1] == 0, ( | ||
| f"Expecting in_features {K} to be multiple of group_size {block_size[1]}, but got {K}" | ||
| ) | ||
| scale = _choose_scale_float8( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recently found #3324, does it affect your per tensor use case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reminder. I didn't meet any issue with this.
|
Hi @jerryzh168 I have updated this PR per your comments. Could you please review again? Thanks. |
Summary
We split the original big PR #2505 into the following smaller ones:
Test plan