Skip to content

Adding bmm, mm, view_copy, slice_copy, split_with_sizes_copy optimizations #9877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
- op: bmm.out
kernels:
- arg_meta: null
kernel_name: torch::executor::bmm_out
kernel_name: cadence::impl::HiFi::bmm_out

- op: cat.out
kernels:
Expand Down Expand Up @@ -107,6 +107,11 @@
- arg_meta: null
kernel_name: cadence::impl::HiFi::minimum_out

- op: mm.out
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::mm_out

- op: mul.out
kernels:
- arg_meta: null
Expand Down Expand Up @@ -150,12 +155,12 @@
- op: slice_copy.Tensor_out
kernels:
- arg_meta: null
kernel_name: torch::executor::slice_copy_Tensor_out
kernel_name: cadence::impl::HiFi::slice_copy_Tensor_out

- op: split_with_sizes_copy.out
kernels:
- arg_meta: null
kernel_name: torch::executor::split_with_sizes_copy_out
kernel_name: cadence::impl::HiFi::split_with_sizes_copy_out

- op: sub.out
kernels:
Expand All @@ -170,7 +175,7 @@
- op: view_copy.out
kernels:
- arg_meta: null
kernel_name: torch::executor::view_copy_out
kernel_name: cadence::impl::HiFi::view_copy_out

- op: where.self_out
kernels:
Expand Down
9 changes: 5 additions & 4 deletions backends/cadence/hifi/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,35 @@ endif()
set(_aten_ops__srcs
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_add.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_atan2.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_bmm.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_cat.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_clamp.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_div.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_full.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_maximum.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mean.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_minimum.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mm.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mul.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_permute_copy.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_pow.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_remainder.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_rsqrt.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_slice_copy.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_softmax.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_split_with_sizes_copy.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sigmoid.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sub.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_tanh.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_view_copy.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_where.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_embedding.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_gt.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_gelu.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_hardtanh.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_max_pool2d_with_indices.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_view_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp"
Expand Down
171 changes: 171 additions & 0 deletions backends/cadence/hifi/operators/op_bmm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cadence/hifi/kernels/kernels.h>
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
#include <executorch/kernels/portable/cpu/vec_ops.h>
#include <executorch/runtime/kernel/kernel_includes.h>

using Tensor = exec_aten::Tensor;
using exec_aten::ScalarType;
using executorch::runtime::KernelRuntimeContext;
using executorch::runtime::kTensorDimensionLimit;
using executorch::runtime::resize_tensor;
using executorch::runtime::tensors_have_same_dim_order;
using executorch::runtime::tensor_is_default_dim_order;
using torch::executor::check_bmm_args;
using torch::executor::Error;
using torch::executor::get_bmm_out_target_size;

namespace cadence {
namespace impl {
namespace HiFi {
namespace native {

Tensor& bmm_out(
KernelRuntimeContext& ctx,
const Tensor& in,
const Tensor& mat2,
Tensor& out) {
ET_KERNEL_CHECK(ctx, check_bmm_args(in, mat2, out), InvalidArgument, out);

ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, mat2, out), InvalidArgument, out);

ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);

size_t output_ndim = 0;
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
get_bmm_out_target_size(in, mat2, output_sizes, &output_ndim);
ET_KERNEL_CHECK(
ctx,
resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok,
InvalidArgument,
out);

constexpr auto name = "bmm.out";
constexpr int kNnlibMaxDim = 3;

bool optimized = true;

if (out.scalar_type() != ScalarType::Float)
optimized = false;

if (in.dim() > kNnlibMaxDim)
optimized = false;

if (optimized) {
const float* in_data = in.const_data_ptr<float>();
const float* mat2_data = mat2.const_data_ptr<float>();
float* out_data = out.mutable_data_ptr<float>();

int64_t batch_size = in.size(0);
int64_t m = in.size(1);
int64_t n = in.size(2);
int64_t p = mat2.size(2);

WORD32 rows = m;
WORD32 cols1 = n;
WORD32 row_stride1 = n;
WORD32 vec_count = p;
WORD32 vec_offset = n;
WORD32 out_offset = 1;
WORD32 out_stride = p;

WORD32* __restrict__ tmp =
(WORD32* __restrict__)kernels::allocate_temp_memory(
ctx, (batch_size * m * p) * sizeof(float));

ET_KERNEL_CHECK(ctx, tmp != nullptr, MemoryAllocationFailed, out);

tmp[batch_size * m * p] = {0};

WORD32* __restrict__ p_o =
(WORD32* __restrict__)kernels::allocate_temp_memory(
ctx, (batch_size * m * p) * sizeof(WORD32));

ET_KERNEL_CHECK(ctx, p_o != nullptr, MemoryAllocationFailed, out);

for (int i = 0; i < batch_size; ++i) {
const FLOAT32* __restrict__ p_mat1 = in_data + i * m * n;
const FLOAT32* __restrict__ p_vec1 = mat2_data + i * n * p;
FLOAT32* __restrict__ p_out = out_data + i * m * p;
const FLOAT32* __restrict__ p_bias = (const FLOAT32* __restrict__)tmp;

WORD32* p_inp = (WORD32*)p_vec1;

WORD32 p_inp_shape[kNnlibMaxDim];
p_inp_shape[0] = n;
p_inp_shape[1] = p;
p_inp_shape[2] = 1;

WORD32 p_out_shape[kNnlibMaxDim];
p_out_shape[0] = p;
p_out_shape[1] = n;
p_out_shape[2] = 1;

WORD32 p_permute_vec[kNnlibMaxDim] = {1, 0, 2};

WORD32 num_out_dims = kNnlibMaxDim;
WORD32 num_inp_dims = kNnlibMaxDim;

xa_nn_transpose_32_32(
p_o,
p_out_shape,
p_inp,
p_inp_shape,
p_permute_vec,
num_out_dims,
num_inp_dims);

const FLOAT32* __restrict__ p_vec = (const FLOAT32* __restrict__)p_o;

xa_nn_matmul_f32xf32_f32(
p_out,
p_mat1,
p_vec,
p_bias,
rows,
cols1,
row_stride1,
vec_count,
vec_offset,
out_offset,
out_stride);
}

return out;
}

ET_SWITCH_REAL_TYPES_AND(Half, in.scalar_type(), ctx, name, CTYPE, [&]() {
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
const CTYPE* mat2_data = mat2.const_data_ptr<CTYPE>();
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();

int64_t batch_size = in.size(0);
int64_t m = in.size(1);
int64_t n = in.size(2);
int64_t p = mat2.size(2);

for (int i = 0; i < batch_size; ++i) {
const CTYPE* in_data_offset = in_data + i * m * n;
const CTYPE* mat2_data_offset = mat2_data + i * n * p;
CTYPE* out_data_offset = out_data + i * m * p;

torch::executor::vec_matmul<CTYPE>(
out_data_offset, in_data_offset, mat2_data_offset, m, n, p);
}
});

return out;
}

} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
Loading
Loading