-
Notifications
You must be signed in to change notification settings - Fork 354
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Provide kernels with true reference implementations for quantized ops (…
…#4108) Summary: Pull Request resolved: #4108 We want to be able to run the reference implementations on x86, so we don't want any intrinsics or anything like that in the reference kernels. In the end, this change has a lot of things: - introduce a `reference` folder for reference implementations - moved the primary cmake flow from HiFi to reference, so that the default mode can run on x86 - that means we will need a proper flag to use HiFi optimized ops, which we will add later - add a `quantized_matmul` reference kernel Differential Revision: D59238748
- Loading branch information
1 parent
f9efb05
commit 9215276
Showing
17 changed files
with
757 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# lint_cmake: -linelength | ||
add_library( | ||
cadence_kernels | ||
kernels.cpp | ||
) | ||
|
||
target_include_directories( | ||
cadence_kernels | ||
PUBLIC . | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include "kernels.h" | ||
|
||
#include <algorithm> | ||
#include <limits> | ||
|
||
namespace impl { | ||
namespace reference { | ||
namespace kernels { | ||
|
||
// Quantize a fp32 value to an int8_t/uint8_t value | ||
template <typename T> | ||
__attribute__((always_inline)) T | ||
quantize(const float x, float scale, int32_t zero_point) { | ||
constexpr float min_val = std::numeric_limits<T>::min(); | ||
constexpr float max_val = std::numeric_limits<T>::max(); | ||
float tmp = roundf(x * scale + zero_point); | ||
return std::max(std::min(tmp, max_val), min_val); | ||
} | ||
|
||
// Quantize an fp32 array to an int8_t/uint8_t array | ||
template <typename T> | ||
void quantize( | ||
T* __restrict__ y, | ||
const float* __restrict__ x, | ||
float inv_scale, | ||
int32_t zero_point, | ||
size_t size) { | ||
for (size_t i = 0; i < size; ++i) { | ||
y[i] = quantize<T>(x[i], inv_scale, zero_point); | ||
} | ||
} | ||
|
||
// Dequantize an int8_t/uint8_t value to an fp32 value | ||
template <typename T> | ||
__attribute__((always_inline)) float | ||
dequantize(const T x, float scale, int32_t zero_point) { | ||
return scale * (x - zero_point); | ||
} | ||
|
||
// Dequantize an int8_t/uint8_t/int16_t array to an fp32 array | ||
template <typename T> | ||
void dequantize( | ||
float* __restrict__ y, | ||
const T* __restrict__ x, | ||
float scale, | ||
int32_t zero_point, | ||
size_t size) { | ||
for (size_t i = 0; i < size; ++i) { | ||
y[i] = dequantize<T>(x[i], scale, zero_point); | ||
} | ||
} | ||
|
||
// explicit template instantiation | ||
|
||
#define typed_quantize_val(dtype) \ | ||
template __attribute__((always_inline)) dtype quantize( \ | ||
const float x, float inv_scale, int32_t zero_point); | ||
typed_quantize_val(int8_t); | ||
typed_quantize_val(uint8_t); | ||
typed_quantize_val(int16_t); | ||
typed_quantize_val(int32_t); | ||
#undef typed_quantize_val | ||
|
||
#define typed_quantize_vec(dtype) \ | ||
template void quantize( \ | ||
dtype* __restrict__ y, \ | ||
const float* __restrict__ x, \ | ||
float inv_scale, \ | ||
int32_t zero_point, \ | ||
size_t size); | ||
typed_quantize_vec(int8_t); | ||
typed_quantize_vec(uint8_t); | ||
typed_quantize_vec(int16_t); | ||
typed_quantize_vec(int32_t); | ||
#undef typed_quantize_vec | ||
|
||
#define typed_dequantize_val(dtype) \ | ||
template __attribute__((always_inline)) float dequantize( \ | ||
const dtype x, float scale, int32_t zero_point); | ||
typed_dequantize_val(int8_t); | ||
typed_dequantize_val(uint8_t); | ||
typed_dequantize_val(int16_t); | ||
typed_dequantize_val(int32_t); | ||
#undef typed_dequantize_val | ||
|
||
#define typed_dequantize_vec(dtype) \ | ||
template void dequantize( \ | ||
float* __restrict__ y, \ | ||
const dtype* __restrict__ x, \ | ||
float scale, \ | ||
int32_t zero_point, \ | ||
size_t size); | ||
typed_dequantize_vec(int8_t); | ||
typed_dequantize_vec(uint8_t); | ||
typed_dequantize_vec(int16_t); | ||
typed_dequantize_vec(int32_t); | ||
#undef typed_dequantize_vec | ||
|
||
}; // namespace kernels | ||
}; // namespace reference | ||
}; // namespace impl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include "inttypes.h" | ||
#include "stddef.h" | ||
|
||
namespace impl { | ||
namespace reference { | ||
namespace kernels { | ||
|
||
template <typename T> | ||
T quantize(const float x, float scale, int32_t zero_point); | ||
|
||
template <typename T> | ||
float dequantize(const T x, float scale, int32_t zero_point); | ||
|
||
template <typename T> | ||
void quantize( | ||
T* __restrict__ y, | ||
const float* __restrict__ x, | ||
float scale, | ||
int32_t zero_point, | ||
size_t size); | ||
|
||
// Deuantize an int8_t/uint8_t/int16_t array to an fp32 array | ||
template <typename T> | ||
void dequantize( | ||
float* __restrict__ y, | ||
const T* __restrict__ x, | ||
float scale, | ||
int32_t zero_point, | ||
size_t size); | ||
|
||
}; // namespace kernels | ||
}; // namespace reference | ||
}; // namespace impl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
cmake_minimum_required(VERSION 3.19) | ||
|
||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) | ||
if(NOT CMAKE_CXX_STANDARD) | ||
set(CMAKE_CXX_STANDARD 17) | ||
endif() | ||
|
||
include(${EXECUTORCH_ROOT}/build/Utils.cmake) | ||
include(${EXECUTORCH_ROOT}/build/Codegen.cmake) | ||
|
||
if(NOT PYTHON_EXECUTABLE) | ||
resolve_python_executable() | ||
endif() | ||
|
||
# ATen compliant ops that are needed to run this model. | ||
set(_aten_ops__srcs | ||
"${CMAKE_CURRENT_SOURCE_DIR}/op_add.cpp" | ||
"${CMAKE_CURRENT_SOURCE_DIR}/op_embedding.cpp" | ||
"${CMAKE_CURRENT_SOURCE_DIR}/op_full.cpp" | ||
"${CMAKE_CURRENT_SOURCE_DIR}/op_view_copy.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/index_util.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/kernel_ops_util.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sub.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp" | ||
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp") | ||
add_library(aten_ops_cadence ${_aten_ops__srcs}) | ||
target_link_libraries(aten_ops_cadence PUBLIC executorch) | ||
target_link_libraries(aten_ops_cadence PRIVATE cadence_kernels) | ||
|
||
# Let files say "include <executorch/path/to/header.h>". | ||
set(_common_include_directories ${EXECUTORCH_ROOT}/..) | ||
|
||
target_include_directories(aten_ops_cadence PUBLIC ${ROOT_DIR}/.. | ||
${CMAKE_BINARY_DIR} | ||
${_common_include_directories}) | ||
|
||
# Custom ops that are needed to run the test model. | ||
add_library( | ||
custom_ops "quantized_linear_out.cpp" "quantized_conv_out.cpp" | ||
"quantized_relu_out.cpp" "quantized_layer_norm.cpp" | ||
"quantize_per_tensor.cpp" "dequantize_per_tensor.cpp") | ||
target_include_directories(custom_ops PUBLIC ${ROOT_DIR}/.. | ||
${CMAKE_BINARY_DIR} | ||
${_common_include_directories}) | ||
|
||
target_link_libraries(custom_ops PUBLIC executorch) | ||
target_link_libraries(custom_ops PRIVATE cadence_kernels) | ||
|
||
# Generate C++ bindings to register kernels into both PyTorch (for AOT) and | ||
# Executorch (for runtime). Here select all ops in functions.yaml | ||
gen_selected_ops( | ||
LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML | ||
"${CMAKE_CURRENT_LIST_DIR}/../../aot/functions.yaml" "" "" | ||
) | ||
generate_bindings_for_kernels( | ||
LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML | ||
FUNCTIONS_YAML ${CMAKE_CURRENT_SOURCE_DIR}/../../aot/functions.yaml | ||
) | ||
message("Generated files ${gen_command_sources}") | ||
|
||
gen_operators_lib( | ||
LIB_NAME "cadence_ops_lib" | ||
KERNEL_LIBS custom_ops | ||
DEPS aten_ops_cadence) |
51 changes: 51 additions & 0 deletions
51
backends/cadence/reference/operators/dequantize_per_tensor.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/runtime/kernel/kernel_includes.h> | ||
#include "kernels.h" | ||
|
||
namespace impl { | ||
namespace reference { | ||
namespace native { | ||
|
||
using Tensor = exec_aten::Tensor; | ||
using RuntimeContext = torch::executor::RuntimeContext; | ||
using ScalarType = exec_aten::ScalarType; | ||
|
||
void dequantize_per_tensor_out( | ||
RuntimeContext& context, | ||
const Tensor& input, | ||
double scale, | ||
int64_t zero_point, | ||
int64_t quant_min, | ||
int64_t quant_max, | ||
ScalarType dtype, | ||
Tensor& out) { | ||
float* out_data = out.mutable_data_ptr<float>(); | ||
size_t numel = out.numel(); | ||
|
||
if (input.scalar_type() == ScalarType::Byte) { | ||
const uint8_t* input_data = input.const_data_ptr<uint8_t>(); | ||
impl::reference::kernels::dequantize<uint8_t>( | ||
out_data, input_data, scale, zero_point, numel); | ||
} else if (input.scalar_type() == ScalarType::Char) { | ||
const int8_t* input_data = input.const_data_ptr<int8_t>(); | ||
impl::reference::kernels::dequantize<int8_t>( | ||
out_data, input_data, scale, zero_point, numel); | ||
} else if (input.scalar_type() == ScalarType::Int) { | ||
const int32_t* input_data = input.const_data_ptr<int32_t>(); | ||
impl::reference::kernels::dequantize<int32_t>( | ||
out_data, input_data, scale, zero_point, numel); | ||
} else { | ||
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); | ||
} | ||
} | ||
|
||
}; // namespace native | ||
}; // namespace reference | ||
}; // namespace impl |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.