Skip to content

Complex Support: bmm #10197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions kernels/optimized/cpu/op_bmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/kernel/kernel_includes.h>

#include <executorch/kernels/optimized/blas/CPUBlas.h>
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

// Performs a batch matrix-matrix product of matrices stored in input and mat2.

Expand Down Expand Up @@ -136,33 +136,32 @@ Error resize_out_tensor(const Tensor& self, const Tensor& mat2, Tensor& out) {

// bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
Tensor& opt_bmm_out(
KernelRuntimeContext& context,
KernelRuntimeContext& ctx,
const Tensor& self,
const Tensor& mat2,
Tensor& out) {
(void)context;
(void)ctx;

ET_KERNEL_CHECK(
context,
ctx,
resize_out_tensor(self, mat2, out) == Error::Ok,
InvalidArgument,
out);
ET_KERNEL_CHECK(
context, check_bmm_out_args(self, mat2, out), InvalidArgument, out);

#define BMM_TENSOR(ctype, dtype) \
case ScalarType::dtype: \
bmm_kernel<ctype>(self, mat2, out); \
break;

auto scalar_type = self.scalar_type();
switch (scalar_type) {
ET_FORALL_REAL_TYPES_AND(Half, BMM_TENSOR)
default:
ET_CHECK_MSG(
false, "Unhandled dtype %" PRId8, static_cast<int8_t>(scalar_type));
ctx, check_bmm_out_args(self, mat2, out), InvalidArgument, out);

constexpr auto name = "bmm.out";
auto self_type = self.scalar_type();

if (executorch::runtime::isComplexType(self_type)) {
ET_SWITCH_COMPLEXH_TYPES(self_type, ctx, name, CTYPE, [&]() {
internal::bmm_out_impl<CTYPE>(self, mat2, out);
});
} else {
ET_SWITCH_REALH_TYPES(self_type, ctx, name, CTYPE, [&]() {
bmm_kernel<CTYPE>(self, mat2, out);
});
}
#undef BMM_TENSOR

return out;
}
Expand Down
1 change: 1 addition & 0 deletions kernels/optimized/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ _OPTIMIZED_ATEN_OPS = (
name = "op_bmm",
deps = [
"//executorch/kernels/optimized:libblas",
"//executorch/kernels/portable/cpu/util:matmul_ops_util",
],
),
op_target(
Expand Down
30 changes: 11 additions & 19 deletions kernels/portable/cpu/op_bmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
*/

#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
#include <executorch/kernels/portable/cpu/vec_ops.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
Expand Down Expand Up @@ -37,26 +36,19 @@ Tensor& bmm_out(
InvalidArgument,
out);

ET_SWITCH_REAL_TYPES_AND(
Half, in.scalar_type(), ctx, "bmm.out", CTYPE, [&]() {
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
const CTYPE* mat2_data = mat2.const_data_ptr<CTYPE>();
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
constexpr auto name = "bmm.out";

int64_t batch_size = in.size(0);
int64_t m = in.size(1);
int64_t n = in.size(2);
int64_t p = mat2.size(2);
auto in_type = in.scalar_type();

for (int i = 0; i < batch_size; ++i) {
const CTYPE* in_data_offset = in_data + i * m * n;
const CTYPE* mat2_data_offset = mat2_data + i * n * p;
CTYPE* out_data_offset = out_data + i * m * p;

vec_matmul<CTYPE>(
out_data_offset, in_data_offset, mat2_data_offset, m, n, p);
}
});
if (executorch::runtime::isComplexType(in_type)) {
ET_SWITCH_COMPLEXH_TYPES(in_type, ctx, name, CTYPE, [&]() {
internal::bmm_out_impl<CTYPE>(in, mat2, out);
});
} else {
ET_SWITCH_REALH_TYPES(in_type, ctx, name, CTYPE, [&]() {
internal::bmm_out_impl<CTYPE>(in, mat2, out);
});
}

return out;
}
Expand Down
31 changes: 31 additions & 0 deletions kernels/portable/cpu/util/matmul_ops_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,36 @@ void get_linear_out_target_size(
Tensor::SizesType* out_sizes,
size_t* out_ndim);

namespace internal {

template <typename CTYPE>
void bmm_out_impl(const Tensor& in, const Tensor& mat2, Tensor& out) {
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
const CTYPE* mat2_data = mat2.const_data_ptr<CTYPE>();
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();

int64_t batch_size = in.size(0);
int64_t m = in.size(1);
int64_t n = in.size(2);
int64_t p = mat2.size(2);

for (int b = 0; b < batch_size; ++b) {
const CTYPE* in_data_offset = in_data + b * m * n;
const CTYPE* mat2_data_offset = mat2_data + b * n * p;
CTYPE* out_data_offset = out_data + b * m * p;

for (const auto i : c10::irange(m)) {
for (const auto j : c10::irange(p)) {
CTYPE sum = static_cast<CTYPE>(0.0);
for (const auto k : c10::irange(n)) {
sum += in_data_offset[i * n + k] * mat2_data_offset[k * p + j];
}
out_data_offset[i * p + j] = sum;
}
}
}
}

} // namespace internal
} // namespace executor
} // namespace torch
67 changes: 66 additions & 1 deletion kernels/test/op_bmm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,61 @@ class OpBmmOutTest : public OperatorTest {

EXPECT_TENSOR_EQ(out, expected);
}

template <typename CTYPE, ScalarType DTYPE>
void test_complex_dtype() {
TensorFactory<DTYPE> tf;
Tensor x = tf.make(
{2, 2, 3},
{CTYPE(1, 1),
CTYPE(2, 2),
CTYPE(3, 3),
CTYPE(4, 4),
CTYPE(5, 5),
CTYPE(6, 6),
CTYPE(7, 7),
CTYPE(8, 8),
CTYPE(9, 9),
CTYPE(10, 10),
CTYPE(11, 11),
CTYPE(12, 12)});
Tensor y = tf.make(
{2, 3, 2},
{CTYPE(2, 1),
CTYPE(4, 2),
CTYPE(6, 3),
CTYPE(8, 4),
CTYPE(10, 5),
CTYPE(12, 6),
CTYPE(14, 7),
CTYPE(16, 8),
CTYPE(18, 9),
CTYPE(20, 10),
CTYPE(22, 11),
CTYPE(24, 12)});
Tensor out = tf.make(
{2, 2, 2},
{CTYPE(0, 0),
CTYPE(0, 0),
CTYPE(0, 0),
CTYPE(0, 0),
CTYPE(0, 0),
CTYPE(0, 0),
CTYPE(0, 0),
CTYPE(0, 0)});
Tensor expected = tf.make(
{2, 2, 2},
{CTYPE(22, 66),
CTYPE(28, 84),
CTYPE(49, 147),
CTYPE(64, 192),
CTYPE(220, 660),
CTYPE(244, 732),
CTYPE(301, 903),
CTYPE(334, 1002)});
op_bmm_out(x, y, out);
EXPECT_TENSOR_CLOSE(out, expected);
}
};

TEST_F(OpBmmOutTest, OutputDim) {
Expand Down Expand Up @@ -132,7 +187,7 @@ TEST_F(OpBmmOutTest, OutputDimFloat) {

/// A generic smoke test that works for any dtype that supports ones() and
/// zeros().
TEST_F(OpBmmOutTest, AllDtypesSupported) {
TEST_F(OpBmmOutTest, AllRealDtypesSupported) {
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
ET_FORALL_REAL_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
Expand All @@ -141,6 +196,16 @@ TEST_F(OpBmmOutTest, AllDtypesSupported) {
// for those types.
}

TEST_F(OpBmmOutTest, AllComplexDtypesSupported) {
#define TEST_ENTRY(ctype, dtype) test_complex_dtype<ctype, ScalarType::dtype>();
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
ET_FORALL_COMPLEX_TYPES(TEST_ENTRY);
} else {
ET_FORALL_COMPLEXH_TYPES(TEST_ENTRY);
}
#undef TEST_ENTRY
}

TEST_F(OpBmmOutTest, EmptyInputWithEmptyOutTensorPasses) {
TensorFactory<ScalarType::Int> tf;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ ATEN_OPS = (
name = "op_bmm",
deps = [
"//executorch/kernels/portable/cpu/util:matmul_ops_util",
":vec_ops",
],
),
op_target(
Expand Down
Loading