Skip to content

Commit b4115d3

Browse files
committed
Merge branch 'main' into rocm_swizzle_reland2
2 parents f4ec46d + 5549da8 commit b4115d3

File tree

15 files changed

+390
-27
lines changed

15 files changed

+390
-27
lines changed

.github/workflows/torchao_experimental_test.yml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
test-cpu-ops:
1515
strategy:
1616
matrix:
17-
runner: [macos-14]
17+
runner: [macos-14, linux.arm64.2xlarge]
1818
runs-on: ${{matrix.runner}}
1919
defaults:
2020
run:
@@ -30,7 +30,8 @@ jobs:
3030
python-version: "3.10"
3131
miniconda-version: "latest"
3232
activate-environment: venv
33-
- name: Install requirements
33+
- name: Install requirements mac
34+
if: runner.os == 'macOS'
3435
run: |
3536
conda activate venv
3637
# Install executorch first because it installs its own version
@@ -39,27 +40,37 @@ jobs:
3940
pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu --force-reinstall
4041
pip install -r dev-requirements.txt
4142
USE_CPP=1 TORCHAO_BUILD_KLEIDIAI=1 pip install .
43+
- name: Install requirements linux
44+
if: runner.os == 'Linux'
45+
run: |
46+
conda activate venv
47+
pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu --force-reinstall
48+
pip install -r dev-requirements.txt
49+
BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install .
4250
- name: Run python tests
4351
run: |
4452
conda activate venv
4553
pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py
4654
python torchao/experimental/tests/test_embedding_xbit_quantizer.py
4755
python torchao/experimental/tests/test_quant_passes.py
4856
- name: Run kernels/cpu/aarch64/tests
57+
if: runner.os == 'macOS'
4958
run: |
5059
conda activate venv
5160
pushd torchao/experimental/kernels/cpu/aarch64/tests
5261
sh build_and_run_tests.sh
5362
rm -rf /tmp/cmake-out
5463
popd
5564
- name: Run torchao/experimental/ops/tests
65+
if: runner.os == 'macOS'
5666
run: |
5767
conda activate venv
5868
pushd torchao/experimental/ops/tests
5969
sh build_and_run_tests.sh
6070
rm -rf /tmp/cmake-out
6171
popd
6272
- name: ET ops build
73+
if: runner.os == 'macOS'
6374
run: |
6475
conda activate venv
6576
pushd torchao/experimental

setup.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def read_version(file_path="version.txt"):
4949

5050
import platform
5151

52-
build_torchao_experimental = (
52+
build_macos_arm_auto = (
5353
use_cpp == "1"
5454
and platform.machine().startswith("arm64")
5555
and platform.system() == "Darwin"
@@ -117,8 +117,33 @@ def __init__(self):
117117
"TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
118118
)
119119

120+
# TORCHAO_PARALLEL_BACKEND specifies which parallel backend to use
121+
# Possible values: aten_openmp, executorch, openmp, pthreadpool, single_threaded
122+
self.parallel_backend = os.getenv("TORCHAO_PARALLEL_BACKEND", "aten_openmp")
123+
124+
# TORCHAO_ENABLE_ARM_NEON_DOT enable ARM NEON Dot Product extension
125+
# Enabled by default on macOS silicon
126+
self.enable_arm_neon_dot = self._os_bool_var(
127+
"TORCHAO_ENABLE_ARM_NEON_DOT",
128+
default=(self._is_arm64() and self._is_macos()),
129+
)
130+
if self.enable_arm_neon_dot:
131+
assert self.build_cpu_aarch64, (
132+
"TORCHAO_ENABLE_ARM_NEON_DOT requires TORCHAO_BUILD_CPU_AARCH64 be set"
133+
)
134+
135+
# TORCHAO_ENABLE_ARM_I8MM enable ARM 8-bit Integer Matrix Multiply instructions
136+
# Not enabled by default on macOS as not all silicon mac supports it
137+
self.enable_arm_i8mm = self._os_bool_var(
138+
"TORCHAO_ENABLE_ARM_I8MM", default=False
139+
)
140+
if self.enable_arm_i8mm:
141+
assert self.build_cpu_aarch64, (
142+
"TORCHAO_ENABLE_ARM_I8MM requires TORCHAO_BUILD_CPU_AARCH64 be set"
143+
)
144+
120145
def _is_arm64(self) -> bool:
121-
return platform.machine().startswith("arm64")
146+
return platform.machine().startswith("arm64") or platform.machine() == "aarch64"
122147

123148
def _is_macos(self) -> bool:
124149
return platform.system() == "Darwin"
@@ -468,7 +493,8 @@ def get_extensions():
468493
)
469494
)
470495

471-
if build_torchao_experimental:
496+
# Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
497+
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":
472498
build_options = BuildOptions()
473499

474500
def bool_to_on_off(value):
@@ -488,6 +514,9 @@ def bool_to_on_off(value):
488514
f"-DTORCHAO_BUILD_CPU_AARCH64={bool_to_on_off(build_options.build_cpu_aarch64)}",
489515
f"-DTORCHAO_BUILD_KLEIDIAI={bool_to_on_off(build_options.build_kleidi_ai)}",
490516
f"-DTORCHAO_BUILD_MPS_OPS={bool_to_on_off(build_options.build_experimental_mps)}",
517+
f"-DTORCHAO_ENABLE_ARM_NEON_DOT={bool_to_on_off(build_options.enable_arm_neon_dot)}",
518+
f"-DTORCHAO_ENABLE_ARM_I8MM={bool_to_on_off(build_options.enable_arm_i8mm)}",
519+
f"-DTORCHAO_PARALLEL_BACKEND={build_options.parallel_backend}",
491520
"-DTorch_DIR=" + torch_dir,
492521
]
493522
+ (

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
2929
from torchao.quantization import quantize_
3030
from torchao.quantization.utils import compute_error
31+
from torchao.testing.utils import skip_if_rocm
3132
from torchao.utils import (
3233
TORCH_VERSION_AT_LEAST_2_8,
3334
is_sm_at_least_89,
@@ -396,18 +397,25 @@ def test_inference_print_str():
396397
@pytest.mark.skipif(
397398
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
398399
)
399-
@pytest.mark.skipif(not is_sm_at_least_100, reason="Reqs sm100")
400400
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
401401
@pytest.mark.parametrize("bias", [True, False])
402402
@pytest.mark.parametrize("compile", [True, False])
403403
@torch.no_grad()
404+
@skip_if_rocm(
405+
"ROCm float4 gemm require gfx950"
406+
) # TODO(future): deploy gfx950 in ROCM CI
404407
def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
405408
"""
406409
Smoke test for inference compile
407410
"""
411+
# TODO(future): figure out why these CUDA capability conditions are not properly
412+
# applied when inside `pytest.mark.skipif` for this test
408413
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
409414
if not is_sm_at_least_89():
410415
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
416+
elif elem_dtype == torch.float4_e2m1fn_x2:
417+
if not is_sm_at_least_100():
418+
pytest.skip("CUDA capability >= 10.0 required for float4 gemm")
411419

412420
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
413421
m_mx = copy.deepcopy(m)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
6666
if elem_dtype is torch.float8_e4m3fn:
6767
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 18.0)
6868
else:
69-
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 14.0)
69+
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 13.0)
7070

7171

7272
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,6 +2385,192 @@ def validate(self, model: torch.fx.GraphModule) -> None:
23852385
node_list,
23862386
)
23872387

2388+
def test_conv3d_bn_relu(self):
2389+
class BackendAQuantizer(Quantizer):
2390+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2391+
act_qspec = QuantizationSpec(
2392+
dtype=torch.uint8,
2393+
quant_min=0,
2394+
quant_max=255,
2395+
qscheme=torch.per_tensor_affine,
2396+
is_dynamic=False,
2397+
observer_or_fake_quant_ctr=observer.default_observer,
2398+
)
2399+
weight_qspec = QuantizationSpec(
2400+
dtype=torch.int8,
2401+
quant_min=-128,
2402+
quant_max=127,
2403+
qscheme=torch.per_tensor_affine,
2404+
is_dynamic=False,
2405+
observer_or_fake_quant_ctr=observer.default_weight_observer,
2406+
)
2407+
bias_qspec = QuantizationSpec(
2408+
dtype=torch.float32,
2409+
is_dynamic=False,
2410+
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
2411+
)
2412+
# conv_transpose + bn is fused automatically in PTQ (not configurable)
2413+
# so we just need to annotate conv + relu for conv + bn + relu
2414+
# pattern
2415+
for n in model.graph.nodes:
2416+
if (
2417+
n.op != "call_function"
2418+
or n.target != torch.ops.aten.relu.default
2419+
):
2420+
continue
2421+
relu_node = n
2422+
n = n.args[0]
2423+
if (
2424+
n.op != "call_function"
2425+
and n.target != torch.ops.aten.conv3d.input
2426+
):
2427+
continue
2428+
conv_t_node = n
2429+
input_act = conv_t_node.args[0]
2430+
weight = conv_t_node.args[1]
2431+
bias = conv_t_node.args[2]
2432+
conv_t_node.meta["quantization_annotation"] = (
2433+
QuantizationAnnotation(
2434+
input_qspec_map={
2435+
input_act: act_qspec,
2436+
weight: weight_qspec,
2437+
bias: bias_qspec,
2438+
},
2439+
_annotated=True,
2440+
)
2441+
)
2442+
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
2443+
output_qspec=act_qspec,
2444+
_annotated=True,
2445+
)
2446+
2447+
def validate(self, model: torch.fx.GraphModule) -> None:
2448+
pass
2449+
2450+
class M(torch.nn.Module):
2451+
def __init__(self):
2452+
super().__init__()
2453+
self.conv = torch.nn.Conv3d(2, 2, 3, padding=1)
2454+
self.bn = torch.nn.BatchNorm3d(2)
2455+
2456+
def forward(self, x):
2457+
return torch.nn.functional.relu(self.bn(self.conv(x)))
2458+
2459+
example_inputs = (torch.randn(1, 2, 2, 5, 5),)
2460+
node_occurrence = {
2461+
# two for input of the first conv, one for output for the first conv
2462+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
2463+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
2464+
}
2465+
node_list = [
2466+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2467+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2468+
torch.ops.aten.conv3d.default,
2469+
torch.ops.aten.relu.default,
2470+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
2471+
]
2472+
model = M().eval()
2473+
self._test_quantizer(
2474+
model,
2475+
example_inputs,
2476+
BackendAQuantizer(),
2477+
node_occurrence,
2478+
node_list,
2479+
)
2480+
2481+
def test_conv_transpose3d_bn_relu(self):
2482+
class BackendAQuantizer(Quantizer):
2483+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2484+
act_qspec = QuantizationSpec(
2485+
dtype=torch.uint8,
2486+
quant_min=0,
2487+
quant_max=255,
2488+
qscheme=torch.per_tensor_affine,
2489+
is_dynamic=False,
2490+
observer_or_fake_quant_ctr=observer.default_observer,
2491+
)
2492+
weight_qspec = QuantizationSpec(
2493+
dtype=torch.int8,
2494+
quant_min=-128,
2495+
quant_max=127,
2496+
qscheme=torch.per_tensor_affine,
2497+
is_dynamic=False,
2498+
observer_or_fake_quant_ctr=observer.default_weight_observer,
2499+
)
2500+
bias_qspec = QuantizationSpec(
2501+
dtype=torch.float32,
2502+
is_dynamic=False,
2503+
observer_or_fake_quant_ctr=observer.PlaceholderObserver,
2504+
)
2505+
# conv_transpose + bn is fused automatically in PTQ (not configurable)
2506+
# so we just need to annotate conv_transpose + relu for conv_transpose + bn + relu
2507+
# pattern
2508+
for n in model.graph.nodes:
2509+
if (
2510+
n.op != "call_function"
2511+
or n.target != torch.ops.aten.relu.default
2512+
):
2513+
continue
2514+
relu_node = n
2515+
n = n.args[0]
2516+
if (
2517+
n.op != "call_function"
2518+
and n.target != torch.ops.aten.conv_transposed3d.input
2519+
):
2520+
continue
2521+
conv_t_node = n
2522+
input_act = conv_t_node.args[0]
2523+
weight = conv_t_node.args[1]
2524+
bias = conv_t_node.args[2]
2525+
conv_t_node.meta["quantization_annotation"] = (
2526+
QuantizationAnnotation(
2527+
input_qspec_map={
2528+
input_act: act_qspec,
2529+
weight: weight_qspec,
2530+
bias: bias_qspec,
2531+
},
2532+
_annotated=True,
2533+
)
2534+
)
2535+
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
2536+
output_qspec=act_qspec,
2537+
_annotated=True,
2538+
)
2539+
2540+
def validate(self, model: torch.fx.GraphModule) -> None:
2541+
pass
2542+
2543+
class M(torch.nn.Module):
2544+
def __init__(self):
2545+
super().__init__()
2546+
self.conv_t = torch.nn.ConvTranspose3d(2, 2, 3, padding=1)
2547+
self.bn = torch.nn.BatchNorm3d(2)
2548+
2549+
def forward(self, x):
2550+
return torch.nn.functional.relu(self.bn(self.conv_t(x)))
2551+
2552+
example_inputs = (torch.randn(1, 2, 2, 5, 5),)
2553+
node_occurrence = {
2554+
# two for input of the first conv, one for output for the first conv
2555+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
2556+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
2557+
}
2558+
node_list = [
2559+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2560+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2561+
torch.ops.aten.conv_transpose3d.input,
2562+
torch.ops.aten.relu.default,
2563+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
2564+
]
2565+
model = M().eval()
2566+
self._test_quantizer(
2567+
model,
2568+
example_inputs,
2569+
BackendAQuantizer(),
2570+
node_occurrence,
2571+
node_list,
2572+
)
2573+
23882574
def test_multi_users_without_output_observer(self):
23892575
"""
23902576
Test the case in which a node is used by multiple users,

0 commit comments

Comments
 (0)