Skip to content

Commit 959bb1b

Browse files
authored
Update ExecuTorch for XNNPACK 87ee0b4 (#4916) (#4916)
Summary: Pull Request resolved: #4916 Reviewed By: digantdesai Differential Revision: D61822607 Pulled By: GregoryComer
1 parent 9fd8e53 commit 959bb1b

File tree

11 files changed

+102
-74
lines changed

11 files changed

+102
-74
lines changed

.gitmodules

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
url = https://github.com/Maratyszcza/FXdiv.git
2222
[submodule "backends/xnnpack/third-party/XNNPACK"]
2323
path = backends/xnnpack/third-party/XNNPACK
24-
url = https://github.com/digantdesai/XNNPACK.git
24+
url = https://github.com/google/XNNPACK.git
2525
[submodule "backends/xnnpack/third-party/cpuinfo"]
2626
path = backends/xnnpack/third-party/cpuinfo
2727
url = https://github.com/pytorch/cpuinfo.git

backends/xnnpack/cmake/Dependencies.cmake

+4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ set(XNNPACK_ENABLE_AVXVNNI
3636
OFF
3737
CACHE BOOL ""
3838
)
39+
set(XNNPACK_ENABLE_KLEIDIAI
40+
OFF
41+
CACHE BOOL ""
42+
)
3943
add_subdirectory("${XNNPACK_SOURCE_DIR}")
4044
include_directories(SYSTEM ${XNNPACK_INCLUDE_DIR})
4145
list(APPEND xnnpack_third_party XNNPACK)

backends/xnnpack/runtime/XNNCompiler.cpp

+66-7
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,25 @@ namespace executor {
2121
namespace xnnpack {
2222
namespace delegate {
2323

24+
/*
25+
* Provide compile-time allocation.
26+
*/
27+
class CompileAllocator {
28+
public:
29+
/*
30+
* Allocate memory which will be automatically freed at the end
31+
* of the compilation process.
32+
*/
33+
void* allocateTemporary(size_t size) {
34+
auto mem = new uint8_t[size];
35+
temporaries_.emplace_back(mem);
36+
return mem;
37+
}
38+
39+
private:
40+
std::vector<std::unique_ptr<uint8_t[]>> temporaries_;
41+
};
42+
2443
// Flatbuffer types
2544
using ValuePtr = const fb_xnnpack::XValue*;
2645
using NodePtr = const fb_xnnpack::XNode*;
@@ -35,6 +54,23 @@ using DefineNodeFunc = Error (*)(
3554
const std::unordered_map<uint32_t, uint32_t>&,
3655
NodePtr) noexcept;
3756

57+
/*
58+
Convert a tensor from fp32 to bf16.
59+
*/
60+
void convertF32TensorToBF16(
61+
const float* f32_data,
62+
uint16_t* bf16_data_out,
63+
size_t numel) {
64+
for (auto i = 0u; i < numel; i++) {
65+
// Adjust the f32 value such that it rounds properly after truncation.
66+
// Constant factor scales 1+2^-8 to 1+2e-7.
67+
float f32_adjusted = f32_data[i] * 1.00389105f;
68+
uint32_t f32_bits;
69+
memcpy(&f32_bits, &f32_adjusted, sizeof(float));
70+
bf16_data_out[i] = static_cast<uint16_t>(f32_bits >> 16);
71+
}
72+
}
73+
3874
/*
3975
Gets the output min and output max for a given node operator
4076
*/
@@ -152,7 +188,8 @@ Error defineTensor(
152188
GraphPtr flatbuffer_graph,
153189
const uint8_t* constant_data_ptr,
154190
std::vector<uint32_t>& input_ids,
155-
std::vector<uint32_t>& output_ids) {
191+
std::vector<uint32_t>& output_ids,
192+
CompileAllocator& allocator) {
156193
const fb_xnnpack::XNNTensorValue* tensor_value = nullptr;
157194
const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr;
158195

@@ -356,12 +393,31 @@ Error defineTensor(
356393
size_t group_size = qparams->group_size();
357394
size_t output_channels = tensor_value->dims()->Get(0);
358395
size_t input_channels = tensor_value->dims()->Get(1);
396+
397+
const uint16_t* scale_data = nullptr;
398+
uint32_t scale_numel = 0;
399+
400+
// Block scales are preferably serialized as bf16 but can also be
401+
// serialized as fp32 for backwards compatability.
402+
if (qparams->scale_bf16() != nullptr) {
403+
scale_data =
404+
static_cast<const uint16_t*>(qparams->scale_bf16()->data());
405+
scale_numel = qparams->scale_bf16()->size();
406+
} else {
407+
// Read fp32 scales, convert to bf16.
408+
auto conv_buffer = static_cast<uint16_t*>(allocator.allocateTemporary(
409+
qparams->scale()->size() * sizeof(uint16_t)));
410+
scale_numel = qparams->scale()->size();
411+
convertF32TensorToBF16(
412+
qparams->scale()->data(), conv_buffer, scale_numel);
413+
scale_data = conv_buffer;
414+
}
415+
359416
ET_CHECK_OR_RETURN_ERROR(
360-
qparams->scale()->size() ==
361-
output_channels * input_channels / group_size,
417+
scale_numel == output_channels * input_channels / group_size,
362418
Internal,
363419
"scale size %zu != output channels %zu * group size %zu",
364-
(size_t)qparams->scale()->size(),
420+
static_cast<size_t>(scale_numel),
365421
output_channels,
366422
group_size);
367423
int32_t zero_point =
@@ -370,18 +426,19 @@ Error defineTensor(
370426
Debug,
371427
"define quant tensor (per channel group): buffer_ptr: %p, scale.numel(): %u, channel_dim: %u, grpup_size: %zu, output_channels: %zu, dtype: %u, zero_point: %d, datatype: %d\n",
372428
buffer_ptr,
373-
qparams->scale()->size(),
429+
scale_numel,
374430
qparams->channel_dim(),
375431
group_size,
376432
output_channels,
377433
datatype,
378434
zero_point,
379435
datatype);
436+
380437
status = xnn_define_blockwise_quantized_tensor_value(
381438
/*subgraph=*/subgraph_ptr,
382439
/*datatype=*/datatype,
383440
/*zero_point=*/zero_point,
384-
/*scale=*/qparams->scale()->data(),
441+
/*scale=*/scale_data,
385442
/*num_dims=*/tensor_value->num_dims(),
386443
/*channel_dim=*/qparams->channel_dim(),
387444
/*block_size=*/qparams->group_size(),
@@ -1617,6 +1674,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
16171674
Result<XNNHeader> header = XNNHeader::Parse(buffer_pointer, num_bytes);
16181675
const uint8_t* flatbuffer_data = nullptr;
16191676
const uint8_t* constant_data = nullptr;
1677+
CompileAllocator compile_allocator;
16201678

16211679
// Header status can only either be Error::Ok or Error::NotFound
16221680
if (header.ok()) {
@@ -1688,7 +1746,8 @@ ET_NODISCARD Error XNNCompiler::compileModel(
16881746
flatbuffer_graph,
16891747
constant_data,
16901748
input_ids,
1691-
output_ids);
1749+
output_ids,
1750+
compile_allocator);
16921751

16931752
if (err != Error::Ok) {
16941753
return err;

backends/xnnpack/serialization/runtime_schema.fbs

+1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ table PerChannelGroupQuant {
6363
scale:[float];
6464
channel_dim:int;
6565
group_size:int;
66+
scale_bf16:[ushort];
6667
}
6768

6869
table XNNTensorValue {

backends/xnnpack/serialization/schema.fbs

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ table PerChannelGroupQuant {
4848
scale:[float];
4949
channel_dim:int;
5050
group_size:int;
51+
scale_bf16:[ushort];
5152
}
5253

5354
table PerChannelQuant {

backends/xnnpack/test/ops/linear.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,8 @@ def test_qd8_per_channel_linear_parallel_and_sequential(self):
407407
)
408408
def test_qd8_fp32_per_token_weight_per_channel_group_int4(self):
409409
M_sizes = [1, 2, 17, 31]
410-
K_sizes = [8, 32, 64, 128]
411-
bl_sizes = [8, 16, 16, 32]
410+
K_sizes = [32, 32, 64, 128]
411+
bl_sizes = [32, 32, 32, 64]
412412
N_sizes = [2, 17, 92, 128]
413413

414414
for use_bias in [True, False]:
@@ -430,8 +430,8 @@ def test_qd8_fp32_per_token_weight_per_channel_group_int4(self):
430430
)
431431
def test_qd8_fp16_per_token_weight_per_channel_group_int4(self):
432432
M_sizes = [1, 2, 17, 31]
433-
K_sizes = [8, 32, 64, 128]
434-
bl_sizes = [8, 16, 16, 32]
433+
K_sizes = [32, 32, 64, 128]
434+
bl_sizes = [32, 32, 32, 64]
435435
N_sizes = [2, 17, 92, 128]
436436

437437
for use_bias in [True, False]:
@@ -602,8 +602,8 @@ def _test_groupwise_dq_linear(
602602
use_bias: bool = False,
603603
group_size: int = 8,
604604
num_linears: int = 1,
605-
atol: float = 1e-3,
606-
rtol: float = 1e-3,
605+
atol: float = 5e-3,
606+
rtol: float = 5e-3,
607607
):
608608
quantize_(mod, int8_dynamic_activation_int4_weight(group_size=group_size))
609609
unwrap_tensor_subclass(mod)

backends/xnnpack/third-party/XNNPACK

Submodule XNNPACK updated 13544 files

backends/xnnpack/third-party/generate-xnnpack-wrappers.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22

33
from __future__ import print_function
4+
from pathlib import Path
45
import collections
56
import os
67
import sys
@@ -36,8 +37,8 @@
3637
"PROD_AVX512F_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
3738
"PROD_AVX512SKX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
3839
"PROD_AVX512VBMI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
39-
"PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
4040
"PROD_AVX512VNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
41+
"PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
4142
"PROD_RVV_MICROKERNEL_SRCS": "defined(__riscv) || defined(__riscv__)",
4243
"PROD_AVXVNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
4344
"AARCH32_ASM_MICROKERNEL_SRCS": "defined(__arm__)",
@@ -46,7 +47,7 @@
4647
# add non-prod microkernel sources here:
4748
}
4849

49-
SRC_NAMES = set([
50+
SRC_NAMES = {
5051
"OPERATOR_SRCS",
5152
"SUBGRAPH_SRCS",
5253
"LOGGING_SRCS",
@@ -81,30 +82,42 @@
8182
"PROD_AVX512F_MICROKERNEL_SRCS",
8283
"PROD_AVX512SKX_MICROKERNEL_SRCS",
8384
"PROD_AVX512VBMI_MICROKERNEL_SRCS",
84-
"PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS",
8585
"PROD_AVX512VNNI_MICROKERNEL_SRCS",
86+
"PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS",
8687
"PROD_RVV_MICROKERNEL_SRCS",
8788
"PROD_AVXVNNI_MICROKERNEL_SRCS",
8889
"AARCH32_ASM_MICROKERNEL_SRCS",
8990
"AARCH64_ASM_MICROKERNEL_SRCS",
9091

9192
# add non-prod microkernel sources here:
92-
])
93+
}
9394

9495
def handle_singleline_parse(line):
9596
start_index = line.find("(")
9697
end_index = line.find(")")
9798
line = line[start_index+1:end_index]
9899
key_val = line.split(" ")
99-
return key_val[0], list(map(lambda x: x[4:], key_val[1:]))
100+
return key_val[0], [x[4:] for x in key_val[1:]]
100101

101102
def update_sources(xnnpack_path, cmakefile = "XNNPACK/CMakeLists.txt"):
103+
print(f"Updating sources from {cmakefile}")
102104
sources = collections.defaultdict(list)
103105
with open(os.path.join(xnnpack_path, cmakefile)) as cmake:
104106
lines = cmake.readlines()
105107
i = 0
106108
while i < len(lines):
107109
line = lines[i]
110+
111+
if lines[i].startswith("INCLUDE"):
112+
file, _ = handle_singleline_parse(line)
113+
if file.startswith("cmake/gen/"):
114+
path = Path(xnnpack_path) / "XNNPACK" / file
115+
local_sources = update_sources(xnnpack_path, path.absolute().as_posix())
116+
for k,v in local_sources.items():
117+
if k in sources:
118+
sources[k] = sources[k] + local_sources[k]
119+
else:
120+
sources[k] = local_sources[k]
108121

109122
if lines[i].startswith("SET") and "src/" in lines[i]:
110123
name, val = handle_singleline_parse(line)
@@ -132,7 +145,7 @@ def gen_wrappers(xnnpack_path):
132145
xnnpack_sources = collections.defaultdict(list)
133146
sources = update_sources(xnnpack_path)
134147

135-
microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/microkernels.cmake")
148+
microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/gen/microkernels.cmake")
136149
for key in microkernels_sources:
137150
sources[key] = microkernels_sources[key]
138151

@@ -186,6 +199,8 @@ def gen_wrappers(xnnpack_path):
186199

187200

188201
def main(argv):
202+
print("Generating wrappers...")
203+
189204
if argv is None or len(argv) == 0:
190205
gen_wrappers(".")
191206
else:

backends/xnnpack/third-party/xnnpack.buck.bzl

+1-27
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
load("//third-party:glob_defs.bzl", "subdir_glob")
22
load(
33
":xnnpack_src_defs.bzl",
4-
"JIT_SRCS",
54
"LOGGING_SRCS",
65
"OPERATOR_SRCS",
76
"SUBGRAPH_SRCS",
@@ -69,27 +68,6 @@ def define_xnnpack():
6968
],
7069
)
7170

72-
# @lint-ignore BUCKLINT: native and fb_native are explicitly forbidden in fbcode.
73-
native.cxx_library(
74-
name = "jit_memory",
75-
srcs = JIT_SRCS,
76-
headers = subdir_glob([
77-
("XNNPACK/src", "**/*.h"),
78-
]),
79-
header_namespace = "",
80-
compiler_flags = [
81-
"-std=c++17",
82-
],
83-
preferred_linkage = "static",
84-
preprocessor_flags = [
85-
"-DXNN_LOG_LEVEL=0",
86-
],
87-
exported_deps = [
88-
":clog",
89-
":interface",
90-
],
91-
)
92-
9371
# @lint-ignore BUCKLINT: native and fb_native are explicitly forbidden in fbcode.
9472
native.cxx_library(
9573
name = "operators",
@@ -139,7 +117,6 @@ def define_xnnpack():
139117
preferred_linkage = "static",
140118
preprocessor_flags = [
141119
"-DXNN_LOG_LEVEL=0",
142-
"-DXNN_ENABLE_JIT=0",
143120
"-DXNN_ENABLE_SPARSE=0",
144121
"-DXNN_ENABLE_GEMM_M_SPECIALIZATION=0",
145122
"-DXNN_ENABLE_MEMOPT",
@@ -1223,7 +1200,6 @@ def define_xnnpack():
12231200
]
12241201

12251202
ARM_XNNPACK_DEPS = [
1226-
":jit_memory",
12271203
":ukernels_armsimd32",
12281204
":ukernels_fp16arith",
12291205
":ukernels_asm",
@@ -1246,11 +1222,10 @@ def define_xnnpack():
12461222
"XNNPACK/src/configs/hardware-config.c",
12471223
"XNNPACK/src/microparams-init.c",
12481224
"XNNPACK/src/operator-run.c",
1249-
"XNNPACK/src/operators/post-operation.c",
12501225
"XNNPACK/src/microkernel-utils.c",
12511226
],
12521227
headers = subdir_glob([
1253-
("XNNPACK/src", "xnnpack/*.h"),
1228+
("XNNPACK/src", "**/*.h"),
12541229
("XNNPACK/include", "**/*.h"),
12551230
]),
12561231
exported_headers = {
@@ -1271,7 +1246,6 @@ def define_xnnpack():
12711246
"-DXNN_NO_X8_OPERATORS",
12721247
"-DXNN_ENABLE_MEMOPT",
12731248
"-DXNN_ENABLE_SPARSE=0",
1274-
"-DXNN_ENABLE_JIT=0",
12751249
"-DXNN_ENABLE_ASSEMBLY",
12761250
"-DXNN_ENABLE_GEMM_M_SPECIALIZATION",
12771251
"-DXNN_ENABLE_ARM_DOTPROD",

0 commit comments

Comments
 (0)