Skip to content

Commit 456fecc

Browse files
committed
[Experimental] Enable kleidi AI examples to run on graviton3
1 parent 2a3fbff commit 456fecc

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

torchao/experimental/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,17 @@ if(TORCHAO_BUILD_KLEIDIAI)
2929
endif()
3030
include(CMakePrintHelpers)
3131

32-
add_compile_options("-Wall" "-Werror" "-Wno-deprecated")
32+
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")
33+
add_compile_options("-Wall" "-Werror" "-Wno-deprecated" "-march=armv8.2-a+dotprod" "-fPIC" "-Wno-error=unknown-pragmas")
34+
else()
35+
add_compile_options("-Wall" "-Werror" "-Wno-deprecated")
36+
endif()
3337

3438
include(CMakePrintHelpers)
3539
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
3640
include_directories(${TORCHAO_INCLUDE_DIRS})
3741

38-
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
42+
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")
3943
if(TORCHAO_BUILD_KLEIDIAI)
4044
message(STATUS "Building with Arm KleidiAI library")
4145
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)

torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h>
88
#include <cassert>
99
#include <cstring>
10+
#include <cstdint>
1011

1112
// Interleaves data across channels (row/column) and groups.
1213
// Each channel is the same size (vals_per_channel) and is

torchao/experimental/ops/packed_weights_header.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class PackedWeightsHeader {
5151
auto header = reinterpret_cast<const int*>(packed_weights);
5252
assert(header[0] == PackedWeightsHeader::magic);
5353
params_type params;
54-
for (int i = 0; i < params.size(); i++) {
54+
for (size_t i = 0; i < params.size(); i++) {
5555
params[i] = header[i + 2];
5656
}
5757
return PackedWeightsHeader(
@@ -62,7 +62,7 @@ class PackedWeightsHeader {
6262
if (type != other.type) {
6363
return false;
6464
}
65-
for (int i = 0; i < params.size(); i++) {
65+
for (size_t i = 0; i < params.size(); i++) {
6666
if (params[i] != other.params[i]) {
6767
return false;
6868
}

torchao/experimental/ops/parallel-aten-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// LICENSE file in the root directory of this source tree.
66

77
#pragma once
8-
#include <Aten/Parallel.h>
8+
#include <torchao/experimental/ops/parallel.h>
99
#include <torch/library.h>
1010
#include <torch/torch.h>
1111

0 commit comments

Comments
 (0)