Skip to content

Implement portable abs for complex input #8183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions kernels/portable/cpu/op_abs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,48 @@ Tensor& abs_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
out,
"Failed to resize output tensor.");

ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
const bool in_is_complex =
executorch::runtime::isComplexType(in.scalar_type());
ET_KERNEL_CHECK(
ctx,
in_is_complex || tensors_have_same_dtype(in, out),
InvalidArgument,
out);
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] {
apply_unary_map_fn(
[](const CTYPE val_in) {
if (val_in < 0) {
return static_cast<CTYPE>(-val_in);
} else {
return static_cast<CTYPE>(val_in);
}
},
in.const_data_ptr<CTYPE>(),
out.mutable_data_ptr<CTYPE>(),
in.numel());
});
if (in_is_complex) {
// NOTE: Elected not to add COMPLEXH to dtype_util.h for now
// because I am not planning wide rollout of complex support; if
// we do add SupportedTensorDtypes::COMPLEXH support, then we
// should use it here.
ET_SWITCH_COMPLEXH_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE_IN, [&] {
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, "abs.out", CTYPE_OUT, [&] {
apply_unary_map_fn<CTYPE_IN, CTYPE_OUT>(
[](const CTYPE_IN val_in) -> CTYPE_OUT {
return sqrt(
val_in.real_ * val_in.real_ + val_in.imag_ * val_in.imag_);
},
in.const_data_ptr<CTYPE_IN>(),
out.mutable_data_ptr<CTYPE_OUT>(),
in.numel());
});
});
} else {
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] {
apply_unary_map_fn(
[](const CTYPE val_in) {
if (val_in < 0) {
return static_cast<CTYPE>(-val_in);
} else {
return static_cast<CTYPE>(val_in);
}
},
in.const_data_ptr<CTYPE>(),
out.mutable_data_ptr<CTYPE>(),
in.numel());
});
}

return out;
}
Expand Down
26 changes: 26 additions & 0 deletions kernels/test/op_abs_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,39 @@ class OpAbsTest : public OperatorTest {
EXPECT_TENSOR_EQ(out, ret);
EXPECT_TENSOR_EQ(out, expected);
}

template <typename CTYPE, ScalarType DTYPE>
void run_complex_smoke_test() {
TensorFactory<DTYPE> tf;
constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE);
TensorFactory<REAL_DTYPE> tf_out;
using REAL_CTYPE =
typename executorch::runtime::ScalarTypeToCppType<REAL_DTYPE>::type;
Tensor in = tf.make(
{1, 2},
{CTYPE{REAL_CTYPE(3), REAL_CTYPE(4)},
CTYPE{REAL_CTYPE(5), REAL_CTYPE(12)}});
Tensor out = tf_out.zeros({1, 2});
Tensor expected = tf_out.make({1, 2}, {5, 13});
Tensor ret = op_abs_out(in, out);
EXPECT_TENSOR_EQ(out, ret);
EXPECT_TENSOR_CLOSE(out, expected);
}
};

TEST_F(OpAbsTest, SmokeTest) {
#define RUN_SMOKE_TEST(ctype, dtype) run_smoke_test<ScalarType::dtype>();
// TODO: cover all REALHBF16 types with generalized unary function test
// harness.
ET_FORALL_FLOATHBF16_TYPES(RUN_SMOKE_TEST);
#undef RUN_SMOKE_TEST
}

TEST_F(OpAbsTest, ComplexSmokeTest) {
#define RUN_SMOKE_TEST(ctype, dtype) \
run_complex_smoke_test<ctype, ScalarType::dtype>();
ET_FORALL_COMPLEXH_TYPES(RUN_SMOKE_TEST);
#undef RUN_SMOKE_TEST
}

TEST_F(OpAbsTest, MemoryFormatCheck) {
Expand Down
30 changes: 25 additions & 5 deletions runtime/core/exec_aten/util/scalar_type_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,14 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType)

// In this context, "COMPLEX" means complex types based on primitive C types,
// which is why ComplexHalf is not included.
#define ET_FORALL_COMPLEX_TYPES(_) \
_(::torch::executor::complex<float>, ComplexFloat) \
_(::torch::executor::complex<double>, ComplexDouble)
#define ET_FORALL_COMPLEX_TYPES(_) \
_(::executorch::aten::complex<float>, ComplexFloat) \
_(::executorch::aten::complex<double>, ComplexDouble)

#define ET_FORALL_COMPLEXH_TYPES(_) \
_(::executorch::aten::complex<::executorch::aten::Half>, ComplexHalf) \
_(::executorch::aten::complex<float>, ComplexFloat) \
_(::executorch::aten::complex<double>, ComplexDouble)

//
// Utility functions to retrieve metadata for a given ScalarType
Expand Down Expand Up @@ -593,7 +598,7 @@ inline bool isUnderlying(
return type == ::executorch::runtime::toUnderlying(qtype);
}

inline ::executorch::aten::ScalarType toRealValueType(
inline constexpr ::executorch::aten::ScalarType toRealValueType(
::executorch::aten::ScalarType t) {
switch (t) {
case ::executorch::aten::ScalarType::ComplexHalf:
Expand All @@ -607,7 +612,7 @@ inline ::executorch::aten::ScalarType toRealValueType(
}
}

inline ::executorch::aten::ScalarType toComplexType(
inline constexpr ::executorch::aten::ScalarType toComplexType(
::executorch::aten::ScalarType t) {
switch (t) {
case ::executorch::aten::ScalarType::BFloat16:
Expand Down Expand Up @@ -1060,6 +1065,14 @@ struct promote_types {
ET_INTERNAL_SWITCH_CASE( \
::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__)

#define ET_INTERNAL_SWITCH_CASE_COMPLEXH_TYPES(CTYPE_ALIAS, ...) \
ET_INTERNAL_SWITCH_CASE( \
::executorch::aten::ScalarType::ComplexHalf, CTYPE_ALIAS, __VA_ARGS__) \
ET_INTERNAL_SWITCH_CASE( \
::executorch::aten::ScalarType::ComplexFloat, CTYPE_ALIAS, __VA_ARGS__) \
ET_INTERNAL_SWITCH_CASE( \
::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__)

#define ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_TYPES(CTYPE_ALIAS, ...) \
ET_INTERNAL_SWITCH_CASE( \
::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, __VA_ARGS__) \
Expand Down Expand Up @@ -1278,6 +1291,13 @@ struct promote_types {
NAME, \
ET_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__))

#define ET_SWITCH_COMPLEXH_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
ET_INTERNAL_SWITCH( \
TYPE, \
CONTEXT, \
NAME, \
ET_INTERNAL_SWITCH_CASE_COMPLEXH_TYPES(CTYPE_ALIAS, __VA_ARGS__))

#define ET_SWITCH_SCALAR_OBJ_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
ET_INTERNAL_SWITCH( \
TYPE, \
Expand Down
Loading