Skip to content

[BE] Convert quant_primitives methods private #2350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions docs/source/api_ref_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,8 @@ Quantization Primitives

choose_qparams_affine
choose_qparams_affine_with_min_max
choose_qparams_affine_floatx
quantize_affine
quantize_affine_floatx
dequantize_affine
dequantize_affine_floatx
choose_qparams_and_quantize_affine_hqq
fake_quantize_affine
fake_quantize_affine_cachemask
safe_int_mm
int_scaled_matmul
MappingType
Expand Down
20 changes: 10 additions & 10 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@
)
from torchao.quantization.quant_primitives import (
MappingType,
_choose_qparams_affine_float8,
_dequantize_affine_float8,
_quantize_affine_float8,
choose_qparams_affine,
choose_qparams_affine_float8,
dequantize_affine_float8,
quantize_affine_float8,
)
from torchao.utils import (
is_sm_at_least_89,
Expand Down Expand Up @@ -358,21 +358,21 @@ def test_mm_float8dq_per_row(
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
@common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)])
def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
"""Test dequantize_affine_float8 with various configurations"""
"""Test _dequantize_affine_float8 with various configurations"""

device = "cuda"
input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32)

# Choose quantization parameters
scale = choose_qparams_affine_float8(
scale = _choose_qparams_affine_float8(
input_tensor, float8_dtype=float8_dtype, block_size=block_size
)

# Quantize
quantized = quantize_affine_float8(input_tensor, scale, float8_dtype)
quantized = _quantize_affine_float8(input_tensor, scale, float8_dtype)

# Dequantize
dequantized = dequantize_affine_float8(quantized, scale, output_dtype)
dequantized = _dequantize_affine_float8(quantized, scale, output_dtype)

# Verify output properties
self.assertEqual(dequantized.dtype, output_dtype)
Expand All @@ -395,7 +395,7 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
block_size = (2, 16) # 2x2 blocks in first dim, 2x16 blocks in second dim

# Choose quantization parameters
scale = choose_qparams_affine_float8(
scale = _choose_qparams_affine_float8(
input_tensor, float8_dtype=torch.float8_e4m3fn, block_size=block_size
)

Expand All @@ -407,10 +407,10 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
self.assertEqual(scale.shape, expected_scale_shape)

# Quantize
quantized = quantize_affine_float8(input_tensor, scale, torch.float8_e4m3fn)
quantized = _quantize_affine_float8(input_tensor, scale, torch.float8_e4m3fn)

# Dequantize
dequantized = dequantize_affine_float8(quantized, scale, torch.float32)
dequantized = _dequantize_affine_float8(quantized, scale, torch.float32)

# Verify shapes match
self.assertEqual(dequantized.shape, input_tensor.shape)
Expand Down
8 changes: 4 additions & 4 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device):
@parametrize("ebits,mbits", _Floatx_DTYPES)
def test_to_copy_device(self, ebits, mbits):
from torchao.quantization.quant_primitives import (
choose_qparams_affine_floatx,
quantize_affine_floatx,
_choose_qparams_affine_floatx,
_quantize_affine_floatx,
)

x = torch.randn(256, 64)
scale = choose_qparams_affine_floatx(x, ebits, mbits)
x = quantize_affine_floatx(x, scale, ebits, mbits)
scale = _choose_qparams_affine_floatx(x, ebits, mbits)
x = _quantize_affine_floatx(x, scale, ebits, mbits)
_layout = FloatxTensorCoreLayout(ebits, mbits)
floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(
x, scale, None, _layout
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/test_gguf_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
GGUFWeightOnlyConfig,
)
from torchao.quantization import quantize_
from torchao.quantization.quant_primitives import choose_qparams_gguf
from torchao.quantization.quant_primitives import _choose_qparams_gguf
from torchao.quantization.utils import compute_error


Expand All @@ -31,7 +31,7 @@ def test_choose_qparams_gguf(self):
super_block_min_scale,
quantized_block_scale,
quantized_block_min,
) = choose_qparams_gguf(self.input, self.block_size, self.dtype)
) = _choose_qparams_gguf(self.input, self.block_size, self.dtype)

assert super_block_scale_scale.shape, (2, 8)
assert super_block_min_scale.shape, (2, 8)
Expand Down
4 changes: 2 additions & 2 deletions test/quantization/test_marlin_qqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_and_quantize_affine_qqq,
_choose_qparams_and_quantize_affine_qqq,
)
from torchao.testing.utils import skip_if_rocm
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_pack_unpack_equivalence(self):

for group_size in [-1, 128]:
# Quantize weights
q_w, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq(
q_w, s_group, s_channel, _ = _choose_qparams_and_quantize_affine_qqq(
w, num_bits, group_size
)

Expand Down
4 changes: 2 additions & 2 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@
MappingType,
TorchAODType,
ZeroPointDomain,
_fake_quantize_affine,
choose_qparams_affine,
dequantize_affine,
fake_quantize_affine,
quantize_affine,
)
from torchao.quantization.unified import (
Expand Down Expand Up @@ -637,7 +637,7 @@ def test_qat_4w_primitives(self):
group_size,
scales_precision,
)
w_fq = fake_quantize_affine(
w_fq = _fake_quantize_affine(
weight,
block_size,
scales,
Expand Down
12 changes: 6 additions & 6 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
_choose_qparams_affine_tinygemm,
_fake_quantize_affine,
_fake_quantize_affine_cachemask,
choose_qparams_affine,
choose_qparams_affine_tinygemm,
dequantize_affine,
fake_quantize_affine,
fake_quantize_affine_cachemask,
quantize_affine,
)

Expand Down Expand Up @@ -672,7 +672,7 @@ def test_get_groupwise_affine_qparams(self):
zero_point_domain=zero_point_domain,
)
if zero_point_domain == ZeroPointDomain.FLOAT:
scale, zero_point = choose_qparams_affine_tinygemm(
scale, zero_point = _choose_qparams_affine_tinygemm(
input,
mapping_type,
block_size,
Expand Down Expand Up @@ -780,7 +780,7 @@ def test_fake_quantize_affine(self):
dequantized = dequantize_affine(
quantized, block_size, scale, zero_point, dtype, quant_min, quant_max
)
fake_quantized = fake_quantize_affine(
fake_quantized = _fake_quantize_affine(
input, block_size, scale, zero_point, dtype, quant_min, quant_max
)
torch.testing.assert_close(dequantized, fake_quantized)
Expand Down Expand Up @@ -816,7 +816,7 @@ def test_fake_quantize_affine_cachemask(self):
dequantized = dequantize_affine(
quantized, block_size, scale, zero_point, dtype, quant_min, quant_max
)
(fake_quantized, mask) = fake_quantize_affine_cachemask(
(fake_quantized, mask) = _fake_quantize_affine_cachemask(
input,
block_size,
scale,
Expand Down
6 changes: 4 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
marlin_qqq_workspace,
pack_to_marlin_qqq,
)
from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq
from torchao.quantization.quant_primitives import (
_choose_qparams_and_quantize_affine_qqq,
)
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
Expand Down Expand Up @@ -713,7 +715,7 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact
)

# Quantize weights
q_w, s_group, s_channel, w_ref = choose_qparams_and_quantize_affine_qqq(
q_w, s_group, s_channel, w_ref = _choose_qparams_and_quantize_affine_qqq(
b_weight, num_bits, group_size
)
q_w = q_w.t()
Expand Down
58 changes: 29 additions & 29 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@
FP8_TYPES,
MappingType,
ZeroPointDomain,
_choose_qparams_affine_dont_preserve_zero,
_choose_qparams_affine_float8,
_choose_qparams_affine_floatx,
_choose_qparams_affine_tinygemm,
_choose_qparams_and_quantize_affine_hqq,
_dequantize_affine_float8,
_dequantize_affine_floatx,
_dequantize_affine_no_zero_point,
_dequantize_affine_tinygemm,
_quantize_affine_float8,
_quantize_affine_floatx,
_quantize_affine_no_zero_point,
_quantize_affine_tinygemm,
choose_qparams_affine,
choose_qparams_affine_dont_preserve_zero,
choose_qparams_affine_float8,
choose_qparams_affine_floatx,
choose_qparams_affine_tinygemm,
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_float8,
dequantize_affine_floatx,
dequantize_affine_no_zero_point,
dequantize_affine_tinygemm,
quantize_affine,
quantize_affine_float8,
quantize_affine_floatx,
quantize_affine_no_zero_point,
quantize_affine_tinygemm,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
Expand Down Expand Up @@ -142,7 +142,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor

if isinstance(self._layout, FloatxTensorCoreLayout):
int_data, scale = self.tensor_impl.get_plain()
return dequantize_affine_floatx(
return _dequantize_affine_floatx(
int_data,
scale,
self._layout.ebits,
Expand All @@ -151,11 +151,11 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
)
elif isinstance(self._layout, Float8Layout):
data, scale, _ = self.tensor_impl.get_plain()
return dequantize_affine_float8(data, scale, output_dtype)
return _dequantize_affine_float8(data, scale, output_dtype)
else:
data, scale, zero_point = self.tensor_impl.get_plain()
if self.zero_point_domain == ZeroPointDomain.FLOAT:
dq = dequantize_affine_tinygemm(
dq = _dequantize_affine_tinygemm(
data,
self.block_size,
scale,
Expand All @@ -166,7 +166,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
output_dtype=output_dtype,
)
elif self.zero_point_domain == ZeroPointDomain.NONE:
dq = dequantize_affine_no_zero_point(
dq = _dequantize_affine_no_zero_point(
data,
self.block_size,
scale,
Expand Down Expand Up @@ -270,7 +270,7 @@ def from_hp_to_intx(
from torchao.dtypes import Int4CPULayout
from torchao.dtypes.uintx import TensorCoreTiledLayout

data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(
data, scale, zero_point, _ = _choose_qparams_and_quantize_affine_hqq(
input_float,
nbits=nbits,
group_size=group_size,
Expand All @@ -291,7 +291,7 @@ def from_hp_to_intx(
data = data.to(target_dtype)
else:
if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero:
scale, zero_point = choose_qparams_affine_tinygemm(
scale, zero_point = _choose_qparams_affine_tinygemm(
input_float,
mapping_type,
block_size,
Expand All @@ -303,7 +303,7 @@ def from_hp_to_intx(
zero_point_dtype,
)
elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero:
scale, zero_point = choose_qparams_affine_dont_preserve_zero(
scale, zero_point = _choose_qparams_affine_dont_preserve_zero(
input_float,
mapping_type,
block_size,
Expand All @@ -329,7 +329,7 @@ def from_hp_to_intx(
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
if zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
data = quantize_affine_no_zero_point(
data = _quantize_affine_no_zero_point(
input_float,
block_size,
scale,
Expand All @@ -339,7 +339,7 @@ def from_hp_to_intx(
quant_max,
)
elif zero_point_domain == ZeroPointDomain.FLOAT:
data = quantize_affine_tinygemm(
data = _quantize_affine_tinygemm(
input_float,
block_size,
scale,
Expand Down Expand Up @@ -400,7 +400,7 @@ def from_hp_to_intx_static(

if zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
int_data = quantize_affine_no_zero_point(
int_data = _quantize_affine_no_zero_point(
input_float,
block_size,
scale,
Expand All @@ -410,7 +410,7 @@ def from_hp_to_intx_static(
quant_max,
)
elif zero_point_domain == ZeroPointDomain.FLOAT:
int_data = quantize_affine_tinygemm(
int_data = _quantize_affine_tinygemm(
input_float,
block_size,
scale,
Expand Down Expand Up @@ -462,10 +462,10 @@ def from_hp_to_floatx(
if target_dtype in FP8_TYPES:
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
scale = choose_qparams_affine_float8(
scale = _choose_qparams_affine_float8(
input_float, float8_dtype=target_dtype, block_size=block_size
)
data = quantize_affine_float8(input_float, scale, target_dtype)
data = _quantize_affine_float8(input_float, scale, target_dtype)
data, scale, zero_point = _layout.post_process(
data, scale, None, block_size
)
Expand Down Expand Up @@ -499,7 +499,7 @@ def from_hp_to_floatx_static(
input_float, scale, ZeroPointDomain.NONE, block_size
)

data = quantize_affine_float8(
data = _quantize_affine_float8(
input_float,
scale,
target_dtype,
Expand Down Expand Up @@ -545,8 +545,8 @@ def from_hp_to_fpx(

ebits, mbits = _layout.ebits, _layout.mbits
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
scale = _choose_qparams_affine_floatx(input_float, ebits, mbits)
floatx_unpacked = _quantize_affine_floatx(input_float, scale, ebits, mbits)
floatx_packed, scale, _ = _layout.post_process(
floatx_unpacked, scale, None, block_size
)
Expand Down
Loading
Loading