Skip to content

Commit 28ccd73

Browse files
metascroyJack-Khuu
authored andcommitted
Optionally enable KleidiAI + clean up setup.py flags (#1826)
* init * up * up * up * up * up * up * up * up * up
1 parent 9322ee1 commit 28ccd73

File tree

12 files changed

+173
-103
lines changed

12 files changed

+173
-103
lines changed

setup.py

Lines changed: 83 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sys
1010
import time
1111
from datetime import datetime
12+
from typing import List, Optional
1213

1314
from setuptools import Extension, find_packages, setup
1415

@@ -75,19 +76,54 @@ def use_debug_mode():
7576
CUDAExtension,
7677
)
7778

78-
build_torchao_experimental_mps = (
79-
os.getenv("TORCHAO_BUILD_EXPERIMENTAL_MPS") == "1"
80-
and build_torchao_experimental
81-
and torch.mps.is_available()
82-
)
8379

84-
if os.getenv("TORCHAO_BUILD_EXPERIMENTAL_MPS") == "1":
85-
if use_cpp != "1":
86-
print("Building experimental MPS ops requires USE_CPP=1")
87-
if not platform.machine().startswith("arm64") or platform.system() != "Darwin":
88-
print("Experimental MPS ops require Apple Silicon.")
89-
if not torch.mps.is_available():
90-
print("MPS not available. Skipping compilation of experimental MPS ops.")
80+
class BuildOptions:
81+
def __init__(self):
82+
# TORCHAO_BUILD_CPU_AARCH64 is enabled by default on Arm-based Apple machines
83+
# The kernels require sdot/udot, which are not required on Arm until Armv8.4 or later,
84+
# but are available on Arm-based Apple machines. On non-Apple machines, the kernels
85+
# can be built by explicitly setting TORCHAO_BUILD_CPU_AARCH64=1
86+
self.build_cpu_aarch64 = self._os_bool_var(
87+
"TORCHAO_BUILD_CPU_AARCH64",
88+
default=(self._is_arm64() and self._is_macos()),
89+
)
90+
if self.build_cpu_aarch64:
91+
assert (
92+
self._is_arm64()
93+
), "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine"
94+
95+
# TORCHAO_BUILD_KLEIDIAI is disabled by default for now because
96+
# 1) It increases the build time
97+
# 2) It has some accuracy issues in CI tests due to BF16
98+
self.build_kleidi_ai = self._os_bool_var(
99+
"TORCHAO_BUILD_KLEIDIAI", default=False
100+
)
101+
if self.build_kleidi_ai:
102+
assert (
103+
self.build_cpu_aarch64
104+
), "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set"
105+
106+
# TORCHAO_BUILD_EXPERIMENTAL_MPS is disabled by default.
107+
self.build_experimental_mps = self._os_bool_var(
108+
"TORCHAO_BUILD_EXPERIMENTAL_MPS", default=False
109+
)
110+
if self.build_experimental_mps:
111+
assert self._is_macos(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MacOS"
112+
assert self._is_arm64(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires arm64"
113+
assert (
114+
torch.mps.is_available()
115+
), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
116+
117+
def _is_arm64(self) -> bool:
118+
return platform.machine().startswith("arm64")
119+
120+
def _is_macos(self) -> bool:
121+
return platform.system() == "Darwin"
122+
123+
def _os_bool_var(self, var, default) -> bool:
124+
default_val = "1" if default else "0"
125+
return os.getenv(var, default_val) == "1"
126+
91127

92128
# Constant known variables used throughout this file
93129
cwd = os.path.abspath(os.path.curdir)
@@ -179,38 +215,30 @@ def build_extensions(self):
179215
def build_cmake(self, ext):
180216
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
181217

182-
build_type = "Debug" if use_debug_mode() else "Release"
183-
184-
from distutils.sysconfig import get_python_lib
185-
186-
torch_dir = get_python_lib() + "/torch/share/cmake/Torch"
187-
188218
if not os.path.exists(self.build_temp):
189219
os.makedirs(self.build_temp)
190220

191-
build_mps_ops = "ON" if build_torchao_experimental_mps else "OFF"
192-
193221
subprocess.check_call(
194222
[
195223
"cmake",
196-
ext.sourcedir,
197-
"-DCMAKE_BUILD_TYPE=" + build_type,
198-
# Disable now because 1) KleidiAI increases build time, and 2) KleidiAI has accuracy issues due to BF16
199-
"-DTORCHAO_BUILD_KLEIDIAI=OFF",
200-
"-DTORCHAO_BUILD_MPS_OPS=" + build_mps_ops,
201-
"-DTorch_DIR=" + torch_dir,
202-
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
203-
"-DCMAKE_INSTALL_PREFIX=cmake-out",
204-
],
224+
ext.cmake_lists_dir,
225+
]
226+
+ ext.cmake_args
227+
+ ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir],
205228
cwd=self.build_temp,
206229
)
207230
subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp)
208231

209232

210233
class CMakeExtension(Extension):
211-
def __init__(self, name, sourcedir=""):
234+
def __init__(
235+
self, name, cmake_lists_dir: str = "", cmake_args: Optional[List[str]] = None
236+
):
212237
Extension.__init__(self, name, sources=[])
213-
self.sourcedir = os.path.abspath(sourcedir)
238+
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
239+
if cmake_args is None:
240+
cmake_args = []
241+
self.cmake_args = cmake_args
214242

215243

216244
def get_extensions():
@@ -310,10 +338,33 @@ def get_extensions():
310338
)
311339

312340
if build_torchao_experimental:
341+
build_options = BuildOptions()
342+
343+
def bool_to_on_off(value):
344+
return "ON" if value else "OFF"
345+
346+
from distutils.sysconfig import get_python_lib
347+
348+
torch_dir = get_python_lib() + "/torch/share/cmake/Torch"
349+
313350
ext_modules.append(
314351
CMakeExtension(
315352
"torchao.experimental",
316-
sourcedir="torchao/experimental",
353+
cmake_lists_dir="torchao/experimental",
354+
cmake_args=(
355+
[
356+
f"-DCMAKE_BUILD_TYPE={'Debug' if use_debug_mode() else 'Release'}",
357+
f"-DTORCHAO_BUILD_CPU_AARCH64={bool_to_on_off(build_options.build_cpu_aarch64)}",
358+
f"-DTORCHAO_BUILD_KLEIDIAI={bool_to_on_off(build_options.build_kleidi_ai)}",
359+
f"-DTORCHAO_BUILD_MPS_OPS={bool_to_on_off(build_options.build_experimental_mps)}",
360+
"-DTorch_DIR=" + torch_dir,
361+
]
362+
+ (
363+
["-DCMAKE_INSTALL_PREFIX=cmake-out"]
364+
if build_options.build_experimental_mps
365+
else []
366+
)
367+
),
317368
)
318369
)
319370

torchao/experimental/CMakeLists.txt

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,13 @@ endif()
1717

1818
option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF)
1919
option(TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF)
20-
20+
option(TORCHAO_BUILD_CPU_AARCH64 "Build torchao's CPU aarch64 kernels" OFF)
21+
option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF)
2122

2223
if(NOT TORCHAO_INCLUDE_DIRS)
2324
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..)
2425
endif()
2526

26-
option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF)
27-
if(TORCHAO_BUILD_KLEIDIAI)
28-
message(STATUS "Building with Arm KleidiAI library")
29-
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
30-
endif()
3127
include(CMakePrintHelpers)
3228

3329
add_compile_options("-Wall" "-Werror" "-Wno-deprecated")
@@ -36,49 +32,52 @@ include(CMakePrintHelpers)
3632
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
3733
include_directories(${TORCHAO_INCLUDE_DIRS})
3834

39-
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
35+
36+
if(TORCHAO_BUILD_CPU_AARCH64)
37+
message(STATUS "Building with cpu/aarch64")
38+
add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64)
39+
40+
# Defines torchao_kernels_aarch64
41+
add_subdirectory(kernels/cpu/aarch64)
42+
4043
if(TORCHAO_BUILD_KLEIDIAI)
4144
message(STATUS "Building with Arm KleidiAI library")
42-
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
43-
endif()
44-
# Defines target torchao_kernels_aarch64
45-
add_subdirectory(kernels/cpu/aarch64)
46-
add_subdirectory(ops/linear_8bit_act_xbit_weight)
47-
add_subdirectory(ops/embedding_xbit)
48-
49-
add_library(torchao_ops_aten SHARED)
50-
target_link_libraries(
51-
torchao_ops_aten PRIVATE
52-
torchao_ops_linear_8bit_act_xbit_weight_aten
53-
torchao_ops_embedding_xbit_aten
54-
)
55-
if (TORCHAO_BUILD_MPS_OPS)
56-
message(STATUS "Building with MPS support")
57-
add_subdirectory(ops/mps)
58-
target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten)
45+
add_compile_definitions(TORCHAO_ENABLE_KLEIDI)
5946
endif()
47+
endif()
48+
49+
add_subdirectory(ops/linear_8bit_act_xbit_weight)
50+
add_subdirectory(ops/embedding_xbit)
6051

52+
add_library(torchao_ops_aten SHARED)
53+
target_link_libraries(
54+
torchao_ops_aten PRIVATE
55+
torchao_ops_linear_8bit_act_xbit_weight_aten
56+
torchao_ops_embedding_xbit_aten
57+
)
58+
if (TORCHAO_BUILD_MPS_OPS)
59+
message(STATUS "Building with MPS support")
60+
add_subdirectory(ops/mps)
61+
target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten)
62+
endif()
63+
64+
install(
65+
TARGETS torchao_ops_aten
66+
EXPORT _targets
67+
DESTINATION lib
68+
)
69+
if(TORCHAO_BUILD_EXECUTORCH_OPS)
70+
add_library(torchao_ops_executorch STATIC)
71+
target_link_libraries(torchao_ops_executorch PRIVATE
72+
torchao_ops_linear_8bit_act_xbit_weight_executorch
73+
torchao_ops_embedding_xbit_executorch
74+
)
6175
install(
62-
TARGETS torchao_ops_aten
76+
TARGETS
77+
torchao_ops_executorch
78+
torchao_ops_linear_8bit_act_xbit_weight_executorch
79+
torchao_ops_embedding_xbit_executorch
6380
EXPORT _targets
6481
DESTINATION lib
6582
)
66-
if(TORCHAO_BUILD_EXECUTORCH_OPS)
67-
add_library(torchao_ops_executorch STATIC)
68-
target_link_libraries(torchao_ops_executorch PRIVATE
69-
torchao_ops_linear_8bit_act_xbit_weight_executorch
70-
torchao_ops_embedding_xbit_executorch
71-
)
72-
install(
73-
TARGETS
74-
torchao_ops_executorch
75-
torchao_kernels_aarch64
76-
torchao_ops_linear_8bit_act_xbit_weight_executorch
77-
torchao_ops_embedding_xbit_executorch
78-
EXPORT _targets
79-
DESTINATION lib
80-
)
81-
endif()
82-
else()
83-
message(FATAL_ERROR "Torchao experimental ops can only be built on arm64 CPUs.")
8483
endif()

torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64"))
7+
if (TORCHAO_BUILD_CPU_AARCH64)
88
add_library(
99
torchao_kernels_aarch64
1010
${CMAKE_CURRENT_SOURCE_DIR}/reduction/find_min_and_max.cpp
@@ -22,14 +22,11 @@ if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUA
2222
GIT_TAG v1.2.0)
2323
FetchContent_MakeAvailable(kleidiai)
2424

25-
# Temporarily exposing this to the parent scope until we wire
26-
# this up properly from the top level
27-
set(TORCHAO_BUILD_KLEIDI ON PARENT_SCOPE)
2825
target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai)
2926
endif()
30-
endif()
3127

3228
install(
3329
TARGETS torchao_kernels_aarch64
3430
DESTINATION lib
3531
)
32+
endif()

torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ endif()
4040

4141
add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64)
4242

43-
# The TORCHAO_BUILD_KLEIDI cmake variable should be set by `torchao_kernels_aarch64"
44-
if(TORCHAO_BUILD_KLEIDI)
43+
if(TORCHAO_BUILD_KLEIDIAI)
4544
add_compile_definitions(TORCHAO_ENABLE_KLEIDI)
4645
endif()
4746

torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ cmake \
4040
${EXTRA_ARGS} \
4141
-DCMAKE_BUILD_TYPE=Debug \
4242
-DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
43+
-DTORCHAO_BUILD_CPU_AARCH64=ON \
4344
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests \
4445
-B ${CMAKE_OUT}
4546

torchao/experimental/ops/embedding_xbit/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ add_library(torchao_ops_embedding_xbit_aten OBJECT
1313
op_embedding_xbit_aten.cpp
1414
)
1515
target_link_torchao_parallel_backend(torchao_ops_embedding_xbit_aten "aten_openmp")
16-
target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE torchao_kernels_aarch64)
16+
if (TORCHAO_BUILD_CPU_AARCH64)
17+
target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE torchao_kernels_aarch64)
18+
endif()
1719
target_include_directories(torchao_ops_embedding_xbit_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
1820
target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE "${TORCH_LIBRARIES}")
1921
target_compile_definitions(torchao_ops_embedding_xbit_aten PRIVATE USE_ATEN=1)
@@ -32,5 +34,7 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS)
3234
target_include_directories(torchao_ops_embedding_xbit_executorch PRIVATE "${EXECUTORCH_INCLUDE_DIRS}")
3335
target_compile_definitions(torchao_ops_embedding_xbit_executorch PRIVATE USE_EXECUTORCH=1)
3436
target_link_libraries(torchao_ops_embedding_xbit_executorch PRIVATE "${EXECUTORCH_LIBRARIES}")
35-
target_link_libraries(torchao_ops_embedding_xbit_executorch PRIVATE torchao_kernels_aarch64)
37+
if (TORCHAO_BUILD_CPU_AARCH64)
38+
target_link_libraries(torchao_ops_embedding_xbit_executorch PRIVATE torchao_kernels_aarch64)
39+
endif()
3640
endif()

torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
#pragma once
88

9-
#if defined(__aarch64__) || defined(__ARM_NEON)
9+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
1010
#include <torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h>
11-
#endif // defined(__aarch64__) || defined(__ARM_NEON)
11+
#endif // TORCHAO_BUILD_CPU_AARCH64
1212

1313
#include <torchao/experimental/ops/embedding_xbit/packed_weights_header.h>
1414
#include <torchao/experimental/ops/library.h>
@@ -145,7 +145,7 @@ Tensor embedding_out_cpu(
145145
index = index64_ptr[idx];
146146
}
147147
TORCHAO_CHECK(index >= 0 && index < num_embeddings, "index out of bounds");
148-
#if defined(__aarch64__) || defined(__ARM_NEON)
148+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
149149
torchao::kernels::cpu::aarch64::embedding::embedding<weight_nbit>(
150150
out.mutable_data_ptr<float>() + idx * embedding_dim,
151151
embedding_dim,
@@ -157,7 +157,7 @@ Tensor embedding_out_cpu(
157157
index);
158158
#else
159159
TORCHAO_CHECK(false, "Unsupported platform");
160-
#endif // defined(__aarch64__) || defined(__ARM_NEON)
160+
#endif // TORCHAO_BUILD_CPU_AARCH64
161161
});
162162

163163
return out;
@@ -234,7 +234,7 @@ Tensor pack_embedding_cpu(const Tensor& weight_qvals) {
234234
header.write(out.mutable_data_ptr());
235235

236236
torchao::parallel_1d(0, num_embeddings, [&](int64_t idx) {
237-
#if defined(__aarch64__) || defined(__ARM_NEON)
237+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
238238
torchao::kernels::cpu::aarch64::embedding::pack_embedding_weight_qvals<
239239
weight_nbit>(
240240
out.mutable_data_ptr<int8_t>() +
@@ -244,7 +244,7 @@ Tensor pack_embedding_cpu(const Tensor& weight_qvals) {
244244
idx);
245245
#else
246246
TORCHAO_CHECK(false, "Unsupported platform");
247-
#endif // defined(__aarch64__) || defined(__ARM_NEON)
247+
#endif // defined(TORCHAO_BUILD_CPU_AARCH64)
248248
});
249249

250250
return out;

0 commit comments

Comments
 (0)