diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index d3b215f91de..d36c8e07a25 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -185,6 +185,7 @@ tflm_kernel_cc_library( "add_n.cc", "arg_min_max.cc", "assign_variable.cc", + "batch_matmul.cc", "batch_to_space_nd.cc", "broadcast_args.cc", "broadcast_to.cc", @@ -425,6 +426,20 @@ cc_test( ], ) +cc_test( + name = "batch_matmul_test", + srcs = [ + "batch_matmul_test.cc", + ], + deps = [ + ":kernel_runner", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + cc_test( name = "batch_to_space_nd_test", srcs = [ diff --git a/tensorflow/lite/micro/kernels/Makefile.inc b/tensorflow/lite/micro/kernels/Makefile.inc index 45ba9e058ae..0bd846bc679 100644 --- a/tensorflow/lite/micro/kernels/Makefile.inc +++ b/tensorflow/lite/micro/kernels/Makefile.inc @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -114,6 +114,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/activations_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/add_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/add_n_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/arg_min_max_test.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/batch_matmul_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/batch_to_space_nd_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/broadcast_args_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/broadcast_to_test.cc \ diff --git a/tensorflow/lite/micro/kernels/batch_matmul.cc b/tensorflow/lite/micro/kernels/batch_matmul.cc new file mode 100644 index 00000000000..0ff8f6b2193 --- /dev/null +++ b/tensorflow/lite/micro/kernels/batch_matmul.cc @@ -0,0 +1,555 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/batch_matmul.h" + +#include +#include +#include + +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/transpose.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" + +namespace tflite { +namespace { + +constexpr int kInputLhsTensor = 0; +constexpr int kInputRhsTensor = 1; +constexpr int kOutputTensor = 0; + +struct QuantizationOpData { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; // exponent + + // The range of the fused activation layer. For example for kNone and + // int8_t these would be -128 and 127. + int32_t output_activation_min; + int32_t output_activation_max; + + int32_t lhs_zero_point; + int32_t rhs_zero_point; + int32_t output_zero_point; +}; + +struct OpData { + QuantizationOpData* quantization; + + // Transpose tensors and state + TfLiteEvalTensor* lhs_transposed_tensor; + TfLiteEvalTensor* rhs_transposed_tensor; + bool rhs_is_transposed; + bool lhs_is_constant_tensor; + bool rhs_is_constant_tensor; +}; + +struct OpContext { + OpContext(TfLiteContext* context, TfLiteNode* node) + : params(static_cast(node->builtin_data)), + op_data(static_cast(node->user_data)) {} + + TfLiteBatchMatMulParams* params; + OpData* op_data; +}; + +struct PrepareOpContext : OpContext { + PrepareOpContext(TfLiteContext* context, TfLiteNode* node) + : OpContext(context, node), + micro_context_(GetMicroContext(context)), + lhs(micro_context_->AllocateTempInputTensor(node, kInputLhsTensor)), + rhs(micro_context_->AllocateTempInputTensor(node, kInputRhsTensor)), + output(micro_context_->AllocateTempOutputTensor(node, kOutputTensor)) {} + + ~PrepareOpContext() { + if (lhs != nullptr) { + micro_context_->DeallocateTempTfLiteTensor(lhs); + } + if (rhs != nullptr) { + micro_context_->DeallocateTempTfLiteTensor(rhs); + } + if (output != nullptr) { + micro_context_->DeallocateTempTfLiteTensor(output); + } + } + + private: + MicroContext* micro_context_; + + public: + TfLiteTensor* lhs; + TfLiteTensor* rhs; + TfLiteTensor* output; +}; + +struct EvalOpContext : OpContext { + EvalOpContext(TfLiteContext* context, TfLiteNode* node) + : OpContext(context, node), + lhs(tflite::micro::GetEvalInput(context, node, kInputLhsTensor)), + rhs(tflite::micro::GetEvalInput(context, node, kInputRhsTensor)), + output(tflite::micro::GetEvalOutput(context, node, kOutputTensor)) {} + + const TfLiteEvalTensor* lhs; + const TfLiteEvalTensor* rhs; + TfLiteEvalTensor* output; +}; + +TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, TfLiteNode* node, + const RuntimeShape& extended_lhs_shape, + const RuntimeShape& extended_rhs_shape, + bool adj_x, bool adj_y, int output_rank, + TfLiteTensor* output) { + int64_t orig_size = NumElements(output); + + // make sure the new output dims rank does not exceed the original rank + TF_LITE_ENSURE(context, output_rank <= NumDimensions(output)); + + // make sure output tensor dims are not in the FlatBuffer + TfLiteEvalTensor* output_eval = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy( + context, output, output_eval)); + + // Fill in any broadcast dimensions. + for (int i = 0; i < output_rank - 2; ++i) { + const int lhs_dim = extended_lhs_shape.Dims(i); + const int rhs_dim = extended_rhs_shape.Dims(i); + int broadcast_dim = lhs_dim; + if ((lhs_dim != rhs_dim) && (lhs_dim == 1)) { + broadcast_dim = rhs_dim; + } + output->dims->data[i] = broadcast_dim; + } + // Fill in the matmul dimensions. + int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2; + int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1; + + output->dims->data[output_rank - 2] = extended_lhs_shape.Dims(lhs_rows_index); + output->dims->data[output_rank - 1] = extended_rhs_shape.Dims(rhs_cols_index); + output->dims->size = output_rank; + + // Check that output tensor has not been resized + // since TFLM doesn't support tensor resizing. + TF_LITE_ENSURE_EQ(context, orig_size, NumElements(output)); + + return kTfLiteOk; +} + +TfLiteEvalTensor* AllocInitTransposeTensorFromTfLiteTensor( + TfLiteContext* context, const TfLiteTensor& tensor) { + MicroContext* micro_context = GetMicroContext(context); + TfLiteEvalTensor* eval_tensor = static_cast( + micro_context->AllocatePersistentBuffer(sizeof(TfLiteEvalTensor))); + if (eval_tensor == nullptr) { + return nullptr; + } + + eval_tensor->type = tensor.type; + + const int tensor_rank = NumDimensions(&tensor); + const size_t eval_dims_size = TfLiteIntArrayGetSizeInBytes(tensor_rank); + eval_tensor->dims = static_cast( + micro_context->AllocatePersistentBuffer(eval_dims_size)); + if (eval_tensor->dims == nullptr) { + return nullptr; + } + eval_tensor->dims->size = tensor_rank; + for (int i = 0; i < tensor_rank - 2; ++i) { + eval_tensor->dims->data[i] = tensor.dims->data[i]; + } + // Swap last two dimensions. + eval_tensor->dims->data[tensor_rank - 2] = tensor.dims->data[tensor_rank - 1]; + eval_tensor->dims->data[tensor_rank - 1] = tensor.dims->data[tensor_rank - 2]; + + const size_t eval_data_size = static_cast(NumElements(&tensor)) * + TfLiteTypeGetSize(tensor.type); + eval_tensor->data.data = + micro_context->AllocatePersistentBuffer(eval_data_size); + if (eval_tensor->data.data == nullptr) { + return nullptr; + } + + return eval_tensor; +} + +// Initializes tensors to store transposed operands. +// Allocate storage for hybrid quantization if needed. +// Allocate normal quantization data if needed. +TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, + const PrepareOpContext& op_context) { + OpData* op_data = op_context.op_data; + const TfLiteTensor* lhs = op_context.lhs; + const TfLiteTensor* rhs = op_context.rhs; + MicroContext* micro_context = GetMicroContext(context); + + op_data->quantization = nullptr; + op_data->lhs_transposed_tensor = nullptr; + op_data->rhs_transposed_tensor = nullptr; + + if (lhs->type == kTfLiteInt8 || lhs->type == kTfLiteInt16) { + op_data->quantization = static_castquantization)>( + micro_context->AllocatePersistentBuffer( + sizeof(*op_data->quantization))); + TF_LITE_ENSURE(context, op_data->quantization != nullptr); + } + + // tensor for Transposed LHS; + if (op_context.params->adj_x) { + op_data->lhs_transposed_tensor = + AllocInitTransposeTensorFromTfLiteTensor(context, *lhs); + TF_LITE_ENSURE(context, op_data->lhs_transposed_tensor != nullptr); + } + + // We need a buffer for the RHS if we need to transpose the RHS. We + // transpose by default, so that the two inputs (LHS and RHS) are in a proper + // layout for our fast matrix multiplication routines. If the transpose flag + // is set by the caller, the data is already in the desired layout. + if (!op_context.params->adj_y) { + op_data->rhs_transposed_tensor = + AllocInitTransposeTensorFromTfLiteTensor(context, *rhs); + TF_LITE_ENSURE(context, op_data->rhs_transposed_tensor != nullptr); + } + + return kTfLiteOk; +} + +template +void TransposeRowsColumnsImpl(const TfLiteEvalTensor& tensor_in, + TfLiteEvalTensor* tensor_out) { + const Scalar* input = tflite::micro::GetTensorData(&tensor_in); + Scalar* output = tflite::micro::GetTensorData(tensor_out); + RuntimeShape transposed_shape(tflite::micro::GetTensorShape(&tensor_in)); + RuntimeShape shape(transposed_shape); + TransposeParams params; + const int rank = shape.DimensionsCount(); + params.perm_count = rank; + for (int i = 0; i < rank - 2; ++i) { + params.perm[i] = i; + } + // Transpose the last two dimensions. + params.perm[rank - 2] = rank - 1; + params.perm[rank - 1] = rank - 2; + transposed_shape.SetDim(rank - 1, shape.Dims(rank - 2)); + transposed_shape.SetDim(rank - 2, shape.Dims(rank - 1)); + reference_ops::Transpose(params, shape, input, transposed_shape, output); +} + +TfLiteStatus TransposeRowsColumns(const TfLiteEvalTensor& tensor_in, + TfLiteEvalTensor* tensor_out) { + if (tensor_in.type == kTfLiteFloat32) { + TransposeRowsColumnsImpl(tensor_in, tensor_out); + return kTfLiteOk; + } else if (tensor_in.type == kTfLiteInt8) { + TransposeRowsColumnsImpl(tensor_in, tensor_out); + return kTfLiteOk; + } else if (tensor_in.type == kTfLiteInt16) { + TransposeRowsColumnsImpl(tensor_in, tensor_out); + return kTfLiteOk; + } else { + MicroPrintf( + "BATCH_MATMUL can only transpose tensors with FLOAT32, INT8, INT16 " + "type."); + } + return kTfLiteError; +} + +RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) { + RuntimeShape swapped_shape(shape); + const int32_t dims = shape.DimensionsCount(); + swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1)); + swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2)); + return swapped_shape; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + MicroContext* micro_context = GetMicroContext(context); + return micro_context->AllocatePersistentBuffer(sizeof(OpData)); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + PrepareOpContext op_context(context, node); + const TfLiteTensor* lhs_data = op_context.lhs; + TF_LITE_ENSURE(context, lhs_data != nullptr); + const TfLiteTensor* rhs_data = op_context.rhs; + TF_LITE_ENSURE(context, rhs_data != nullptr); + TfLiteTensor* output = op_context.output; + TF_LITE_ENSURE(context, output != nullptr); + + TF_LITE_ENSURE(context, lhs_data->type == kTfLiteFloat32 || + lhs_data->type == kTfLiteInt8 || + lhs_data->type == kTfLiteInt16); + TF_LITE_ENSURE(context, rhs_data->type == kTfLiteFloat32 || + rhs_data->type == kTfLiteInt8 || + rhs_data->type == kTfLiteInt16); + // Both inputs should be of the same type. + // Hybrid input (FLOAT32 LHS, INT8 RHS) is not supported. + TF_LITE_ENSURE(context, lhs_data->type == rhs_data->type); + // LHS input must match output type. INT32 output not supported. + TF_LITE_ENSURE(context, lhs_data->type == output->type); + + const int lhs_rank = NumDimensions(lhs_data); + const int rhs_rank = NumDimensions(rhs_data); + // Support dimensions between 2 and 5, inclusive. + TF_LITE_ENSURE(context, lhs_rank >= 2); + TF_LITE_ENSURE(context, lhs_rank <= 5); + TF_LITE_ENSURE(context, rhs_rank >= 2); + TF_LITE_ENSURE(context, rhs_rank <= 5); + + TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, op_context)); + + OpData* op_data = op_context.op_data; + // If the RHS is constant, we only transpose once. + op_data->rhs_is_transposed = false; + op_data->lhs_is_constant_tensor = IsConstantTensor(lhs_data); + op_data->rhs_is_constant_tensor = IsConstantTensor(rhs_data); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (lhs_data->type == kTfLiteInt8 || lhs_data->type == kTfLiteInt16) { + TF_LITE_ENSURE(context, op_data->quantization != nullptr); + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, lhs_data, rhs_data, output, &real_multiplier)); + QuantizeMultiplier(real_multiplier, + &op_data->quantization->output_multiplier, + &op_data->quantization->output_shift); + // BatchMatMul has no fused activation functions. Therefore, set + // output activation min and max to min and max of int8_t or int16_t type. + if (lhs_data->type == kTfLiteInt8) { + op_data->quantization->output_activation_min = + std::numeric_limits::min(); + op_data->quantization->output_activation_max = + std::numeric_limits::max(); + } else { + op_data->quantization->output_activation_min = + std::numeric_limits::min(); + op_data->quantization->output_activation_max = + std::numeric_limits::max(); + + TF_LITE_ENSURE_EQ(context, lhs_data->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, rhs_data->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + } + + op_data->quantization->lhs_zero_point = lhs_data->params.zero_point; + op_data->quantization->rhs_zero_point = rhs_data->params.zero_point; + op_data->quantization->output_zero_point = output->params.zero_point; + } + + const int output_rank = std::max(lhs_rank, rhs_rank); + const RuntimeShape extended_lhs_shape = + RuntimeShape::ExtendedShape(output_rank, GetTensorShape(lhs_data)); + const RuntimeShape extended_rhs_shape = + RuntimeShape::ExtendedShape(output_rank, GetTensorShape(rhs_data)); + + // Ensure any batch dimensions obey broacasting rules. + for (int i = 0; i < output_rank - 2; ++i) { + const int lhs_dim = extended_lhs_shape.Dims(i); + const int rhs_dim = extended_rhs_shape.Dims(i); + if (lhs_dim != rhs_dim) { + if (lhs_dim != 1) { + TF_LITE_ENSURE_EQ(context, rhs_dim, 1); + } + } + } + bool adj_x = op_context.params->adj_x; + bool adj_y = op_context.params->adj_y; + // Ensure other dimensions work for matrix multiplication. + int accum_dim_lhs = adj_x ? extended_lhs_shape.Dims(output_rank - 2) + : extended_lhs_shape.Dims(output_rank - 1); + int accum_dim_rhs = adj_y ? extended_rhs_shape.Dims(output_rank - 1) + : extended_rhs_shape.Dims(output_rank - 2); + + TF_LITE_ENSURE_EQ(context, accum_dim_lhs, accum_dim_rhs); + TfLiteStatus status = + ReshapeOutputTensor(context, node, extended_lhs_shape, extended_rhs_shape, + adj_x, adj_y, output_rank, output); + return status; +} + +TfLiteStatus EvalInt8(TfLiteContext* context, const OpData& data, + const RuntimeShape& lhs_shape, + const TfLiteEvalTensor& lhs, + const RuntimeShape& rhs_shape, + const TfLiteEvalTensor& rhs, + const RuntimeShape& output_shape, + TfLiteEvalTensor* output) { + TF_LITE_ENSURE(context, data.quantization != nullptr); + // Reuse params struct from FullyConnected Op. + FullyConnectedParams op_params; + op_params.input_offset = -data.quantization->lhs_zero_point; + op_params.weights_offset = + -data.quantization->rhs_zero_point; // filter offset + op_params.output_offset = data.quantization->output_zero_point; + op_params.output_multiplier = data.quantization->output_multiplier; + op_params.output_shift = data.quantization->output_shift; + op_params.quantized_activation_min = data.quantization->output_activation_min; + op_params.quantized_activation_max = data.quantization->output_activation_max; + op_params.lhs_cacheable = data.lhs_is_constant_tensor; + op_params.rhs_cacheable = data.rhs_is_constant_tensor; + + // Note we pass RHS args first, LHS args second. See note for Eval. + reference_ops::BatchMatMul( + op_params, rhs_shape, tflite::micro::GetTensorData(&rhs), + lhs_shape, tflite::micro::GetTensorData(&lhs), output_shape, + tflite::micro::GetTensorData(output)); + + return kTfLiteOk; +} + +TfLiteStatus EvalInt16(TfLiteContext* context, const OpData& data, + const RuntimeShape& lhs_shape, + const TfLiteEvalTensor& lhs, + const RuntimeShape& rhs_shape, + const TfLiteEvalTensor& rhs, + const RuntimeShape& output_shape, + TfLiteEvalTensor* output) { + TF_LITE_ENSURE(context, data.quantization != nullptr); + // Reuse params struct from FullyConnected Op. + FullyConnectedParams op_params; + op_params.input_offset = -data.quantization->lhs_zero_point; + op_params.weights_offset = + -data.quantization->rhs_zero_point; // filter offset + op_params.output_offset = data.quantization->output_zero_point; + op_params.output_multiplier = data.quantization->output_multiplier; + op_params.output_shift = data.quantization->output_shift; + op_params.quantized_activation_min = data.quantization->output_activation_min; + op_params.quantized_activation_max = data.quantization->output_activation_max; + op_params.lhs_cacheable = data.lhs_is_constant_tensor; + op_params.rhs_cacheable = data.rhs_is_constant_tensor; + + // Note we pass RHS args first, LHS args second. See note for Eval. + reference_ops::BatchMatMul( + op_params, rhs_shape, tflite::micro::GetTensorData(&rhs), + lhs_shape, tflite::micro::GetTensorData(&lhs), output_shape, + tflite::micro::GetTensorData(output)); + + return kTfLiteOk; +} + +// Perform a batch matrix multiply on +// LHS <..., A, B> X RHS<..., B, C> +// where the leading dimensions of LHS and RHS obey broadcasting rules +// (this Op will apply broadcasting rules). +// We assume that LHS and RHS are both row oriented (adjacent values in memory +// are in the same row) and will output in the same memory layout. However, +// our fast GEMM libraries assume RCC layout (LHS row oriented, +// RHS column oriented, output column oriented). Therefore, we perform +// RHS <..., C, B> X LHS <..., B, A> +// where output is a C X A column-oriented, which is equivalent to +// A X C row-oriented. +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + EvalOpContext op_context(context, node); + OpData* op_data = op_context.op_data; + const TfLiteEvalTensor* lhs = op_context.lhs; + const TfLiteEvalTensor* rhs = op_context.rhs; + TfLiteEvalTensor* output = op_context.output; + RuntimeShape orig_lhs_shape = tflite::micro::GetTensorShape(lhs); + RuntimeShape orig_rhs_shape = tflite::micro::GetTensorShape(rhs); + + bool adj_y = op_context.params->adj_y; + bool adj_x = op_context.params->adj_x; + + // Compress BatchMatMul when third from last RHS dimension is one. + int32_t rhs_dims_count = orig_rhs_shape.DimensionsCount(); + int32_t lhs_dims_count = orig_lhs_shape.DimensionsCount(); + // Compress ops where rhs shape is [..., 1, X, Y] and lhs shape is + // [..., Q, R, S] which is equivalent to rhs: [..., X, Y] and + // lhs: [..., Q * R, S]. + if (rhs_dims_count > 2 && lhs_dims_count > 2) { + int rhs_one = orig_rhs_shape.DimsData()[rhs_dims_count - 3]; + if (rhs_one == 1) { + int32_t* lhs_dims = orig_lhs_shape.DimsData(); + int32_t* rhs_dims = orig_rhs_shape.DimsData(); + RuntimeShape tmp_l(lhs_dims_count - 1, lhs_dims); + tmp_l.SetDim(lhs_dims_count - 3, + lhs_dims[lhs_dims_count - 3] * lhs_dims[lhs_dims_count - 2]); + tmp_l.SetDim(lhs_dims_count - 2, lhs_dims[lhs_dims_count - 1]); + orig_lhs_shape.ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData()); + RuntimeShape tmp_r(rhs_dims_count - 1, orig_rhs_shape.DimsData()); + tmp_r.SetDim(rhs_dims_count - 3, rhs_dims[rhs_dims_count - 2]); + tmp_r.SetDim(rhs_dims_count - 2, rhs_dims[rhs_dims_count - 1]); + orig_rhs_shape.ReplaceWith(tmp_r.DimensionsCount(), tmp_r.DimsData()); + rhs_dims_count = orig_rhs_shape.DimensionsCount(); + lhs_dims_count = orig_lhs_shape.DimensionsCount(); + } + } + + TfLiteEvalTensor* rhs_tensor = adj_y ? const_cast(rhs) + : op_data->rhs_transposed_tensor; + TfLiteEvalTensor* lhs_tensor = adj_x ? op_data->lhs_transposed_tensor + : const_cast(lhs); + TF_LITE_ENSURE(context, rhs_tensor != nullptr); + TF_LITE_ENSURE(context, lhs_tensor != nullptr); + if (!adj_y) { + // TODO(b/154760341): Constant tensors should already be transposed, but + // we transpose once if necessary for now. + if (!(op_data->rhs_is_constant_tensor && op_data->rhs_is_transposed)) { + TransposeRowsColumns(*rhs, rhs_tensor); + op_data->rhs_is_transposed = true; + } + } + if (adj_x) { + TransposeRowsColumns(*lhs, lhs_tensor); + } + RuntimeShape rhs_shape = + adj_y ? orig_rhs_shape : SwapRowColumnDims(orig_rhs_shape); + RuntimeShape lhs_shape = + adj_x ? orig_lhs_shape : SwapRowColumnDims(orig_lhs_shape); + + switch (lhs->type) { + case kTfLiteFloat32: + // Note we pass RHS args first, LHS args second. See note above. + reference_ops::BatchMatMul( + rhs_shape, tflite::micro::GetTensorData(rhs_tensor), lhs_shape, + tflite::micro::GetTensorData(lhs_tensor), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + case kTfLiteInt8: + return EvalInt8(context, *op_data, lhs_shape, *lhs_tensor, rhs_shape, + *rhs_tensor, tflite::micro::GetTensorShape(output), + output); + case kTfLiteInt16: + return EvalInt16(context, *op_data, lhs_shape, *lhs_tensor, rhs_shape, + *rhs_tensor, tflite::micro::GetTensorShape(output), + output); + default: + MicroPrintf("BATCH_MATMUL doesn't support input type %s", + TfLiteTypeGetName(lhs->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace + +TFLMRegistration Register_BATCH_MATMUL() { + return tflite::micro::RegisterOp(Init, Prepare, Eval); +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/batch_matmul_test.cc b/tensorflow/lite/micro/kernels/batch_matmul_test.cc new file mode 100644 index 00000000000..abba7577764 --- /dev/null +++ b/tensorflow/lite/micro/kernels/batch_matmul_test.cc @@ -0,0 +1,736 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +constexpr float kFloatTolerance = 1e-5; + +constexpr int kNumInputs = 2; +constexpr int kNumOutputs = 1; +constexpr int kLhsInputTensorIndex = 0; +constexpr int kRhsInputTensorIndex = 1; +constexpr int kOutputTensorIndex = 2; + +// data_min/data_max are used to compute symmetric scale, zero-point is 0 +// scale should be 0 to use data_min/data_max +template +struct TestQuantizationParams { + // quantization parameters + float scale; // if 0, use data_min and data_max + int zero_point; + float data_min; // input data minimum value + float data_max; // input data maximum value + + T quantized_data[kNumElements]; // quantized storage +}; + +micro::KernelRunner* GetKernelRunnerInstance( + TfLiteTensor* tensors, int tensors_count, + const TfLiteBatchMatMulParams& params, bool need_init_prepare) { + static int kInputArrayData[] = {kNumInputs, kLhsInputTensorIndex, + kRhsInputTensorIndex}; + TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData); + static int kOutputArrayData[] = {kNumOutputs, kOutputTensorIndex}; + TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData); + + static const TFLMRegistration registration = tflite::Register_BATCH_MATMUL(); + + alignas(micro::KernelRunner) static char + kernel_runner_buffer[sizeof(micro::KernelRunner)] = {}; + + static micro::KernelRunner* runner = nullptr; + if (runner == nullptr || need_init_prepare) { + runner = new (kernel_runner_buffer) + micro::KernelRunner(registration, tensors, tensors_count, inputs_array, + outputs_array, ¶ms); + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner->InitAndPrepare()); + } + + return runner; +} + +void TestBatchMatMulFloat(const TfLiteBatchMatMulParams& params, + const int* input_dims_data[kNumInputs], + const float* input_data_lhs, + const float* input_data_rhs, const int* expected_dims, + const float* expected_data, float* output_data, + bool need_constant_rhs = false, + bool need_init_prepare = true) { + TfLiteIntArray* input_dims_lhs = IntArrayFromInts(input_dims_data[0]); + TfLiteIntArray* input_dims_rhs = IntArrayFromInts(input_dims_data[1]); + TfLiteIntArray* output_dims = IntArrayFromInts(expected_dims); + const int kOutputCount = ElementCount(*output_dims); + + static TfLiteTensor tensors[kNumInputs + kNumOutputs]; + + if (need_init_prepare) { + tensors[kLhsInputTensorIndex] = + CreateTensor(input_data_lhs, input_dims_lhs); + tensors[kRhsInputTensorIndex] = + CreateTensor(input_data_rhs, input_dims_rhs); + if (need_constant_rhs) { + tensors[kRhsInputTensorIndex].allocation_type = kTfLiteMmapRo; + } + tensors[kOutputTensorIndex] = CreateTensor(output_data, output_dims); + } + + constexpr int kTensorCount = std::extent::value; + micro::KernelRunner* runner = + GetKernelRunnerInstance(tensors, kTensorCount, params, need_init_prepare); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner->Invoke()); + + // check output data against expected + for (int i = 0; i < kOutputCount; i++) { + TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i], + kFloatTolerance); + } + + // check output dimensions (relocated) against original dimensions + TF_LITE_MICRO_EXPECT_EQ(output_dims->size, + tensors[kOutputTensorIndex].dims->size); + for (int i = 0; i < output_dims->size; i++) { + TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i], + tensors[kOutputTensorIndex].dims->data[i]); + } +} + +template +void SetScaleAndZeroPoint(TestQuantizationParams* q_params) { + if (q_params->scale == 0.0f || q_params->data_max != 0 || + q_params->data_min != 0) { + q_params->scale = + ScaleFromMinMax(q_params->data_min, q_params->data_max); + q_params->zero_point = + ZeroPointFromMinMax(q_params->data_min, q_params->data_max); + } +} + +template +void TestBatchMatMulQuantized( + const TfLiteBatchMatMulParams& params, + TestQuantizationParams* quantization_lhs, + TestQuantizationParams* quantization_rhs, + TestQuantizationParams* quantization_output, + const int* input_dims_data[kNumInputs], const float* input_data_lhs, + const float* input_data_rhs, const int* expected_dims, + const T* expected_data, const float* output_data) { + TfLiteIntArray* input_dims_lhs = IntArrayFromInts(input_dims_data[0]); + TfLiteIntArray* input_dims_rhs = IntArrayFromInts(input_dims_data[1]); + TfLiteIntArray* output_dims = IntArrayFromInts(expected_dims); + const int kOutputCount = ElementCount(*output_dims); + + static TfLiteTensor tensors[kNumInputs + kNumOutputs]; + + SetScaleAndZeroPoint(quantization_lhs); + tensors[kLhsInputTensorIndex] = CreateQuantizedTensor( + input_data_lhs, quantization_lhs->quantized_data, input_dims_lhs, + quantization_lhs->scale, quantization_lhs->zero_point); + SetScaleAndZeroPoint(quantization_rhs); + tensors[kRhsInputTensorIndex] = CreateQuantizedTensor( + input_data_rhs, quantization_rhs->quantized_data, input_dims_rhs, + quantization_rhs->scale, quantization_rhs->zero_point); + SetScaleAndZeroPoint(quantization_output); + tensors[kOutputTensorIndex] = CreateQuantizedTensor( + quantization_output->quantized_data, output_dims, + quantization_output->scale, quantization_output->zero_point); + + constexpr int kTensorCount = std::extent::value; + micro::KernelRunner* runner = + GetKernelRunnerInstance(tensors, kTensorCount, params, true); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner->Invoke()); + + // check output data against expected + for (int i = 0; i < kOutputCount; i++) { + TF_LITE_MICRO_EXPECT_EQ(expected_data[i], + quantization_output->quantized_data[i]); + } + // check dequantized output data against expected + for (int i = 0; i < kOutputCount; i++) { + float dequantized_value = (quantization_output->quantized_data[i] - + quantization_output->zero_point) * + quantization_output->scale; + TF_LITE_MICRO_EXPECT_NEAR(output_data[i], dequantized_value, + kFloatTolerance); + } + + // check output dimensions (relocated) against original dimensions + TF_LITE_MICRO_EXPECT_EQ(output_dims->size, + tensors[kOutputTensorIndex].dims->size); + for (int i = 0; i < output_dims->size; i++) { + TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i], + tensors[kOutputTensorIndex].dims->data[i]); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(BatchMatMulOpTestFloat32Test_Ones) { + constexpr int kLhsInputDims[] = {4, 3, 2, 1, 4}; + constexpr int kRhsInputDims[] = {4, 3, 1, 4, 1}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr size_t kLhsInputSize = 24; + float lhs_input[kLhsInputSize]; + std::iota(std::begin(lhs_input), std::end(lhs_input), 1); + + constexpr size_t kRhsInputSize = 12; + float rhs_input[kRhsInputSize]; + std::iota(std::begin(rhs_input), std::end(rhs_input), 1); + + constexpr float kExpect[] = {30, 70, 278, 382, 782, 950}; + constexpr int kOutputDims[] = {4, 3, 2, 1, 1}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + false, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, lhs_input, + rhs_input, kOutputDims, kExpect, + output_data); +} + +TF_LITE_MICRO_TEST(BatchMatMulOpTestFloat32Test_Flatten) { + constexpr int kLhsInputDims[] = {4, 3, 2, 2, 4}; + constexpr int kRhsInputDims[] = {4, 3, 1, 4, 1}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr size_t kLhsInputSize = 48; + float lhs_input[kLhsInputSize]; + std::iota(std::begin(lhs_input), std::end(lhs_input), 1); + + constexpr size_t kRhsInputSize = 12; + float rhs_input[kRhsInputSize]; + std::iota(std::begin(rhs_input), std::end(rhs_input), 1); + + constexpr float kExpect[] = {30, 70, 110, 150, 486, 590, + 694, 798, 1454, 1622, 1790, 1958}; + constexpr int kOutputDims[] = {4, 3, 2, 2, 1}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + false, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, lhs_input, + rhs_input, kOutputDims, kExpect, + output_data); +} + +TF_LITE_MICRO_TEST(BatchMatMulOpTestFloat32Test_Simple) { + constexpr int kLhsInputDims[] = {3, 1, 2, 3}; + constexpr int kRhsInputDims[] = {3, 1, 3, 4}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr size_t kLhsInputSize = 6; + float lhs_input[kLhsInputSize]; + std::iota(std::begin(lhs_input), std::end(lhs_input), 1); + + constexpr size_t kRhsInputSize = 12; + float rhs_input[kRhsInputSize]; + std::iota(std::begin(rhs_input), std::end(rhs_input), 7); + + constexpr float kExpect[] = {74., 80., 86., 92., 173., 188., 203., 218.}; + constexpr int kOutputDims[] = {3, 1, 2, 4}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + false, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, lhs_input, + rhs_input, kOutputDims, kExpect, + output_data); +} + +TF_LITE_MICRO_TEST(BatchMatMulOpTestFloat32Test_SimpleRHSAdjoint) { + constexpr int kLhsInputDims[] = {3, 1, 2, 3}; + constexpr int kRhsInputDims[] = {3, 1, 4, 3}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr size_t kLhsInputSize = 6; + float lhs_input[kLhsInputSize]; + std::iota(std::begin(lhs_input), std::end(lhs_input), 1); + + constexpr float kRhsInput[] = {7, 11, 15, 8, 12, 16, 9, 13, 17, 10, 14, 18}; + + constexpr float kExpect[] = {74., 80., 86., 92., 173., 188., 203., 218.}; + constexpr int kOutputDims[] = {3, 1, 2, 4}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + false, // adj_x + true, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, lhs_input, + kRhsInput, kOutputDims, kExpect, + output_data); +} + +TF_LITE_MICRO_TEST(BatchMatMulOpTestFloat32Test_SimpleLHSAdjoint) { + constexpr int kLhsInputDims[] = {3, 1, 3, 2}; + constexpr int kRhsInputDims[] = {3, 1, 3, 4}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + constexpr float kLhsInput[] = {1, 4, 2, 5, 3, 6}; + + constexpr size_t kRhsInputSize = 12; + float rhs_input[kRhsInputSize]; + std::iota(std::begin(rhs_input), std::end(rhs_input), 7); + + constexpr float kExpect[] = {74., 80., 86., 92., 173., 188., 203., 218.}; + constexpr int kOutputDims[] = {3, 1, 2, 4}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + true, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, kLhsInput, + rhs_input, kOutputDims, kExpect, + output_data); +} + +TF_LITE_MICRO_TEST(BatchMatMulOpTestFloat32Test_BatchSizeTwo) { + constexpr int kLhsInputDims[] = {3, 2, 2, 3}; + constexpr int kRhsInputDims[] = {3, 2, 3, 4}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + constexpr size_t kLhsInputSize = 12; + float lhs_input[kLhsInputSize]; + std::iota(std::begin(lhs_input), std::end(lhs_input), 1); + + constexpr size_t kRhsInputSize = 24; + float rhs_input[kRhsInputSize]; + std::iota(std::begin(rhs_input), std::end(rhs_input), 7); + + constexpr float kExpect[] = {74., 80., 86., 92., 173., 188., 203., 218., + 560., 584., 608., 632., 767., 800., 833., 866.}; + constexpr int kOutputDims[] = {3, 2, 2, 4}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + false, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, lhs_input, + rhs_input, kOutputDims, kExpect, + output_data); +} + +TF_LITE_MICRO_TEST(BatchMatMulOpTestFloat32Test_Broadcast) { + constexpr int kLhsInputDims[] = {3, 2, 2, 3}; + constexpr int kRhsInputDims[] = {2, 3, 4}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + constexpr size_t kLhsInputSize = 12; + float lhs_input[kLhsInputSize]; + std::iota(std::begin(lhs_input), std::end(lhs_input), 1); + + constexpr size_t kRhsInputSize = 12; + float rhs_input[kRhsInputSize]; + std::iota(std::begin(rhs_input), std::end(rhs_input), 7); + + constexpr float kExpect[] = {74., 80., 86., 92., 173., 188., 203., 218., + 272., 296., 320., 344., 371., 404., 437., 470.}; + constexpr int kOutputDims[] = {3, 2, 2, 4}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + false, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, lhs_input, + rhs_input, kOutputDims, kExpect, + output_data); +} + +TF_LITE_MICRO_TEST(BatchMatMulOpTestFloat32Test_BroadcastLHSAdjoint) { + constexpr int kLhsInputDims[] = {3, 2, 3, 2}; + constexpr int kRhsInputDims[] = {2, 3, 4}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr float kLhsInput[] = {1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12}; + + constexpr size_t kRhsInputSize = 12; + float rhs_input[kRhsInputSize]; + std::iota(std::begin(rhs_input), std::end(rhs_input), 7); + + constexpr float kExpect[] = {74., 80., 86., 92., 173., 188., 203., 218., + 272., 296., 320., 344., 371., 404., 437., 470.}; + constexpr int kOutputDims[] = {3, 2, 2, 4}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + true, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, kLhsInput, + rhs_input, kOutputDims, kExpect, + output_data); +} + +TF_LITE_MICRO_TEST(BatchMatMulOpTestFloat32Test_Broadcast2) { + constexpr int kLhsInputDims[] = {4, 2, 1, 3, 2}; + constexpr int kRhsInputDims[] = {3, 3, 2, 4}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr size_t kLhsInputSize = 12; + float lhs_input[kLhsInputSize]; + std::iota(std::begin(lhs_input), std::end(lhs_input), 1); + + constexpr size_t kRhsInputSize = 24; + float rhs_input[kRhsInputSize]; + std::iota(std::begin(rhs_input), std::end(rhs_input), 7); + + constexpr float kExpect[] = { + 29., 32., 35., 38., 65., 72., 79., 86., 101., 112., 123., 134., + 53., 56., 59., 62., 121., 128., 135., 142., 189., 200., 211., 222., + 77., 80., 83., 86., 177., 184., 191., 198., 277., 288., 299., 310., + 137., 152., 167., 182., 173., 192., 211., 230., 209., 232., 255., 278., + 257., 272., 287., 302., 325., 344., 363., 382., 393., 416., 439., 462., + 377., 392., 407., 422., 477., 496., 515., 534., 577., 600., 623., 646.}; + constexpr int kOutputDims[] = {4, 2, 3, 3, 4}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + false, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, lhs_input, + rhs_input, kOutputDims, kExpect, + output_data); +} + +TF_LITE_MICRO_TEST(BatchMatMulOpTestFloat32Test_Broadcast2LHSAdjoint) { + constexpr int kLhsInputDims[] = {4, 2, 1, 2, 3}; + constexpr int kRhsInputDims[] = {3, 3, 2, 4}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr float kLhsInput[] = {1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}; + + constexpr size_t kRhsInputSize = 24; + float rhs_input[kRhsInputSize]; + std::iota(std::begin(rhs_input), std::end(rhs_input), 7); + + constexpr float kExpect[] = { + 29., 32., 35., 38., 65., 72., 79., 86., 101., 112., 123., 134., + 53., 56., 59., 62., 121., 128., 135., 142., 189., 200., 211., 222., + 77., 80., 83., 86., 177., 184., 191., 198., 277., 288., 299., 310., + 137., 152., 167., 182., 173., 192., 211., 230., 209., 232., 255., 278., + 257., 272., 287., 302., 325., 344., 363., 382., 393., 416., 439., 462., + 377., 392., 407., 422., 477., 496., 515., 534., 577., 600., 623., 646.}; + constexpr int kOutputDims[] = {4, 2, 3, 3, 4}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + true, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, kLhsInput, + rhs_input, kOutputDims, kExpect, + output_data); +} + +TF_LITE_MICRO_TEST(BatchMatMulOpTestFloat32Test_Broadcast2RHSAdjoint) { + constexpr int kLhsInputDims[] = {4, 2, 1, 3, 2}; + constexpr int kRhsInputDims[] = {3, 3, 4, 2}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr size_t kLhsInputSize = 12; + float lhs_input[kLhsInputSize]; + std::iota(std::begin(lhs_input), std::end(lhs_input), 1); + + constexpr float kRhsInput[] = {7, 11, 8, 12, 9, 13, 10, 14, + 15, 19, 16, 20, 17, 21, 18, 22, + 23, 27, 24, 28, 25, 29, 26, 30}; + + constexpr float kExpect[] = { + 29., 32., 35., 38., 65., 72., 79., 86., 101., 112., 123., 134., + 53., 56., 59., 62., 121., 128., 135., 142., 189., 200., 211., 222., + 77., 80., 83., 86., 177., 184., 191., 198., 277., 288., 299., 310., + 137., 152., 167., 182., 173., 192., 211., 230., 209., 232., 255., 278., + 257., 272., 287., 302., 325., 344., 363., 382., 393., 416., 439., 462., + 377., 392., 407., 422., 477., 496., 515., 534., 577., 600., 623., 646.}; + constexpr int kOutputDims[] = {4, 2, 3, 3, 4}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + false, // adj_x + true, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, lhs_input, + kRhsInput, kOutputDims, kExpect, + output_data); +} + +TF_LITE_MICRO_TEST(BatchMatMulOpTestFloat32Test_Broadcast2BothAdjoint) { + constexpr int kLhsInputDims[] = {4, 2, 1, 2, 3}; + constexpr int kRhsInputDims[] = {3, 3, 4, 2}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr float kLhsInput[] = {1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}; + + constexpr float kRhsInput[] = {7, 11, 8, 12, 9, 13, 10, 14, + 15, 19, 16, 20, 17, 21, 18, 22, + 23, 27, 24, 28, 25, 29, 26, 30}; + + constexpr float kExpect[] = { + 29., 32., 35., 38., 65., 72., 79., 86., 101., 112., 123., 134., + 53., 56., 59., 62., 121., 128., 135., 142., 189., 200., 211., 222., + 77., 80., 83., 86., 177., 184., 191., 198., 277., 288., 299., 310., + 137., 152., 167., 182., 173., 192., 211., 230., 209., 232., 255., 278., + 257., 272., 287., 302., 325., 344., 363., 382., 393., 416., 439., 462., + 377., 392., 407., 422., 477., 496., 515., 534., 577., 600., 623., 646.}; + constexpr int kOutputDims[] = {4, 2, 3, 3, 4}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + true, // adj_x + true, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, kLhsInput, + kRhsInput, kOutputDims, kExpect, + output_data); +} + +TF_LITE_MICRO_TEST(BatchMatMulOpTestFloat32Test_BroadcastFromRHS) { + constexpr int kLhsInputDims[] = {2, 4, 5}; + constexpr int kRhsInputDims[] = {4, 3, 1, 5, 2}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr size_t kLhsInputSize = 20; + float lhs_input[kLhsInputSize]; + std::iota(std::begin(lhs_input), std::end(lhs_input), 1); + + constexpr size_t kRhsInputSize = 30; + float rhs_input[kRhsInputSize]; + std::iota(std::begin(rhs_input), std::end(rhs_input), 7); + + constexpr float kExpect[] = {185., 200., 460., 500., 735., 800., + 1010., 1100., 335., 350., 860., 900., + 1385., 1450., 1910., 2000., 485., 500., + 1260., 1300., 2035., 2100., 2810., 2900.}; + constexpr int kOutputDims[] = {4, 3, 1, 4, 2}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + false, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, lhs_input, + rhs_input, kOutputDims, kExpect, + output_data); +} + +TF_LITE_MICRO_TEST(ConstRHSBatchMatMulOpModelRHSNotAdjoint) { + constexpr int kLhsInputDims[] = {3, 1, 6, 2}; + constexpr int kRhsInputDims[] = {2, 2, 3}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr float kLhsInput[] = {6, 3, 7, 4, 6, 9, 2, 6, 7, 4, 3, 7}; + + constexpr float kRhsInput[] = {6, 3, 7, 4, 6, 9}; + + constexpr float kExpect[] = {48, 36, 69, 58, 45, 85, 72, 72, 123, + 36, 42, 68, 58, 45, 85, 46, 51, 84}; + constexpr int kOutputDims[] = {3, 1, 6, 3}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + constexpr TfLiteBatchMatMulParams params = { + false, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, kLhsInput, + kRhsInput, kOutputDims, kExpect, + output_data, true); + // Eval twice to make sure constant transposed RHS is persistent. + tflite::testing::TestBatchMatMulFloat(params, kInputDims, kLhsInput, + kRhsInput, kOutputDims, kExpect, + output_data, true, false); +} + +TF_LITE_MICRO_TEST(QuantizedBatchMatMulOpTestSimpleTestQuantizedInt8) { + constexpr int kLhsInputDims[] = {2, 2, 10}; + constexpr int kRhsInputDims[] = {2, 10, 3}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr float kLhsInput[] = { + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }; + constexpr int kLhsInputCount = std::extent::value; + + constexpr float kRhsInput[] = { + 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, + 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, + }; + constexpr int kRhsInputCount = std::extent::value; + + constexpr int8_t kExpect[] = {22, 22, 22, 56, 56, 56}; + constexpr int kOutputDims[] = {2, 2, 3}; + constexpr int kOutputCount = std::extent::value; + constexpr float output_data[kOutputCount] = {23, 23, 23, 57, 57, 57}; + + constexpr TfLiteBatchMatMulParams params = { + false, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestQuantizationParams + quantization_params_lhs = {0.0f, // scale + 0, // zero_point + -63.5f, // data_min + 64.0f, // data_max + {}}; + tflite::testing::TestQuantizationParams + quantization_params_rhs = {0.0f, // scale + 0, // zero_point + -63.5f, // data_min + 64.0f, // data_max + {}}; + tflite::testing::TestQuantizationParams + quantization_params_output = {0.0f, // scale + 0, // zero_point + -127.0f, // data_min + 128.0f, // data_max + {}}; + + tflite::testing::TestBatchMatMulQuantized( + params, &quantization_params_lhs, &quantization_params_rhs, + &quantization_params_output, kInputDims, kLhsInput, kRhsInput, + kOutputDims, kExpect, output_data); +} + +TF_LITE_MICRO_TEST(QuantizedBatchMatMulOpTestSimpleTestQuantizedInt16) { + constexpr int kLhsInputDims[] = {2, 2, 10}; + constexpr int kRhsInputDims[] = {2, 10, 3}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr float kLhsInput[] = { + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }; + constexpr int kLhsInputCount = std::extent::value; + + constexpr float kRhsInput[] = { + 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, + 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, + }; + constexpr int kRhsInputCount = std::extent::value; + + constexpr int16_t kExpect[] = {23, 23, 23, 57, 57, 57}; + constexpr int kOutputDims[] = {2, 2, 3}; + constexpr int kOutputCount = std::extent::value; + constexpr float output_data[kOutputCount] = {23, 23, 23, 57, 57, 57}; + + constexpr TfLiteBatchMatMulParams params = { + false, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestQuantizationParams + quantization_params_lhs = {}; + quantization_params_lhs.scale = 10.0f / std::numeric_limits::max(); + tflite::testing::TestQuantizationParams + quantization_params_rhs = {}; + quantization_params_rhs.scale = 10.0f / std::numeric_limits::max(); + + tflite::testing::TestQuantizationParams + quantization_params_output = {}; + quantization_params_output.scale = 1.0f; + + tflite::testing::TestBatchMatMulQuantized( + params, &quantization_params_lhs, &quantization_params_rhs, + &quantization_params_output, kInputDims, kLhsInput, kRhsInput, + kOutputDims, kExpect, output_data); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/kernel_runner.cc b/tensorflow/lite/micro/kernels/kernel_runner.cc index d5112a1ec69..602778d7c50 100644 --- a/tensorflow/lite/micro/kernels/kernel_runner.cc +++ b/tensorflow/lite/micro/kernels/kernel_runner.cc @@ -37,7 +37,8 @@ void ClearBufferApi(TfLiteContext* context_) { KernelRunner::KernelRunner(const TFLMRegistration& registration, TfLiteTensor* tensors, int tensors_size, TfLiteIntArray* inputs, TfLiteIntArray* outputs, - void* builtin_data, TfLiteIntArray* intermediates) + const void* builtin_data, + TfLiteIntArray* intermediates) : registration_(registration), allocator_(SingleArenaBufferAllocator::Create(kKernelRunnerBuffer_, kKernelRunnerBufferSize_)), @@ -57,7 +58,7 @@ KernelRunner::KernelRunner(const TFLMRegistration& registration, // Prepare TfLiteNode: node_.inputs = inputs; node_.outputs = outputs; - node_.builtin_data = builtin_data; + node_.builtin_data = const_cast(builtin_data); node_.intermediates = intermediates; } diff --git a/tensorflow/lite/micro/kernels/kernel_runner.h b/tensorflow/lite/micro/kernels/kernel_runner.h index d617c449b25..25b97c11302 100644 --- a/tensorflow/lite/micro/kernels/kernel_runner.h +++ b/tensorflow/lite/micro/kernels/kernel_runner.h @@ -35,7 +35,7 @@ class KernelRunner { public: KernelRunner(const TFLMRegistration& registration, TfLiteTensor* tensors, int tensors_size, TfLiteIntArray* inputs, - TfLiteIntArray* outputs, void* builtin_data, + TfLiteIntArray* outputs, const void* builtin_data, TfLiteIntArray* intermediates = nullptr); // Calls init and prepare on the kernel (i.e. TFLMRegistration) struct. diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index caa59e9126f..2e33a6730bd 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -40,6 +40,7 @@ TFLMRegistration Register_ARG_MAX(); TFLMRegistration Register_ARG_MIN(); TFLMRegistration Register_ASSIGN_VARIABLE(); TFLMRegistration Register_AVERAGE_POOL_2D(); +TFLMRegistration Register_BATCH_MATMUL(); TFLMRegistration Register_BATCH_TO_SPACE_ND(); TFLMRegistration Register_BROADCAST_ARGS(); TFLMRegistration Register_BROADCAST_TO(); diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index 4553db0638f..b2563a93072 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -143,6 +143,11 @@ class MicroMutableOpResolver : public MicroOpResolver { return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, registration, ParsePool); } + TfLiteStatus AddBatchMatMul() { + return AddBuiltin(BuiltinOperator_BATCH_MATMUL, + tflite::Register_BATCH_MATMUL(), ParseBatchMatMul); + } + TfLiteStatus AddBatchToSpaceNd() { return AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND(), ParseBatchToSpaceNd); diff --git a/tensorflow/lite/micro/test_helpers.cc b/tensorflow/lite/micro/test_helpers.cc index 15d238203e6..3f0f5ec0826 100644 --- a/tensorflow/lite/micro/test_helpers.cc +++ b/tensorflow/lite/micro/test_helpers.cc @@ -1876,8 +1876,8 @@ int TestStrcmp(const char* a, const char* b) { // Create a TfLiteIntArray from an array of ints. The first element in the // supplied array must be the size of the array expressed as an int. -TfLiteIntArray* IntArrayFromInts(int* int_array) { - return reinterpret_cast(int_array); +TfLiteIntArray* IntArrayFromInts(const int* int_array) { + return reinterpret_cast(const_cast(int_array)); } // Create a TfLiteFloatArray from an array of floats. The first element in the diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h index 578282e9b28..2a6204feb12 100644 --- a/tensorflow/lite/micro/test_helpers.h +++ b/tensorflow/lite/micro/test_helpers.h @@ -195,7 +195,7 @@ void PopulateContext(TfLiteTensor* tensors, int tensors_size, // Create a TfLiteIntArray from an array of ints. The first element in the // supplied array must be the size of the array expressed as an int. -TfLiteIntArray* IntArrayFromInts(int* int_array); +TfLiteIntArray* IntArrayFromInts(const int* int_array); // Create a TfLiteFloatArray from an array of floats. The first element in the // supplied array must be the size of the array expressed as a float. diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index d6bc68a4002..dc846edeeb6 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -364,6 +364,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/add_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/add_n.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/arg_min_max.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/assign_variable.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/batch_matmul.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/batch_to_space_nd.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/broadcast_args.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/broadcast_to.cc \