Skip to content

Commit

Permalink
Provide kernels with true reference implementations for quantized ops (
Browse files Browse the repository at this point in the history
…#4108)

Summary:
Pull Request resolved: #4108

We want to be able to run the reference implementations on x86, so we don't want any intrinsics or anything like that in the reference kernels.

In the end, this change has a lot of things:
- introduce a `reference` folder for reference implementations
- moved the primary cmake flow from HiFi to reference, so that the default mode can run on x86
- that means we will need a proper flag to use HiFi optimized ops, which we will add later
- add a `quantized_matmul` reference kernel

Reviewed By: dulinriley

Differential Revision: D59238748

fbshipit-source-id: 830c89fe9ee8dd87ece963e1174ca3cbd1e0fbc6
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jul 16, 2024
1 parent 7d6c8fc commit ef640bf
Show file tree
Hide file tree
Showing 17 changed files with 757 additions and 17 deletions.
4 changes: 2 additions & 2 deletions backends/cadence/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ include(${EXECUTORCH_ROOT}/build/Utils.cmake)
set(_common_include_directories ${EXECUTORCH_ROOT}/..)


add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/hifi/operators)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/hifi/kernels)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/reference/operators)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/reference/kernels)
12 changes: 6 additions & 6 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,30 +107,30 @@
variants: function
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantize_per_tensor_out
kernel_name: impl::reference::quantize_per_tensor_out

- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
variants: function
kernels:
- arg_meta: null
kernel_name: impl::HiFi::dequantize_per_tensor_out
kernel_name: impl::reference::dequantize_per_tensor_out

- func: cadence::quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_conv_out
kernel_name: impl::reference::quantized_conv_out

- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_layer_norm_out
kernel_name: impl::reference::quantized_layer_norm_out

- func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_linear_out
kernel_name: impl::reference::quantized_linear_out

- func: cadence::quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_relu_out
kernel_name: impl::reference::quantized_relu_out
16 changes: 16 additions & 0 deletions backends/cadence/reference/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# lint_cmake: -linelength
add_library(
cadence_kernels
kernels.cpp
)

target_include_directories(
cadence_kernels
PUBLIC .
)
109 changes: 109 additions & 0 deletions backends/cadence/reference/kernels/kernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "kernels.h"

#include <algorithm>
#include <limits>

namespace impl {
namespace reference {
namespace kernels {

// Quantize a fp32 value to an int8_t/uint8_t value
template <typename T>
__attribute__((always_inline)) T
quantize(const float x, float scale, int32_t zero_point) {
constexpr float min_val = std::numeric_limits<T>::min();
constexpr float max_val = std::numeric_limits<T>::max();
float tmp = roundf(x * scale + zero_point);
return std::max(std::min(tmp, max_val), min_val);
}

// Quantize an fp32 array to an int8_t/uint8_t array
template <typename T>
void quantize(
T* __restrict__ y,
const float* __restrict__ x,
float inv_scale,
int32_t zero_point,
size_t size) {
for (size_t i = 0; i < size; ++i) {
y[i] = quantize<T>(x[i], inv_scale, zero_point);
}
}

// Dequantize an int8_t/uint8_t value to an fp32 value
template <typename T>
__attribute__((always_inline)) float
dequantize(const T x, float scale, int32_t zero_point) {
return scale * (x - zero_point);
}

// Dequantize an int8_t/uint8_t/int16_t array to an fp32 array
template <typename T>
void dequantize(
float* __restrict__ y,
const T* __restrict__ x,
float scale,
int32_t zero_point,
size_t size) {
for (size_t i = 0; i < size; ++i) {
y[i] = dequantize<T>(x[i], scale, zero_point);
}
}

// explicit template instantiation

#define typed_quantize_val(dtype) \
template __attribute__((always_inline)) dtype quantize( \
const float x, float inv_scale, int32_t zero_point);
typed_quantize_val(int8_t);
typed_quantize_val(uint8_t);
typed_quantize_val(int16_t);
typed_quantize_val(int32_t);
#undef typed_quantize_val

#define typed_quantize_vec(dtype) \
template void quantize( \
dtype* __restrict__ y, \
const float* __restrict__ x, \
float inv_scale, \
int32_t zero_point, \
size_t size);
typed_quantize_vec(int8_t);
typed_quantize_vec(uint8_t);
typed_quantize_vec(int16_t);
typed_quantize_vec(int32_t);
#undef typed_quantize_vec

#define typed_dequantize_val(dtype) \
template __attribute__((always_inline)) float dequantize( \
const dtype x, float scale, int32_t zero_point);
typed_dequantize_val(int8_t);
typed_dequantize_val(uint8_t);
typed_dequantize_val(int16_t);
typed_dequantize_val(int32_t);
#undef typed_dequantize_val

#define typed_dequantize_vec(dtype) \
template void dequantize( \
float* __restrict__ y, \
const dtype* __restrict__ x, \
float scale, \
int32_t zero_point, \
size_t size);
typed_dequantize_vec(int8_t);
typed_dequantize_vec(uint8_t);
typed_dequantize_vec(int16_t);
typed_dequantize_vec(int32_t);
#undef typed_dequantize_vec

}; // namespace kernels
}; // namespace reference
}; // namespace impl
41 changes: 41 additions & 0 deletions backends/cadence/reference/kernels/kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "inttypes.h"
#include "stddef.h"

namespace impl {
namespace reference {
namespace kernels {

template <typename T>
T quantize(const float x, float scale, int32_t zero_point);

template <typename T>
float dequantize(const T x, float scale, int32_t zero_point);

template <typename T>
void quantize(
T* __restrict__ y,
const float* __restrict__ x,
float scale,
int32_t zero_point,
size_t size);

// Deuantize an int8_t/uint8_t/int16_t array to an fp32 array
template <typename T>
void dequantize(
float* __restrict__ y,
const T* __restrict__ x,
float scale,
int32_t zero_point,
size_t size);

}; // namespace kernels
}; // namespace reference
}; // namespace impl
86 changes: 86 additions & 0 deletions backends/cadence/reference/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

cmake_minimum_required(VERSION 3.19)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()

include(${EXECUTORCH_ROOT}/build/Utils.cmake)
include(${EXECUTORCH_ROOT}/build/Codegen.cmake)

if(NOT PYTHON_EXECUTABLE)
resolve_python_executable()
endif()

# ATen compliant ops that are needed to run this model.
set(_aten_ops__srcs
"${CMAKE_CURRENT_SOURCE_DIR}/op_add.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_embedding.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_full.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/op_view_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/index_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/kernel_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sub.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp")
add_library(aten_ops_cadence ${_aten_ops__srcs})
target_link_libraries(aten_ops_cadence PUBLIC executorch)
target_link_libraries(aten_ops_cadence PRIVATE cadence_kernels)

# Let files say "include <executorch/path/to/header.h>".
set(_common_include_directories ${EXECUTORCH_ROOT}/..)

target_include_directories(aten_ops_cadence PUBLIC ${ROOT_DIR}/..
${CMAKE_BINARY_DIR}
${_common_include_directories})

# Custom ops that are needed to run the test model.
add_library(
custom_ops "quantized_linear_out.cpp" "quantized_conv_out.cpp"
"quantized_relu_out.cpp" "quantized_layer_norm.cpp"
"quantize_per_tensor.cpp" "dequantize_per_tensor.cpp")
target_include_directories(custom_ops PUBLIC ${ROOT_DIR}/..
${CMAKE_BINARY_DIR}
${_common_include_directories})

target_link_libraries(custom_ops PUBLIC executorch)
target_link_libraries(custom_ops PRIVATE cadence_kernels)

# Generate C++ bindings to register kernels into both PyTorch (for AOT) and
# Executorch (for runtime). Here select all ops in functions.yaml
gen_selected_ops(
LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML
"${CMAKE_CURRENT_LIST_DIR}/../../aot/functions.yaml" "" ""
)
generate_bindings_for_kernels(
LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML
FUNCTIONS_YAML ${CMAKE_CURRENT_SOURCE_DIR}/../../aot/functions.yaml
)
message("Generated files ${gen_command_sources}")

gen_operators_lib(
LIB_NAME "cadence_ops_lib"
KERNEL_LIBS custom_ops
DEPS aten_ops_cadence)
51 changes: 51 additions & 0 deletions backends/cadence/reference/operators/dequantize_per_tensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/kernel/kernel_includes.h>
#include "kernels.h"

namespace impl {
namespace reference {
namespace native {

using Tensor = exec_aten::Tensor;
using RuntimeContext = torch::executor::RuntimeContext;
using ScalarType = exec_aten::ScalarType;

void dequantize_per_tensor_out(
RuntimeContext& context,
const Tensor& input,
double scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
Tensor& out) {
float* out_data = out.mutable_data_ptr<float>();
size_t numel = out.numel();

if (input.scalar_type() == ScalarType::Byte) {
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
impl::reference::kernels::dequantize<uint8_t>(
out_data, input_data, scale, zero_point, numel);
} else if (input.scalar_type() == ScalarType::Char) {
const int8_t* input_data = input.const_data_ptr<int8_t>();
impl::reference::kernels::dequantize<int8_t>(
out_data, input_data, scale, zero_point, numel);
} else if (input.scalar_type() == ScalarType::Int) {
const int32_t* input_data = input.const_data_ptr<int32_t>();
impl::reference::kernels::dequantize<int32_t>(
out_data, input_data, scale, zero_point, numel);
} else {
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type());
}
}

}; // namespace native
}; // namespace reference
}; // namespace impl
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
*/

#include <executorch/runtime/kernel/kernel_includes.h>
#include "kernels.h"

namespace torch {
namespace executor {
Expand All @@ -31,7 +30,7 @@ void embedding_out(

for (int i = 0, e = indices.numel(); i < e; i++) {
// memcpy(dest, src, nbytes);
impl::HiFi::kernels::memcpy(
memcpy(
out_data, w_data + nbytes_per_entry * indices_ptr[i], nbytes_per_entry);
out_data += nbytes_per_entry;
}
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
*/

#include <executorch/runtime/kernel/kernel_includes.h>
#include "kernels.h"

namespace torch {
namespace executor {
Expand All @@ -21,8 +20,7 @@ Tensor& view_copy_out(
const Tensor& input,
const IntArrayRef size,
Tensor& out) {
impl::HiFi::kernels::memcpy(
out.mutable_data_ptr(), input.const_data_ptr(), input.nbytes());
memcpy(out.mutable_data_ptr(), input.const_data_ptr(), input.nbytes());
return out;
}

Expand Down
Loading

0 comments on commit ef640bf

Please sign in to comment.