From 92152766f20a59d72a7370616416b4e7a1350b4a Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Fri, 12 Jul 2024 10:13:08 -0700 Subject: [PATCH] Provide kernels with true reference implementations for quantized ops (#4108) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4108 We want to be able to run the reference implementations on x86, so we don't want any intrinsics or anything like that in the reference kernels. In the end, this change has a lot of things: - introduce a `reference` folder for reference implementations - moved the primary cmake flow from HiFi to reference, so that the default mode can run on x86 - that means we will need a proper flag to use HiFi optimized ops, which we will add later - add a `quantized_matmul` reference kernel Differential Revision: D59238748 --- backends/cadence/CMakeLists.txt | 4 +- backends/cadence/aot/functions.yaml | 12 +- .../cadence/reference/kernels/CMakeLists.txt | 16 ++ .../cadence/reference/kernels/kernels.cpp | 109 ++++++++++++ backends/cadence/reference/kernels/kernels.h | 41 +++++ .../reference/operators/CMakeLists.txt | 86 ++++++++++ .../operators/dequantize_per_tensor.cpp | 51 ++++++ .../{hifi => reference}/operators/op_add.cpp | 0 .../operators/op_embedding.cpp | 3 +- .../{hifi => reference}/operators/op_full.cpp | 0 .../operators/op_view_copy.cpp | 4 +- .../operators/quantize_per_tensor.cpp | 53 ++++++ .../operators/quantized_conv_out.cpp | 4 +- .../operators/quantized_layer_norm.cpp | 157 ++++++++++++++++++ .../operators/quantized_linear_out.cpp | 80 +++++++++ .../operators/quantized_matmul_out.cpp | 150 +++++++++++++++++ .../operators/quantized_relu_out.cpp | 4 +- 17 files changed, 757 insertions(+), 17 deletions(-) create mode 100644 backends/cadence/reference/kernels/CMakeLists.txt create mode 100644 backends/cadence/reference/kernels/kernels.cpp create mode 100644 backends/cadence/reference/kernels/kernels.h create mode 100644 backends/cadence/reference/operators/CMakeLists.txt create mode 100644 backends/cadence/reference/operators/dequantize_per_tensor.cpp rename backends/cadence/{hifi => reference}/operators/op_add.cpp (100%) rename backends/cadence/{hifi => reference}/operators/op_embedding.cpp (95%) rename backends/cadence/{hifi => reference}/operators/op_full.cpp (100%) rename backends/cadence/{hifi => reference}/operators/op_view_copy.cpp (83%) create mode 100644 backends/cadence/reference/operators/quantize_per_tensor.cpp rename backends/cadence/{hifi => reference}/operators/quantized_conv_out.cpp (99%) create mode 100644 backends/cadence/reference/operators/quantized_layer_norm.cpp create mode 100644 backends/cadence/reference/operators/quantized_linear_out.cpp create mode 100644 backends/cadence/reference/operators/quantized_matmul_out.cpp rename backends/cadence/{hifi => reference}/operators/quantized_relu_out.cpp (96%) diff --git a/backends/cadence/CMakeLists.txt b/backends/cadence/CMakeLists.txt index 8781086e113..b3c3b80d00b 100644 --- a/backends/cadence/CMakeLists.txt +++ b/backends/cadence/CMakeLists.txt @@ -25,5 +25,5 @@ include(${EXECUTORCH_ROOT}/build/Utils.cmake) set(_common_include_directories ${EXECUTORCH_ROOT}/..) -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/hifi/operators) -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/hifi/kernels) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/reference/operators) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/reference/kernels) diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 6e74fbc25a6..f79d5f870da 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -107,30 +107,30 @@ variants: function kernels: - arg_meta: null - kernel_name: impl::HiFi::quantize_per_tensor_out + kernel_name: impl::reference::quantize_per_tensor_out - func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null - kernel_name: impl::HiFi::dequantize_per_tensor_out + kernel_name: impl::reference::dequantize_per_tensor_out - func: cadence::quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::HiFi::quantized_conv_out + kernel_name: impl::reference::quantized_conv_out - func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::HiFi::quantized_layer_norm_out + kernel_name: impl::reference::quantized_layer_norm_out - func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::HiFi::quantized_linear_out + kernel_name: impl::reference::quantized_linear_out - func: cadence::quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::HiFi::quantized_relu_out + kernel_name: impl::reference::quantized_relu_out diff --git a/backends/cadence/reference/kernels/CMakeLists.txt b/backends/cadence/reference/kernels/CMakeLists.txt new file mode 100644 index 00000000000..eadb01f54d5 --- /dev/null +++ b/backends/cadence/reference/kernels/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# lint_cmake: -linelength +add_library( + cadence_kernels + kernels.cpp +) + +target_include_directories( + cadence_kernels + PUBLIC . +) diff --git a/backends/cadence/reference/kernels/kernels.cpp b/backends/cadence/reference/kernels/kernels.cpp new file mode 100644 index 00000000000..735d390bc74 --- /dev/null +++ b/backends/cadence/reference/kernels/kernels.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "kernels.h" + +#include +#include + +namespace impl { +namespace reference { +namespace kernels { + +// Quantize a fp32 value to an int8_t/uint8_t value +template +__attribute__((always_inline)) T +quantize(const float x, float scale, int32_t zero_point) { + constexpr float min_val = std::numeric_limits::min(); + constexpr float max_val = std::numeric_limits::max(); + float tmp = roundf(x * scale + zero_point); + return std::max(std::min(tmp, max_val), min_val); +} + +// Quantize an fp32 array to an int8_t/uint8_t array +template +void quantize( + T* __restrict__ y, + const float* __restrict__ x, + float inv_scale, + int32_t zero_point, + size_t size) { + for (size_t i = 0; i < size; ++i) { + y[i] = quantize(x[i], inv_scale, zero_point); + } +} + +// Dequantize an int8_t/uint8_t value to an fp32 value +template +__attribute__((always_inline)) float +dequantize(const T x, float scale, int32_t zero_point) { + return scale * (x - zero_point); +} + +// Dequantize an int8_t/uint8_t/int16_t array to an fp32 array +template +void dequantize( + float* __restrict__ y, + const T* __restrict__ x, + float scale, + int32_t zero_point, + size_t size) { + for (size_t i = 0; i < size; ++i) { + y[i] = dequantize(x[i], scale, zero_point); + } +} + +// explicit template instantiation + +#define typed_quantize_val(dtype) \ + template __attribute__((always_inline)) dtype quantize( \ + const float x, float inv_scale, int32_t zero_point); +typed_quantize_val(int8_t); +typed_quantize_val(uint8_t); +typed_quantize_val(int16_t); +typed_quantize_val(int32_t); +#undef typed_quantize_val + +#define typed_quantize_vec(dtype) \ + template void quantize( \ + dtype* __restrict__ y, \ + const float* __restrict__ x, \ + float inv_scale, \ + int32_t zero_point, \ + size_t size); +typed_quantize_vec(int8_t); +typed_quantize_vec(uint8_t); +typed_quantize_vec(int16_t); +typed_quantize_vec(int32_t); +#undef typed_quantize_vec + +#define typed_dequantize_val(dtype) \ + template __attribute__((always_inline)) float dequantize( \ + const dtype x, float scale, int32_t zero_point); +typed_dequantize_val(int8_t); +typed_dequantize_val(uint8_t); +typed_dequantize_val(int16_t); +typed_dequantize_val(int32_t); +#undef typed_dequantize_val + +#define typed_dequantize_vec(dtype) \ + template void dequantize( \ + float* __restrict__ y, \ + const dtype* __restrict__ x, \ + float scale, \ + int32_t zero_point, \ + size_t size); +typed_dequantize_vec(int8_t); +typed_dequantize_vec(uint8_t); +typed_dequantize_vec(int16_t); +typed_dequantize_vec(int32_t); +#undef typed_dequantize_vec + +}; // namespace kernels +}; // namespace reference +}; // namespace impl diff --git a/backends/cadence/reference/kernels/kernels.h b/backends/cadence/reference/kernels/kernels.h new file mode 100644 index 00000000000..76400405144 --- /dev/null +++ b/backends/cadence/reference/kernels/kernels.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "inttypes.h" +#include "stddef.h" + +namespace impl { +namespace reference { +namespace kernels { + +template +T quantize(const float x, float scale, int32_t zero_point); + +template +float dequantize(const T x, float scale, int32_t zero_point); + +template +void quantize( + T* __restrict__ y, + const float* __restrict__ x, + float scale, + int32_t zero_point, + size_t size); + +// Deuantize an int8_t/uint8_t/int16_t array to an fp32 array +template +void dequantize( + float* __restrict__ y, + const T* __restrict__ x, + float scale, + int32_t zero_point, + size_t size); + +}; // namespace kernels +}; // namespace reference +}; // namespace impl diff --git a/backends/cadence/reference/operators/CMakeLists.txt b/backends/cadence/reference/operators/CMakeLists.txt new file mode 100644 index 00000000000..c22dc0c9976 --- /dev/null +++ b/backends/cadence/reference/operators/CMakeLists.txt @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +include(${EXECUTORCH_ROOT}/build/Utils.cmake) +include(${EXECUTORCH_ROOT}/build/Codegen.cmake) + +if(NOT PYTHON_EXECUTABLE) + resolve_python_executable() +endif() + +# ATen compliant ops that are needed to run this model. +set(_aten_ops__srcs + "${CMAKE_CURRENT_SOURCE_DIR}/op_add.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_embedding.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_full.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_view_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/index_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/kernel_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sub.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp") +add_library(aten_ops_cadence ${_aten_ops__srcs}) +target_link_libraries(aten_ops_cadence PUBLIC executorch) +target_link_libraries(aten_ops_cadence PRIVATE cadence_kernels) + +# Let files say "include ". +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +target_include_directories(aten_ops_cadence PUBLIC ${ROOT_DIR}/.. + ${CMAKE_BINARY_DIR} + ${_common_include_directories}) + +# Custom ops that are needed to run the test model. +add_library( + custom_ops "quantized_linear_out.cpp" "quantized_conv_out.cpp" + "quantized_relu_out.cpp" "quantized_layer_norm.cpp" + "quantize_per_tensor.cpp" "dequantize_per_tensor.cpp") +target_include_directories(custom_ops PUBLIC ${ROOT_DIR}/.. + ${CMAKE_BINARY_DIR} + ${_common_include_directories}) + +target_link_libraries(custom_ops PUBLIC executorch) +target_link_libraries(custom_ops PRIVATE cadence_kernels) + +# Generate C++ bindings to register kernels into both PyTorch (for AOT) and +# Executorch (for runtime). Here select all ops in functions.yaml +gen_selected_ops( + LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML + "${CMAKE_CURRENT_LIST_DIR}/../../aot/functions.yaml" "" "" +) +generate_bindings_for_kernels( + LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML + FUNCTIONS_YAML ${CMAKE_CURRENT_SOURCE_DIR}/../../aot/functions.yaml +) +message("Generated files ${gen_command_sources}") + +gen_operators_lib( + LIB_NAME "cadence_ops_lib" + KERNEL_LIBS custom_ops + DEPS aten_ops_cadence) diff --git a/backends/cadence/reference/operators/dequantize_per_tensor.cpp b/backends/cadence/reference/operators/dequantize_per_tensor.cpp new file mode 100644 index 00000000000..4d6a6180347 --- /dev/null +++ b/backends/cadence/reference/operators/dequantize_per_tensor.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include "kernels.h" + +namespace impl { +namespace reference { +namespace native { + +using Tensor = exec_aten::Tensor; +using RuntimeContext = torch::executor::RuntimeContext; +using ScalarType = exec_aten::ScalarType; + +void dequantize_per_tensor_out( + RuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + + if (input.scalar_type() == ScalarType::Byte) { + const uint8_t* input_data = input.const_data_ptr(); + impl::reference::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Char) { + const int8_t* input_data = input.const_data_ptr(); + impl::reference::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Int) { + const int32_t* input_data = input.const_data_ptr(); + impl::reference::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else { + ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); + } +} + +}; // namespace native +}; // namespace reference +}; // namespace impl diff --git a/backends/cadence/hifi/operators/op_add.cpp b/backends/cadence/reference/operators/op_add.cpp similarity index 100% rename from backends/cadence/hifi/operators/op_add.cpp rename to backends/cadence/reference/operators/op_add.cpp diff --git a/backends/cadence/hifi/operators/op_embedding.cpp b/backends/cadence/reference/operators/op_embedding.cpp similarity index 95% rename from backends/cadence/hifi/operators/op_embedding.cpp rename to backends/cadence/reference/operators/op_embedding.cpp index b4100feacc1..f0b625c963e 100644 --- a/backends/cadence/hifi/operators/op_embedding.cpp +++ b/backends/cadence/reference/operators/op_embedding.cpp @@ -7,7 +7,6 @@ */ #include -#include "kernels.h" namespace torch { namespace executor { @@ -31,7 +30,7 @@ void embedding_out( for (int i = 0, e = indices.numel(); i < e; i++) { // memcpy(dest, src, nbytes); - impl::HiFi::kernels::memcpy( + memcpy( out_data, w_data + nbytes_per_entry * indices_ptr[i], nbytes_per_entry); out_data += nbytes_per_entry; } diff --git a/backends/cadence/hifi/operators/op_full.cpp b/backends/cadence/reference/operators/op_full.cpp similarity index 100% rename from backends/cadence/hifi/operators/op_full.cpp rename to backends/cadence/reference/operators/op_full.cpp diff --git a/backends/cadence/hifi/operators/op_view_copy.cpp b/backends/cadence/reference/operators/op_view_copy.cpp similarity index 83% rename from backends/cadence/hifi/operators/op_view_copy.cpp rename to backends/cadence/reference/operators/op_view_copy.cpp index e856c1592cb..a363125c375 100644 --- a/backends/cadence/hifi/operators/op_view_copy.cpp +++ b/backends/cadence/reference/operators/op_view_copy.cpp @@ -7,7 +7,6 @@ */ #include -#include "kernels.h" namespace torch { namespace executor { @@ -21,8 +20,7 @@ Tensor& view_copy_out( const Tensor& input, const IntArrayRef size, Tensor& out) { - impl::HiFi::kernels::memcpy( - out.mutable_data_ptr(), input.const_data_ptr(), input.nbytes()); + memcpy(out.mutable_data_ptr(), input.const_data_ptr(), input.nbytes()); return out; } diff --git a/backends/cadence/reference/operators/quantize_per_tensor.cpp b/backends/cadence/reference/operators/quantize_per_tensor.cpp new file mode 100644 index 00000000000..8e25b58a07c --- /dev/null +++ b/backends/cadence/reference/operators/quantize_per_tensor.cpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include "kernels.h" + +namespace impl { +namespace reference { +namespace native { + +using Tensor = exec_aten::Tensor; +using RuntimeContext = torch::executor::RuntimeContext; +using ScalarType = exec_aten::ScalarType; + +// Quantize the input tensor (PT2 version). Note that quant_ are not +// used in any computation. +void quantize_per_tensor_out( + RuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + + if (out.scalar_type() == ScalarType::Byte) { + uint8_t* out_data = out.mutable_data_ptr(); + impl::reference::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Char) { + int8_t* out_data = out.mutable_data_ptr(); + impl::reference::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Int) { + int32_t* out_data = out.mutable_data_ptr(); + impl::reference::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else { + ET_CHECK_MSG(false, "Unhandled input dtype %hhd", out.scalar_type()); + } +} + +}; // namespace native +}; // namespace reference +}; // namespace impl diff --git a/backends/cadence/hifi/operators/quantized_conv_out.cpp b/backends/cadence/reference/operators/quantized_conv_out.cpp similarity index 99% rename from backends/cadence/hifi/operators/quantized_conv_out.cpp rename to backends/cadence/reference/operators/quantized_conv_out.cpp index 23e189e6bcb..95236b4397b 100644 --- a/backends/cadence/hifi/operators/quantized_conv_out.cpp +++ b/backends/cadence/reference/operators/quantized_conv_out.cpp @@ -13,7 +13,7 @@ #include namespace impl { -namespace HiFi { +namespace reference { namespace native { using Tensor = exec_aten::Tensor; @@ -223,5 +223,5 @@ void quantized_conv_out( } }; // namespace native -}; // namespace HiFi +}; // namespace reference }; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_layer_norm.cpp b/backends/cadence/reference/operators/quantized_layer_norm.cpp new file mode 100644 index 00000000000..22075f632ef --- /dev/null +++ b/backends/cadence/reference/operators/quantized_layer_norm.cpp @@ -0,0 +1,157 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include "kernels.h" + +#include +#include +#include + +using Tensor = exec_aten::Tensor; +using RuntimeContext = torch::executor::RuntimeContext; + +namespace impl { +namespace reference { +namespace native { + +// Compute quantized layer_norm. The current implementation assumes that the +// input is per-tensor quantized. +template +void quantized_layer_norm_( + const Tensor& input, + float input_scale, + int64_t input_zero_point, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + // Get the raw pointers to input, output, weight, and bias + const T* __restrict__ in_data = input.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + const float* __restrict__ weight_data = weight.const_data_ptr(); + const float* __restrict__ bias_data = bias.const_data_ptr(); + + float output_inv_scale = XT_RECIP_S(output_scale); + + size_t last_dim = input.size(input.dim() - 1); + size_t leading_dims = getLeadingDims(input, input.dim() - 1); + + // Visualize the input tensor as a set of 1d vectors, and compute the + // layer_norm for each vector. + for (size_t i = 0; i < leading_dims; ++i) { + const T* __restrict__ x = in_data + i * last_dim; + T* __restrict__ y = out_data + i * last_dim; + + // compute sum and squared sum. The fp32 sum can be approximated as: + // (X_1 - in_zero_point) * in_scale + (X_2 - in_zero_point) * in_scale + ... + // (X_N - in_zero_point) * in_scale. + int32_t sum = 0; + int32_t sq_sum = last_dim * input_zero_point * input_zero_point; +#pragma simd + for (size_t j = 0; j < last_dim; ++j) { + int32_t val = x[j]; + sum += val; + sq_sum += val * val; + } + sq_sum -= (2 * sum * input_zero_point); + sum -= (last_dim * input_zero_point); + + float mean = XT_DIV_S(XT_MUL_S(input_scale, sum), last_dim); + float variance = + XT_DIV_S( + XT_MUL_S(sq_sum, XT_MUL_S(input_scale, input_scale)), last_dim) - + XT_MUL_S(mean, mean); + float inv_std = XT_RECIP_S(XT_SQRT_S(XT_ADD_S(variance, (float)eps))); + + // y = (x - mean) / std * kGamma + kBeta +#pragma simd + for (size_t j = 0; j < last_dim; ++j) { + // Since X is quantized, we dequantize it, compute fp32 result, and + // quantize the result to an int8/uint8 value. + float val = kernels::dequantize(x[j], input_scale, input_zero_point); + val = (val - mean) * inv_std * weight_data[j] + bias_data[j]; + y[j] = kernels::quantize(val, output_inv_scale, output_zero_point); + } + } +} + +// Compute quantized layer_norm. The current implementation assumes that the +// input is per-tensor quantized. +template +void quantized_layer_norm_( + const Tensor& input, + const Tensor& in_scale, + const Tensor& in_zero_point, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + // Extract the zero point and scale for input tensor. + float input_scale = in_scale.const_data_ptr()[0]; + int64_t input_zero_point = in_zero_point.const_data_ptr()[0]; + + // Call other overload + quantized_layer_norm_( + input, + input_scale, + input_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); +} + +void quantized_layer_norm_out( + RuntimeContext& ctx, + const Tensor& input, + const Tensor& in_scale, + const Tensor& in_zero_point, + const exec_aten::IntArrayRef normalized_shape, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + if (input.scalar_type() == exec_aten::ScalarType::Byte) { + quantized_layer_norm_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else if (input.scalar_type() == exec_aten::ScalarType::Char) { + quantized_layer_norm_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else { + ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); + } +} + +}; // namespace native +}; // namespace reference +}; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_linear_out.cpp b/backends/cadence/reference/operators/quantized_linear_out.cpp new file mode 100644 index 00000000000..fa40f16427b --- /dev/null +++ b/backends/cadence/reference/operators/quantized_linear_out.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include "kernels.h" + +namespace impl { +namespace reference { +namespace native { + +using Tensor = exec_aten::Tensor; +using RuntimeContext = torch::executor::RuntimeContext; + +void quantized_linear_out( + RuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + int64_t src_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + const exec_aten::optional& offset, + Tensor& out) { + // Assuming uint8_t for now, but needs to be updated for other quantization + // types + const uint8_t* __restrict__ src_data = src.const_data_ptr(); + const uint8_t* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + uint8_t* __restrict__ out_data = out.mutable_data_ptr(); + + int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; + + // input comes in shape [batch_size, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [batch_size, out_dim] + // Perform matrix multiply (M x N) x (N x P) => M x P + const auto M = weight.size(0); // = out_dim + const auto N = weight.size(1); // = in_dim + + // Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the + // leading dimensions is d0 * d1 * ... * d_{N-2} + const auto leading_dims = getLeadingDims(src, src.dim() - 1); + + ET_CHECK_MSG( + out_multiplier.numel() == 1, "out_multiplier should have one element"); + ET_CHECK_MSG( + out_shift.numel() == 1, "out_multiplier should have one element"); + + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = + -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]); + + for (int i = 0; i < leading_dims; ++i) { + for (int j = 0; j < M; ++j) { + float sum = bias_data[j]; + for (int k = 0; k < N; ++k) { + sum += (src_data[i * N + k] - src_zero_point) * + (weight_data[j * N + k] - weight_zero_point); + } + out_data[i * M + j] = + kernels::quantize(sum, out_scale, out_zero_point); + } + } +} + +}; // namespace native +}; // namespace reference +}; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_matmul_out.cpp b/backends/cadence/reference/operators/quantized_matmul_out.cpp new file mode 100644 index 00000000000..95df35caba7 --- /dev/null +++ b/backends/cadence/reference/operators/quantized_matmul_out.cpp @@ -0,0 +1,150 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include "kernels.h" + +namespace impl { +namespace reference { +namespace native { + +// The quantized matmul. The quantized matmul accumulates in a wider register, +// whose type is TA. +template < + typename TZ, + typename TA = float, + bool transposed = false, + typename TX = TZ, + typename TY = TZ> +__attribute__((noinline)) void qmatmul( + TZ* __restrict__ Z, + int32_t Z_multiplier, + int32_t Z_shift, + int32_t Z_zero_point, + const TX* __restrict__ X, + int32_t X_zero_point, + const TY* __restrict__ y, + int32_t Y_zero_point, + size_t m, + size_t n, + size_t p) { + // Compute the Z_scale from Z_multiplier and Z_shift + const float Z_scale = -Z_multiplier * 1.0 / (1 << 31) * pow(2, Z_shift); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < p; ++j) { + TA sum = 0; + for (size_t k = 0; k < n; ++k) { + if (transposed) { + sum += (X[i * n + k] - X_zero_point) * (y[j * n + k] - Y_zero_point); + } else { + sum += (X[i * n + k] - X_zero_point) * (y[k * p + j] - Y_zero_point); + } + } + Z[i * p + j] = kernels::quantize(sum, Z_scale, Z_zero_point); + } + } +} + +template +void inline _typed_quantized_matmul( + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const c10::optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + ctype* __restrict__ out_data = out.mutable_data_ptr(); + const ctype* __restrict__ X_data = X.const_data_ptr(); + const ctype* __restrict__ Y_data = Y.const_data_ptr(); + for (size_t i = 0; i < batch_size; ++i) { + const ctype* x = X_data + i * leading_dim * in_dim; + const ctype* y = Y_data + i * in_dim * out_dim; + ctype* z = out_data + i * leading_dim * out_dim; + if (transposed) { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } else { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } + } + break; +} + +void quantized_matmul_out( + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const c10::optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + (void)bias; + + size_t batch_size = getLeadingDims(X, X.dim() - 2); + size_t leading_dim = X.size(X.dim() - 2); + size_t out_dim = Y.size(Y.dim() - 1 - transposed); + size_t in_dim = X.size(X.dim() - 1); + + if (out.ScalarType() == at::ScalarType::Byte) { + _typed_quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + out); + } else if (out.ScalarType() == at::ScalarType::Char) { + _typed_quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + out); + } +} + +}; // namespace native +}; // namespace reference +}; // namespace impl diff --git a/backends/cadence/hifi/operators/quantized_relu_out.cpp b/backends/cadence/reference/operators/quantized_relu_out.cpp similarity index 96% rename from backends/cadence/hifi/operators/quantized_relu_out.cpp rename to backends/cadence/reference/operators/quantized_relu_out.cpp index 1643747baec..54f6b723c68 100644 --- a/backends/cadence/hifi/operators/quantized_relu_out.cpp +++ b/backends/cadence/reference/operators/quantized_relu_out.cpp @@ -10,7 +10,7 @@ #include "kernels.h" namespace impl { -namespace HiFi { +namespace reference { namespace native { using Tensor = exec_aten::Tensor; @@ -47,5 +47,5 @@ void quantized_relu_out( } }; // namespace native -}; // namespace HiFi +}; // namespace reference }; // namespace impl