Skip to content

Conversation

@Xia-Weiwen
Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen commented Sep 26, 2025

Summary
We split the original big PR #2505 into the following smaller ones:

Test plan

pytest -sv test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3075

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit c7524ea with merge base 1a9b6f4 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 26, 2025
@Xia-Weiwen Xia-Weiwen added the topic: new feature Use this tag if this PR adds a new feature label Sep 26, 2025
@Xia-Weiwen
Copy link
Collaborator Author

CC @mingfeima for review. Thanks.

@Xia-Weiwen
Copy link
Collaborator Author

Hi @mingfeima @jerryzh168 @andrewor14 Could you please review this PR? Thanks.

@Xia-Weiwen Xia-Weiwen marked this pull request as draft September 30, 2025 01:28
@Xia-Weiwen Xia-Weiwen marked this pull request as ready for review September 30, 2025 01:35
@Xia-Weiwen
Copy link
Collaborator Author

Hi @mingfeima @jerryzh168 @andrewor14 Though this PR depends on #3100, could you please review this PR? Thanks.

@Xia-Weiwen Xia-Weiwen requested a review from jerryzh168 October 14, 2025 01:59
@Xia-Weiwen
Copy link
Collaborator Author

@jerryzh168 Could you please review this PR again? Thanks.

@Xia-Weiwen
Copy link
Collaborator Author

Hi @jerryzh168 It's been awhile since last update of this PR because of reversions of prior PRs. I have rebased this PR. And for the _normalize_granularities thing we discussed above (#3075 (comment)), I have added a method _normalize_and_check_granularities to Float8OpaqueTensor. This brings minimal changes and keeps the code clean. Could you please review this PR again? Thanks.

)


common_utils.instantiate_parametrized_tests(TestFloat8OpaqueTensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: seems that we can add this to TestFloat8OpaqueTensor class as a decorator

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks.

@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("bias", [True, False])
@common_utils.parametrize("bs", [4, 128])
def test_dynamic_float8_linear_ref(self, dtype, x_dim, bias, bs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does ref mean? what's the difference between this test and previous one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It tests the fallback path in the kernel. I have updated the name and added a comment to explain this. Thanks.

example_inputs = (example_inputs[0].unsqueeze(0),)
y = m(*example_inputs)

with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can do 8e3b3da in setUp

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks.

)


class ToyLinearModel(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also just landed a util for this b4ec4cb

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks.


from torchao import quantize_
from torchao.quantization import PerGroup, PerRow, PerTensor
from torchao.quantization.quant_api import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can just import from torchao.quantization, I feel in the end we might be able to make quant_api.py as implementation detail and don't expose to users

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks.

return weight

elif not _fp8_mm_compat(weight):
elif not is_cpu and not _fp8_mm_compat(weight):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe check packing_format instead of device? since we are trying to make this not device specific

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks.

return weight

if isinstance(weight_granularity, PerRow):
if not is_cpu and isinstance(weight_granularity, PerRow):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks.

"Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details"
)

_check_hardware_support(granularity)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function name seems too general, but we can improve this later

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. This function is not added by this PR. I just moved it here for config.version == 1.

*,
parameter_name: str = "weight",
):
assert is_sm_at_least_89() or is_MI300(), (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this removed? is it because it's already checked in _check_hardware_support?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably due to rebase conflicts. I have added this back. Thanks.

assert is_sm_at_least_89() or is_MI300(), (
"Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
)
if config.set_inductor_config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also why is this removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably due to rebase conflicts. I have added this back. Thanks.

PerRow,
PerTensor,
)
from torchao.quantization.observer import get_block_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why importing from observer? I thought it's moved to torchao.quantization.utils?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks.

f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}"
)

act_mat = input_tensor.contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this going to be slow?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On CPU, we require input tensors to be contiguous. In fact, we almost always get contiguous inputs. So, the reordering won't actually happen. Here it just ensures the assumption.

Comment on lines +231 to +246
granularity = weight_tensor.act_quant_kwargs.granularity
if isinstance(granularity, PerGroup):
group_size = granularity.group_size
if weight_tensor.block_size[1] < weight_tensor.size(-1):
# weight_tensor is also per group quantized
assert weight_tensor.block_size[1] == group_size, (
"input and weight should have the same group size but got"
f" {weight_tensor.block_size[1]} and {group_size}"
)
act_block_size = get_block_size(act_mat.shape, granularity)
act_scale = _choose_scale_float8(
act_mat,
float8_dtype=torch.float8_e4m3fn,
block_size=act_block_size,
)
act_mat = _quantize_affine_float8(act_mat, act_scale, torch.float8_e4m3fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this not using

input_tensor = _choose_quant_func_and_quantize_tensor(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer. However, _choose_quant_func_and_quantize_tensor does the following:

    if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs):
        return Float8Tensor.from_hp(...)

Unfortunately, Float8OpaqueTensor also uses QuantizeTensorToFloat8Kwargs so it cannot distinguish them.
Besides, in the implementation of Float8Tensor, activation is quantized by Float8Tensor.from_hp to a Float8Tensor and then unwrapped to get the quantized tensor data for computation. And this part of logic is not exposed to users. So, I feel that it's unnecessary to use Float8OpaqueTensor.from_hp to quantize the activation then unwrap it. It looks good to quantize it with _quantize_affine_float8.
What do you think? If you want Float8OpaqueTensor to be aligned with Float8Tensor, we may need to define a counterpart of QuantizeTensorToFloat8Kwargs for Float8OpaqueTensor so that we can distinguish them. Thanks.

packed_weight,
scale,
bias.float() if bias is not None else bias, # requires bias to be float
torch.float, # out_dtype
Copy link
Contributor

@jerryzh168 jerryzh168 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this align with the original activation dtype orig_dtype? oh or are you trying to do this for better precision?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Thanks.

assert K % block_size[1] == 0, (
f"Expecting in_features {K} to be multiple of group_size {block_size[1]}, but got {K}"
)
scale = _choose_scale_float8(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recently found #3324, does it affect your per tensor use case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. I didn't meet any issue with this.

@Xia-Weiwen
Copy link
Collaborator Author

Hi @jerryzh168 I have updated this PR per your comments. Could you please review again? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants