Skip to content

Commit 7038f8b

Browse files
authored
Rename AQT#2 LayoutType -> Layout (#1049)
1 parent 10601b3 commit 7038f8b

File tree

34 files changed

+358
-361
lines changed

34 files changed

+358
-361
lines changed

benchmarks/benchmark_fp6.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
import pandas as pd
33
import torch.nn.functional as F
44
from torchao.dtypes import to_affine_quantized_fpx
5-
from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayoutType
5+
from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayout
66
from torchao.utils import benchmark_torch_function_in_microseconds
77
from tqdm import tqdm
88

99

1010
def benchmark(m: int, k: int, n: int):
1111
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
12-
fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayoutType(3, 2))
12+
fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayout(3, 2))
1313
fp16_weight = fp6_weight.dequantize(torch.half)
1414

1515
fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")

test/dtypes/test_affine_quantized.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
int8_dynamic_activation_int8_semi_sparse_weight,
1111
float8_weight_only,
1212
)
13-
from torchao.dtypes import SemiSparseLayoutType
13+
from torchao.dtypes import SemiSparseLayout
1414
from torch.testing._internal import common_utils
1515
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1616

@@ -31,7 +31,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
3131
base_functions.append(int4_weight_only(group_size=32))
3232

3333
if do_sparse:
34-
base_functions.append(int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()))
34+
base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
3535

3636
if is_cuda_8_9:
3737
base_functions.append(float8_weight_only())

test/dtypes/test_floatx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111
from torchao.dtypes.floatx import (
1212
FloatxTensorCoreAQTTensorImpl,
13-
FloatxTensorCoreLayoutType,
13+
FloatxTensorCoreLayout,
1414
to_scaled_tc_floatx,
1515
from_scaled_tc_floatx,
1616
)
@@ -81,8 +81,8 @@ def test_to_copy_device(self, ebits, mbits):
8181
x = torch.randn(256, 64)
8282
scale = choose_qparams_affine_floatx(x, ebits, mbits)
8383
x = quantize_affine_floatx(x, scale, ebits, mbits)
84-
layout_type = FloatxTensorCoreLayoutType(ebits, mbits)
85-
floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, layout_type).cuda()
84+
_layout = FloatxTensorCoreLayout(ebits, mbits)
85+
floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, _layout).cuda()
8686
assert floatx_tensor_impl.device.type == "cuda"
8787
floatx_tensor_impl = floatx_tensor_impl.cpu()
8888
assert floatx_tensor_impl.device.type == "cpu"

test/hqq/test_hqq_affine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
to_affine_quantized_intx,
55
ZeroPointDomain,
66
PlainAQTTensorImpl,
7-
PlainLayoutType,
7+
PlainLayout,
88
TensorCoreTiledAQTTensorImpl,
9-
TensorCoreTiledLayoutType,
9+
TensorCoreTiledLayout,
1010
MappingType,
1111
)
1212

test/integration/test_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torchao.quantization.dynamic_quant import (
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
22-
from torchao.dtypes import TensorCoreTiledLayoutType
22+
from torchao.dtypes import TensorCoreTiledLayout
2323
from torchao.quantization.quant_api import (
2424
int4_weight_only,
2525
int8_weight_only,
@@ -876,7 +876,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
876876
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
877877
for groupsize in [64, 32]:
878878
for inner_k_tiles in [4, 2]:
879-
kwargs = {"groupsize": groupsize, "layout_type": TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)}
879+
kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}
880880

881881
def api(mod):
882882
kwargs_copy = kwargs.copy()
@@ -888,7 +888,7 @@ def api(mod):
888888
unwrap_tensor_subclass(mod)
889889
else:
890890
kwargs_copy["inner_k_tiles"] = inner_k_tiles
891-
del kwargs_copy["layout_type"]
891+
del kwargs_copy["layout"]
892892
change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy)
893893

894894
self._test_lin_weight_subclass_api_impl(

test/quantization/test_qat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1515
from torchao.dtypes import (
16-
TensorCoreTiledLayoutType,
16+
TensorCoreTiledLayout,
1717
)
1818
from torchao.quantization.prototype.qat.api import (
1919
ComposableQATQuantizer,

test/sparsity/test_marlin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch import nn
66
from torch.testing._internal.common_utils import TestCase, run_tests
77
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
8-
from torchao.dtypes import MarlinSparseLayoutType
8+
from torchao.dtypes import MarlinSparseLayout
99
from torchao.sparsity.sparse_api import apply_fake_sparsity
1010
from torchao.quantization.quant_api import int4_weight_only, quantize_
1111
from torchao.sparsity.marlin import (
@@ -50,7 +50,7 @@ def test_quant_sparse_marlin_layout_eager(self):
5050
dense_result = model_copy(self.input.bfloat16()).half()
5151

5252
# Sparse + quantized
53-
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
53+
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
5454
sparse_result = self.model(self.input)
5555

5656
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
@@ -67,7 +67,7 @@ def test_quant_sparse_marlin_layout_compile(self):
6767
dense_result = model_copy(self.input.bfloat16()).half()
6868

6969
# Sparse + quantized
70-
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
70+
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
7171
self.model.forward = torch.compile(self.model.forward, fullgraph=True)
7272
sparse_result = self.model(self.input)
7373

test/sparsity/test_sparse_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from torch import nn
77
from torch.testing._internal import common_utils
8-
from torchao.dtypes import MarlinSparseLayoutType, SemiSparseLayoutType
8+
from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout
99
from torchao.quantization.quant_api import (
1010
int4_weight_only,
1111
int8_dynamic_activation_int8_weight,
@@ -74,7 +74,7 @@ def test_quant_semi_sparse(self, compile):
7474

7575
quantize_(
7676
model,
77-
int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()),
77+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()),
7878
)
7979
if compile:
8080
model = torch.compile(model)
@@ -108,7 +108,7 @@ def test_sparse_marlin(self, compile):
108108
dense_result = model_copy(input.bfloat16()).half()
109109

110110
# Sparse + quantized
111-
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
111+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
112112
if compile:
113113
model = torch.compile(model)
114114
sparse_result = model(input)
@@ -185,12 +185,12 @@ def test_sparse(self, compile):
185185
quantize_(model_copy, int8_dynamic_activation_int8_weight())
186186
reference = model_copy(input)
187187

188-
from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType
188+
from torchao.dtypes.affine_quantized_tensor import BlockSparseLayout
189189

190190
quantize_(
191191
model,
192192
int8_dynamic_activation_int8_weight(
193-
layout_type=BlockSparseLayoutType(blocksize=64)
193+
layout=BlockSparseLayout(blocksize=64)
194194
),
195195
)
196196
if compile:

torchao/_models/llama/eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def run_evaluation(
100100
group_size = int(_quant_args[2])
101101
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
102102
if "marlin" in quantization:
103-
from torchao.dtypes import MarlinSparseLayoutType
104-
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
103+
from torchao.dtypes import MarlinSparseLayout
104+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
105105
if "int4wo" in quantization and "gptq" in quantization:
106106
# avoid circular imports
107107
from torchao._models._eval import InputRecorder

torchao/_models/llama/generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ def main(
230230
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
231231
quantize_(model, int4_weight_only(group_size=groupsize))
232232
if "marlin" in quantization:
233-
from torchao.dtypes import MarlinSparseLayoutType
234-
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
233+
from torchao.dtypes import MarlinSparseLayout
234+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
235235
if "fp6" in quantization:
236236
quantize_(model, fpx_weight_only(3, 2))
237237
if quantization.startswith("awq"):

0 commit comments

Comments
 (0)