Skip to content

Optionally enable KleidiAI + clean up setup.py flags #1826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 83 additions & 32 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys
import time
from datetime import datetime
from typing import List, Optional

from setuptools import Extension, find_packages, setup

Expand Down Expand Up @@ -75,19 +76,54 @@ def use_debug_mode():
CUDAExtension,
)

build_torchao_experimental_mps = (
os.getenv("TORCHAO_BUILD_EXPERIMENTAL_MPS") == "1"
and build_torchao_experimental
and torch.mps.is_available()
)

if os.getenv("TORCHAO_BUILD_EXPERIMENTAL_MPS") == "1":
if use_cpp != "1":
print("Building experimental MPS ops requires USE_CPP=1")
if not platform.machine().startswith("arm64") or platform.system() != "Darwin":
print("Experimental MPS ops require Apple Silicon.")
if not torch.mps.is_available():
print("MPS not available. Skipping compilation of experimental MPS ops.")
class BuildOptions:
def __init__(self):
# TORCHAO_BUILD_CPU_AARCH64 is enabled by default on Arm-based Apple machines
# The kernels require sdot/udot, which are not required on Arm until Armv8.4 or later,
# but are available on Arm-based Apple machines. On non-Apple machines, the kernels
# can be built by explicitly setting TORCHAO_BUILD_CPU_AARCH64=1
self.build_cpu_aarch64 = self._os_bool_var(
"TORCHAO_BUILD_CPU_AARCH64",
default=(self._is_arm64() and self._is_macos()),
)
if self.build_cpu_aarch64:
assert (
self._is_arm64()
), "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine"

# TORCHAO_BUILD_KLEIDIAI is disabled by default for now because
# 1) It increases the build time
# 2) It has some accuracy issues in CI tests due to BF16
self.build_kleidi_ai = self._os_bool_var(
"TORCHAO_BUILD_KLEIDIAI", default=False
)
if self.build_kleidi_ai:
assert (
self.build_cpu_aarch64
), "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set"

# TORCHAO_BUILD_EXPERIMENTAL_MPS is disabled by default.
self.build_experimental_mps = self._os_bool_var(
"TORCHAO_BUILD_EXPERIMENTAL_MPS", default=False
)
if self.build_experimental_mps:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to assert _is_arm64 here. the experimental mps ops are intended for apple silicon

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added arm64

assert self._is_macos(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MacOS"
assert self._is_arm64(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires arm64"
assert (
torch.mps.is_available()
), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"

def _is_arm64(self) -> bool:
return platform.machine().startswith("arm64")

def _is_macos(self) -> bool:
return platform.system() == "Darwin"

def _os_bool_var(self, var, default) -> bool:
default_val = "1" if default else "0"
return os.getenv(var, default_val) == "1"


# Constant known variables used throughout this file
cwd = os.path.abspath(os.path.curdir)
Expand Down Expand Up @@ -179,38 +215,30 @@ def build_extensions(self):
def build_cmake(self, ext):
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))

build_type = "Debug" if use_debug_mode() else "Release"

from distutils.sysconfig import get_python_lib

torch_dir = get_python_lib() + "/torch/share/cmake/Torch"

if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)

build_mps_ops = "ON" if build_torchao_experimental_mps else "OFF"

subprocess.check_call(
[
"cmake",
ext.sourcedir,
"-DCMAKE_BUILD_TYPE=" + build_type,
# Disable now because 1) KleidiAI increases build time, and 2) KleidiAI has accuracy issues due to BF16
"-DTORCHAO_BUILD_KLEIDIAI=OFF",
"-DTORCHAO_BUILD_MPS_OPS=" + build_mps_ops,
"-DTorch_DIR=" + torch_dir,
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DCMAKE_INSTALL_PREFIX=cmake-out",
],
ext.cmake_lists_dir,
]
+ ext.cmake_args
+ ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir],
cwd=self.build_temp,
)
subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp)


class CMakeExtension(Extension):
def __init__(self, name, sourcedir=""):
def __init__(
self, name, cmake_lists_dir: str = "", cmake_args: Optional[List[str]] = None
):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
if cmake_args is None:
cmake_args = []
self.cmake_args = cmake_args


def get_extensions():
Expand Down Expand Up @@ -310,10 +338,33 @@ def get_extensions():
)

if build_torchao_experimental:
build_options = BuildOptions()

def bool_to_on_off(value):
return "ON" if value else "OFF"

from distutils.sysconfig import get_python_lib

torch_dir = get_python_lib() + "/torch/share/cmake/Torch"

ext_modules.append(
CMakeExtension(
"torchao.experimental",
sourcedir="torchao/experimental",
cmake_lists_dir="torchao/experimental",
cmake_args=(
[
f"-DCMAKE_BUILD_TYPE={'Debug' if use_debug_mode() else 'Release'}",
f"-DTORCHAO_BUILD_CPU_AARCH64={bool_to_on_off(build_options.build_cpu_aarch64)}",
f"-DTORCHAO_BUILD_KLEIDIAI={bool_to_on_off(build_options.build_kleidi_ai)}",
f"-DTORCHAO_BUILD_MPS_OPS={bool_to_on_off(build_options.build_experimental_mps)}",
"-DTorch_DIR=" + torch_dir,
]
+ (
["-DCMAKE_INSTALL_PREFIX=cmake-out"]
if build_options.build_experimental_mps
else []
)
),
)
)

Expand Down
85 changes: 42 additions & 43 deletions torchao/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,13 @@ endif()

option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF)
option(TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF)

option(TORCHAO_BUILD_CPU_AARCH64 "Build torchao's CPU aarch64 kernels" OFF)
option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF)

if(NOT TORCHAO_INCLUDE_DIRS)
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..)
endif()

option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF)
if(TORCHAO_BUILD_KLEIDIAI)
message(STATUS "Building with Arm KleidiAI library")
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
endif()
include(CMakePrintHelpers)

add_compile_options("-Wall" "-Werror" "-Wno-deprecated")
Expand All @@ -36,49 +32,52 @@ include(CMakePrintHelpers)
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
include_directories(${TORCHAO_INCLUDE_DIRS})

if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")

if(TORCHAO_BUILD_CPU_AARCH64)
message(STATUS "Building with cpu/aarch64")
add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64)

# Defines torchao_kernels_aarch64
add_subdirectory(kernels/cpu/aarch64)

if(TORCHAO_BUILD_KLEIDIAI)
message(STATUS "Building with Arm KleidiAI library")
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
endif()
# Defines target torchao_kernels_aarch64
add_subdirectory(kernels/cpu/aarch64)
add_subdirectory(ops/linear_8bit_act_xbit_weight)
add_subdirectory(ops/embedding_xbit)

add_library(torchao_ops_aten SHARED)
target_link_libraries(
torchao_ops_aten PRIVATE
torchao_ops_linear_8bit_act_xbit_weight_aten
torchao_ops_embedding_xbit_aten
)
if (TORCHAO_BUILD_MPS_OPS)
message(STATUS "Building with MPS support")
add_subdirectory(ops/mps)
target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten)
add_compile_definitions(TORCHAO_ENABLE_KLEIDI)
endif()
endif()

add_subdirectory(ops/linear_8bit_act_xbit_weight)
add_subdirectory(ops/embedding_xbit)

add_library(torchao_ops_aten SHARED)
target_link_libraries(
torchao_ops_aten PRIVATE
torchao_ops_linear_8bit_act_xbit_weight_aten
torchao_ops_embedding_xbit_aten
)
if (TORCHAO_BUILD_MPS_OPS)
message(STATUS "Building with MPS support")
add_subdirectory(ops/mps)
target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten)
endif()

install(
TARGETS torchao_ops_aten
EXPORT _targets
DESTINATION lib
)
if(TORCHAO_BUILD_EXECUTORCH_OPS)
add_library(torchao_ops_executorch STATIC)
target_link_libraries(torchao_ops_executorch PRIVATE
torchao_ops_linear_8bit_act_xbit_weight_executorch
torchao_ops_embedding_xbit_executorch
)
install(
TARGETS torchao_ops_aten
TARGETS
torchao_ops_executorch
torchao_ops_linear_8bit_act_xbit_weight_executorch
torchao_ops_embedding_xbit_executorch
EXPORT _targets
DESTINATION lib
)
if(TORCHAO_BUILD_EXECUTORCH_OPS)
add_library(torchao_ops_executorch STATIC)
target_link_libraries(torchao_ops_executorch PRIVATE
torchao_ops_linear_8bit_act_xbit_weight_executorch
torchao_ops_embedding_xbit_executorch
)
install(
TARGETS
torchao_ops_executorch
torchao_kernels_aarch64
torchao_ops_linear_8bit_act_xbit_weight_executorch
torchao_ops_embedding_xbit_executorch
EXPORT _targets
DESTINATION lib
)
endif()
else()
message(FATAL_ERROR "Torchao experimental ops can only be built on arm64 CPUs.")
endif()
7 changes: 2 additions & 5 deletions torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64"))
if (TORCHAO_BUILD_CPU_AARCH64)
add_library(
torchao_kernels_aarch64
${CMAKE_CURRENT_SOURCE_DIR}/reduction/find_min_and_max.cpp
Expand All @@ -22,14 +22,11 @@ if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUA
GIT_TAG v1.2.0)
FetchContent_MakeAvailable(kleidiai)

# Temporarily exposing this to the parent scope until we wire
# this up properly from the top level
set(TORCHAO_BUILD_KLEIDI ON PARENT_SCOPE)
target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai)
endif()
endif()

install(
TARGETS torchao_kernels_aarch64
DESTINATION lib
)
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ endif()

add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64)

# The TORCHAO_BUILD_KLEIDI cmake variable should be set by `torchao_kernels_aarch64"
if(TORCHAO_BUILD_KLEIDI)
if(TORCHAO_BUILD_KLEIDIAI)
add_compile_definitions(TORCHAO_ENABLE_KLEIDI)
endif()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ cmake \
${EXTRA_ARGS} \
-DCMAKE_BUILD_TYPE=Debug \
-DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
-DTORCHAO_BUILD_CPU_AARCH64=ON \
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests \
-B ${CMAKE_OUT}

Expand Down
8 changes: 6 additions & 2 deletions torchao/experimental/ops/embedding_xbit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ add_library(torchao_ops_embedding_xbit_aten OBJECT
op_embedding_xbit_aten.cpp
)
target_link_torchao_parallel_backend(torchao_ops_embedding_xbit_aten "aten_openmp")
target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE torchao_kernels_aarch64)
if (TORCHAO_BUILD_CPU_AARCH64)
target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE torchao_kernels_aarch64)
endif()
target_include_directories(torchao_ops_embedding_xbit_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE "${TORCH_LIBRARIES}")
target_compile_definitions(torchao_ops_embedding_xbit_aten PRIVATE USE_ATEN=1)
Expand All @@ -32,5 +34,7 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS)
target_include_directories(torchao_ops_embedding_xbit_executorch PRIVATE "${EXECUTORCH_INCLUDE_DIRS}")
target_compile_definitions(torchao_ops_embedding_xbit_executorch PRIVATE USE_EXECUTORCH=1)
target_link_libraries(torchao_ops_embedding_xbit_executorch PRIVATE "${EXECUTORCH_LIBRARIES}")
target_link_libraries(torchao_ops_embedding_xbit_executorch PRIVATE torchao_kernels_aarch64)
if (TORCHAO_BUILD_CPU_AARCH64)
target_link_libraries(torchao_ops_embedding_xbit_executorch PRIVATE torchao_kernels_aarch64)
endif()
endif()
12 changes: 6 additions & 6 deletions torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

#pragma once

#if defined(__aarch64__) || defined(__ARM_NEON)
#if defined(TORCHAO_BUILD_CPU_AARCH64)
#include <torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h>
#endif // defined(__aarch64__) || defined(__ARM_NEON)
#endif // TORCHAO_BUILD_CPU_AARCH64

#include <torchao/experimental/ops/embedding_xbit/packed_weights_header.h>
#include <torchao/experimental/ops/library.h>
Expand Down Expand Up @@ -145,7 +145,7 @@ Tensor embedding_out_cpu(
index = index64_ptr[idx];
}
TORCHAO_CHECK(index >= 0 && index < num_embeddings, "index out of bounds");
#if defined(__aarch64__) || defined(__ARM_NEON)
#if defined(TORCHAO_BUILD_CPU_AARCH64)
torchao::kernels::cpu::aarch64::embedding::embedding<weight_nbit>(
out.mutable_data_ptr<float>() + idx * embedding_dim,
embedding_dim,
Expand All @@ -157,7 +157,7 @@ Tensor embedding_out_cpu(
index);
#else
TORCHAO_CHECK(false, "Unsupported platform");
#endif // defined(__aarch64__) || defined(__ARM_NEON)
#endif // TORCHAO_BUILD_CPU_AARCH64
});

return out;
Expand Down Expand Up @@ -234,7 +234,7 @@ Tensor pack_embedding_cpu(const Tensor& weight_qvals) {
header.write(out.mutable_data_ptr());

torchao::parallel_1d(0, num_embeddings, [&](int64_t idx) {
#if defined(__aarch64__) || defined(__ARM_NEON)
#if defined(TORCHAO_BUILD_CPU_AARCH64)
torchao::kernels::cpu::aarch64::embedding::pack_embedding_weight_qvals<
weight_nbit>(
out.mutable_data_ptr<int8_t>() +
Expand All @@ -244,7 +244,7 @@ Tensor pack_embedding_cpu(const Tensor& weight_qvals) {
idx);
#else
TORCHAO_CHECK(false, "Unsupported platform");
#endif // defined(__aarch64__) || defined(__ARM_NEON)
#endif // defined(TORCHAO_BUILD_CPU_AARCH64)
});

return out;
Expand Down
Loading
Loading