Skip to content

Commit 92a3da4

Browse files
committed
up
1 parent f401783 commit 92a3da4

File tree

11 files changed

+149
-101
lines changed

11 files changed

+149
-101
lines changed

setup.py

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,53 @@ def use_debug_mode():
7676
IS_WINDOWS,
7777
)
7878

79-
build_torchao_experimental_mps = (
80-
os.getenv("TORCHAO_BUILD_EXPERIMENTAL_MPS") == "1"
81-
and build_torchao_experimental
82-
and torch.mps.is_available()
83-
)
8479

85-
if os.getenv("TORCHAO_BUILD_EXPERIMENTAL_MPS") == "1":
86-
if use_cpp != "1":
87-
print("Building experimental MPS ops requires USE_CPP=1")
88-
if not platform.machine().startswith("arm64") or platform.system() != "Darwin":
89-
print("Experimental MPS ops require Apple Silicon.")
90-
if not torch.mps.is_available():
91-
print("MPS not available. Skipping compilation of experimental MPS ops.")
80+
class BuildOptions:
81+
def __init__(self):
82+
# TORCHAO_BUILD_CPU_AARCH64 is enabled by default on Arm-based Apple machines
83+
# The kernels require sdot/udot, which are not required on Arm until Armv8.4 or later,
84+
# but are available on Arm-based Apple machines. On non-Apple machines, the kernels
85+
# can be built by explicitly setting TORCHAO_BUILD_CPU_AARCH64=1
86+
self.build_cpu_aarch64 = self._os_bool_var(
87+
"TORCHAO_BUILD_CPU_AARCH64",
88+
default=(self._is_arm64() and self._is_macos()),
89+
)
90+
if self.build_cpu_aarch64:
91+
assert (
92+
self._is_arm64()
93+
), "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine"
94+
95+
# TORCHAO_BUILD_KLEIDIAI is disabled by default for now because
96+
# 1) It increases the build time
97+
# 2) It has some accuracy issues in CI tests due to BF16
98+
self.build_kleidi_ai = self._os_bool_var(
99+
"TORCHAO_BUILD_KLEIDIAI", default=False
100+
)
101+
if self.build_kleidi_ai:
102+
assert (
103+
self.build_cpu_aarch64
104+
), "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set"
105+
106+
# TORCHAO_BUILD_EXPERIMENTAL_MPS is disabled by default.
107+
self.build_experimental_mps = self._os_bool_var(
108+
"TORCHAO_BUILD_EXPERIMENTAL_MPS", default=False
109+
)
110+
if self.build_experimental_mps:
111+
assert self._is_macos(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MacOS"
112+
assert (
113+
torch.mps.is_available()
114+
), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
115+
116+
def _is_arm64(self) -> bool:
117+
return platform.machine().startswith("arm64")
118+
119+
def _is_macos(self) -> bool:
120+
return platform.system() == "Darwin"
121+
122+
def _os_bool_var(self, var, default) -> bool:
123+
default_val = "1" if default else "0"
124+
return os.getenv(var, default_val) == "1"
125+
92126

93127
# Constant known variables used throughout this file
94128
cwd = os.path.abspath(os.path.curdir)
@@ -303,6 +337,11 @@ def get_extensions():
303337
)
304338

305339
if build_torchao_experimental:
340+
build_options = BuildOptions()
341+
342+
def bool_to_on_off(value):
343+
return "ON" if value else "OFF"
344+
306345
from distutils.sysconfig import get_python_lib
307346

308347
torch_dir = get_python_lib() + "/torch/share/cmake/Torch"
@@ -313,23 +352,17 @@ def get_extensions():
313352
cmake_lists_dir="torchao/experimental",
314353
cmake_args=(
315354
[
316-
(
317-
"-DCMAKE_BUILD_TYPE=" + "Debug"
318-
if use_debug_mode()
319-
else "Release"
320-
),
321-
# Disable now because 1) KleidiAI increases build time, and 2) KleidiAI has accuracy issues due to BF16
322-
"-DTORCHAO_BUILD_KLEIDIAI=ON",
323-
(
324-
"-DTORCHAO_BUILD_MPS_OPS=" + "ON"
325-
if build_torchao_experimental_mps
326-
else "OFF"
327-
),
355+
f"-DCMAKE_BUILD_TYPE={'Debug' if use_debug_mode() else 'Release'}",
356+
f"-DTORCHAO_BUILD_CPU_AARCH64={bool_to_on_off(build_options.build_cpu_aarch64)}",
357+
f"-DTORCHAO_BUILD_KLEIDIAI={bool_to_on_off(build_options.build_kleidi_ai)}",
358+
f"-DTORCHAO_BUILD_MPS_OPS={bool_to_on_off(build_options.build_experimental_mps)}",
328359
"-DTorch_DIR=" + torch_dir,
329360
]
330-
+ ["-DCMAKE_INSTALL_PREFIX=cmake-out"]
331-
if build_torchao_experimental_mps
332-
else []
361+
+ (
362+
["-DCMAKE_INSTALL_PREFIX=cmake-out"]
363+
if build_options.build_experimental_mps
364+
else []
365+
)
333366
),
334367
)
335368
)

torchao/experimental/CMakeLists.txt

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,13 @@ endif()
1717

1818
option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF)
1919
option(TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF)
20-
20+
option(TORCHAO_BUILD_CPU_AARCH64 "Build torchao's CPU aarch64 kernels" OFF)
21+
option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF)
2122

2223
if(NOT TORCHAO_INCLUDE_DIRS)
2324
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..)
2425
endif()
2526

26-
option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF)
27-
if(TORCHAO_BUILD_KLEIDIAI)
28-
message(STATUS "Building with Arm KleidiAI library")
29-
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
30-
endif()
3127
include(CMakePrintHelpers)
3228

3329
add_compile_options("-Wall" "-Werror" "-Wno-deprecated")
@@ -36,49 +32,56 @@ include(CMakePrintHelpers)
3632
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
3733
include_directories(${TORCHAO_INCLUDE_DIRS})
3834

39-
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
35+
36+
if(TORCHAO_BUILD_CPU_AARCH64)
37+
message(STATUS "Building with cpu/aarch64")
38+
add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64)
39+
4040
if(TORCHAO_BUILD_KLEIDIAI)
4141
message(STATUS "Building with Arm KleidiAI library")
4242
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
4343
endif()
44-
# Defines target torchao_kernels_aarch64
44+
endif()
45+
46+
47+
if (TORCHAO_BUILD_CPU_AARCH64)
48+
# Defines torchao_kernels_aarch64
4549
add_subdirectory(kernels/cpu/aarch64)
46-
add_subdirectory(ops/linear_8bit_act_xbit_weight)
47-
add_subdirectory(ops/embedding_xbit)
48-
49-
add_library(torchao_ops_aten SHARED)
50-
target_link_libraries(
51-
torchao_ops_aten PRIVATE
52-
torchao_ops_linear_8bit_act_xbit_weight_aten
53-
torchao_ops_embedding_xbit_aten
54-
)
55-
if (TORCHAO_BUILD_MPS_OPS)
56-
message(STATUS "Building with MPS support")
57-
add_subdirectory(ops/mps)
58-
target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten)
59-
endif()
50+
endif()
51+
52+
add_subdirectory(ops/linear_8bit_act_xbit_weight)
53+
add_subdirectory(ops/embedding_xbit)
6054

55+
add_library(torchao_ops_aten SHARED)
56+
target_link_libraries(
57+
torchao_ops_aten PRIVATE
58+
torchao_ops_linear_8bit_act_xbit_weight_aten
59+
torchao_ops_embedding_xbit_aten
60+
)
61+
if (TORCHAO_BUILD_MPS_OPS)
62+
message(STATUS "Building with MPS support")
63+
add_subdirectory(ops/mps)
64+
target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten)
65+
endif()
66+
67+
install(
68+
TARGETS torchao_ops_aten
69+
EXPORT _targets
70+
DESTINATION lib
71+
)
72+
if(TORCHAO_BUILD_EXECUTORCH_OPS)
73+
add_library(torchao_ops_executorch STATIC)
74+
target_link_libraries(torchao_ops_executorch PRIVATE
75+
torchao_ops_linear_8bit_act_xbit_weight_executorch
76+
torchao_ops_embedding_xbit_executorch
77+
)
6178
install(
62-
TARGETS torchao_ops_aten
79+
TARGETS
80+
torchao_ops_executorch
81+
# torchao_kernels_aarch64
82+
torchao_ops_linear_8bit_act_xbit_weight_executorch
83+
torchao_ops_embedding_xbit_executorch
6384
EXPORT _targets
6485
DESTINATION lib
6586
)
66-
if(TORCHAO_BUILD_EXECUTORCH_OPS)
67-
add_library(torchao_ops_executorch STATIC)
68-
target_link_libraries(torchao_ops_executorch PRIVATE
69-
torchao_ops_linear_8bit_act_xbit_weight_executorch
70-
torchao_ops_embedding_xbit_executorch
71-
)
72-
install(
73-
TARGETS
74-
torchao_ops_executorch
75-
torchao_kernels_aarch64
76-
torchao_ops_linear_8bit_act_xbit_weight_executorch
77-
torchao_ops_embedding_xbit_executorch
78-
EXPORT _targets
79-
DESTINATION lib
80-
)
81-
endif()
82-
else()
83-
message(FATAL_ERROR "Torchao experimental ops can only be built on arm64 CPUs.")
8487
endif()

torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64"))
7+
if (TORCHAO_BUILD_CPU_AARCH64)
88
add_library(
99
torchao_kernels_aarch64
1010
${CMAKE_CURRENT_SOURCE_DIR}/reduction/find_min_and_max.cpp
@@ -27,9 +27,9 @@ if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUA
2727
set(TORCHAO_BUILD_KLEIDI ON PARENT_SCOPE)
2828
target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai)
2929
endif()
30-
endif()
3130

3231
install(
3332
TARGETS torchao_kernels_aarch64
3433
DESTINATION lib
3534
)
35+
endif()

torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ cmake \
4040
${EXTRA_ARGS} \
4141
-DCMAKE_BUILD_TYPE=Debug \
4242
-DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
43+
-DTORCHAO_BUILD_CPU_AARCH64=ON \
4344
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests \
4445
-B ${CMAKE_OUT}
4546

torchao/experimental/ops/embedding_xbit/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ add_library(torchao_ops_embedding_xbit_aten OBJECT
1313
op_embedding_xbit_aten.cpp
1414
)
1515
target_link_torchao_parallel_backend(torchao_ops_embedding_xbit_aten "aten_openmp")
16-
target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE torchao_kernels_aarch64)
16+
if (TORCHAO_BUILD_CPU_AARCH64)
17+
target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE torchao_kernels_aarch64)
18+
endif()
1719
target_include_directories(torchao_ops_embedding_xbit_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
1820
target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE "${TORCH_LIBRARIES}")
1921
target_compile_definitions(torchao_ops_embedding_xbit_aten PRIVATE USE_ATEN=1)
@@ -32,5 +34,7 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS)
3234
target_include_directories(torchao_ops_embedding_xbit_executorch PRIVATE "${EXECUTORCH_INCLUDE_DIRS}")
3335
target_compile_definitions(torchao_ops_embedding_xbit_executorch PRIVATE USE_EXECUTORCH=1)
3436
target_link_libraries(torchao_ops_embedding_xbit_executorch PRIVATE "${EXECUTORCH_LIBRARIES}")
35-
target_link_libraries(torchao_ops_embedding_xbit_executorch PRIVATE torchao_kernels_aarch64)
37+
if (TORCHAO_BUILD_CPU_AARCH64)
38+
target_link_libraries(torchao_ops_embedding_xbit_executorch PRIVATE torchao_kernels_aarch64)
39+
endif()
3640
endif()

torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
#pragma once
88

9-
#if defined(__aarch64__) || defined(__ARM_NEON)
9+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
1010
#include <torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h>
11-
#endif // defined(__aarch64__) || defined(__ARM_NEON)
11+
#endif // TORCHAO_BUILD_CPU_AARCH64
1212

1313
#include <torchao/experimental/ops/embedding_xbit/packed_weights_header.h>
1414
#include <torchao/experimental/ops/library.h>
@@ -145,7 +145,7 @@ Tensor embedding_out_cpu(
145145
index = index64_ptr[idx];
146146
}
147147
TORCHAO_CHECK(index >= 0 && index < num_embeddings, "index out of bounds");
148-
#if defined(__aarch64__) || defined(__ARM_NEON)
148+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
149149
torchao::kernels::cpu::aarch64::embedding::embedding<weight_nbit>(
150150
out.mutable_data_ptr<float>() + idx * embedding_dim,
151151
embedding_dim,
@@ -157,7 +157,7 @@ Tensor embedding_out_cpu(
157157
index);
158158
#else
159159
TORCHAO_CHECK(false, "Unsupported platform");
160-
#endif // defined(__aarch64__) || defined(__ARM_NEON)
160+
#endif // TORCHAO_BUILD_CPU_AARCH64
161161
});
162162

163163
return out;
@@ -234,7 +234,7 @@ Tensor pack_embedding_cpu(const Tensor& weight_qvals) {
234234
header.write(out.mutable_data_ptr());
235235

236236
torchao::parallel_1d(0, num_embeddings, [&](int64_t idx) {
237-
#if defined(__aarch64__) || defined(__ARM_NEON)
237+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
238238
torchao::kernels::cpu::aarch64::embedding::pack_embedding_weight_qvals<
239239
weight_nbit>(
240240
out.mutable_data_ptr<int8_t>() +
@@ -244,7 +244,7 @@ Tensor pack_embedding_cpu(const Tensor& weight_qvals) {
244244
idx);
245245
#else
246246
TORCHAO_CHECK(false, "Unsupported platform");
247-
#endif // defined(__aarch64__) || defined(__ARM_NEON)
247+
#endif // defined(TORCHAO_BUILD_CPU_AARCH64)
248248
});
249249

250250
return out;

torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,18 @@ FetchContent_Declare(cpuinfo
1818
FetchContent_MakeAvailable(
1919
cpuinfo)
2020

21-
FetchContent_Declare(
22-
glog
23-
GIT_REPOSITORY https://github.com/google/glog.git
24-
GIT_TAG v0.7.1
25-
)
26-
FetchContent_MakeAvailable(glog)
2721

2822
find_package(Torch REQUIRED)
2923
add_library(torchao_ops_linear_8bit_act_xbit_weight_aten OBJECT
3024
linear_8bit_act_xbit_weight.cpp
3125
op_linear_8bit_act_xbit_weight_aten.cpp
3226
)
3327
target_link_torchao_parallel_backend(torchao_ops_linear_8bit_act_xbit_weight_aten aten_openmp)
34-
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE torchao_kernels_aarch64)
28+
29+
if(TORCHAO_BUILD_CPU_AARCH64)
30+
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE torchao_kernels_aarch64)
31+
endif()
3532
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE cpuinfo)
36-
# target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE glog)
3733
target_include_directories(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
3834
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}")
3935
target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE USE_ATEN=1)
@@ -55,7 +51,8 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS)
5551
target_include_directories(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_INCLUDE_DIRS}")
5652
target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1)
5753
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_LIBRARIES}")
58-
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE torchao_kernels_aarch64)
54+
if(TORCHAO_BUILD_CPU_AARCH64)
55+
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE torchao_kernels_aarch64)
56+
endif()
5957
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE cpuinfo)
60-
# target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE glog)
6158
endif()

0 commit comments

Comments
 (0)