Skip to content

Add SupportedTensorDtypes::{BOOL,REALH} #9584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -643,13 +643,18 @@ target_link_options_shared_lib(executorch)
# Real integrations should supply their own YAML file that only lists the
# operators necessary for the models that will run.
#
if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED)
# find pytorch lib here to make it available to all
# sub-directories. Find it before including portable so that
# optimized_portabale_kernels can use it.
find_package_torch_headers()
endif()

if(BUILD_EXECUTORCH_PORTABLE_OPS)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels/portable)
endif()

if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED)
# find pytorch lib here to make it available to all sub-directories
find_package_torch_headers()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels/optimized)
endif()

Expand Down
1 change: 1 addition & 0 deletions kernels/optimized/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ message("Generated files ${gen_command_sources}")
list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
add_library(optimized_kernels ${_optimized_kernels__srcs})
target_include_directories(optimized_kernels PRIVATE ${TORCH_INCLUDE_DIRS} "${EXECUTORCH_ROOT}/third-party/pocketfft")
target_compile_definitions(optimized_kernels PRIVATE ET_USE_PYTORCH_HEADERS)
target_link_libraries(
optimized_kernels PUBLIC executorch_core cpublas extension_threadpool
)
Expand Down
2 changes: 2 additions & 0 deletions kernels/portable/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ if(BUILD_OPTIMIZED_PORTABLE_KERNELS)
target_link_libraries(optimized_portable_kernels PRIVATE executorch)
target_link_libraries(optimized_portable_kernels PUBLIC extension_threadpool)
target_compile_options(optimized_portable_kernels PUBLIC ${_common_compile_options})
target_include_directories(optimized_portable_kernels PRIVATE ${TORCH_INCLUDE_DIRS})
target_compile_definitions(optimized_portable_kernels PRIVATE ET_USE_PYTORCH_HEADERS)
install(
TARGETS optimized_portable_kernels
DESTINATION lib
Expand Down
4 changes: 4 additions & 0 deletions kernels/portable/cpu/util/dtype_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ bool check_tensor_dtype(
return executorch::runtime::tensor_is_realhbbf16_type(t);
case SupportedTensorDtypes::REALHBF16:
return executorch::runtime::tensor_is_realhbf16_type(t);
case SupportedTensorDtypes::REALH:
Copy link
Contributor

@JacobSzwejbka JacobSzwejbka Mar 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the h in realh? half?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it matches the naming scheme in

#define ET_SWITCH_REALH_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \

return executorch::runtime::tensor_is_realh_type(t);
case SupportedTensorDtypes::FLOATHBF16:
return executorch::runtime::tensor_is_floating_type(t);
case SupportedTensorDtypes::INTB:
return executorch::runtime::tensor_is_integral_type(t, true);
case SupportedTensorDtypes::BOOL:
return executorch::runtime::tensor_is_type(t, ScalarType::Bool);
case SupportedTensorDtypes::BOOL_OR_BYTE:
return (executorch::runtime::tensor_is_type(
t, ScalarType::Bool, ScalarType::Byte));
Expand Down
86 changes: 55 additions & 31 deletions kernels/portable/cpu/util/dtype_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbf16(
return result;
}

template <typename CTYPE_COMMON, const char* op_name>
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realh(const Tensor& t) {
CTYPE_COMMON (*result)(const void*) = nullptr;
ET_SWITCH_REALH_TYPES(t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
});
return result;
}

template <typename CTYPE_COMMON, const char* op_name>
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_floathbf16(
const Tensor& t) {
Expand All @@ -72,6 +81,16 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_intb(const Tensor& t) {
return result;
}

template <typename CTYPE_COMMON, const char* op_name>
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool(const Tensor& t) {
ET_CHECK_MSG(
t.scalar_type() == ScalarType::Bool,
"Unhandled dtype %s for %s",
::executorch::runtime::toString(t.scalar_type()),
op_name);
return internal::load_and_convert<CTYPE_COMMON, bool>;
}

template <typename CTYPE_COMMON, const char* op_name>
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
const Tensor& t) {
Expand All @@ -86,12 +105,6 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
template <typename CTYPE_COMMON, const char* op_name>
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_compute(
const Tensor& t) {
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
ET_CHECK_MSG(
t.scalar_type() == common_scalar_type,
"Unhandled dtype %s for %s",
::executorch::runtime::toString(common_scalar_type),
op_name);
return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
}

Expand Down Expand Up @@ -143,6 +156,16 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realhbf16(
return result;
}

template <typename CTYPE_COMMON, const char* op_name>
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realh(
const Tensor& t) {
void (*result)(CTYPE_COMMON, void*) = nullptr;
ET_SWITCH_REALH_TYPES(t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
});
return result;
}

template <typename CTYPE_COMMON, const char* op_name>
store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_floathbf16(const Tensor& t) {
Expand All @@ -165,6 +188,17 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_intb(
return result;
}

template <typename CTYPE_COMMON, const char* op_name>
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_bool(
const Tensor& t) {
ET_CHECK_MSG(
t.scalar_type() == ScalarType::Bool,
"Unhandled dtype %s for %s",
::executorch::runtime::toString(t.scalar_type()),
op_name);
return internal::convert_and_store<bool, CTYPE_COMMON>;
}

template <typename CTYPE_COMMON, const char* op_name>
store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
Expand All @@ -179,33 +213,13 @@ get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
template <typename CTYPE_COMMON, const char* op_name>
store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) {
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
ET_CHECK_MSG(
t.scalar_type() == common_scalar_type,
"Unhandled dtype %s for %s",
::executorch::runtime::toString(common_scalar_type),
op_name);
return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
// We already validate tensor types earlier in the process, so at
// this phase, treat same_as_compute the same as our widest
// SupportedTensorDtypes set.
return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
}

template <
typename CTYPE_COMMON,
const char* op_name,
std::enable_if_t<std::is_same_v<CTYPE_COMMON, float>, bool> = true>
store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
void (*result)(CTYPE_COMMON, void*) = nullptr;
ET_SWITCH_THREE_TYPES(
Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() {
result = internal::convert_and_store<CTYPE, CTYPE_COMMON>;
});
return result;
}

template <
typename CTYPE_COMMON,
const char* op_name,
std::enable_if_t<!std::is_same_v<CTYPE_COMMON, float>, bool> = true>
template <typename CTYPE_COMMON, const char* op_name>
store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
return get_store_common_to_tensor_fn_same_as_compute<CTYPE_COMMON, op_name>(
Expand All @@ -217,8 +231,10 @@ get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
enum class SupportedTensorDtypes {
REALHBBF16,
REALHBF16,
REALH,
FLOATHBF16,
INTB,
BOOL,
BOOL_OR_BYTE,
SAME_AS_COMPUTE,
SAME_AS_COMMON,
Expand All @@ -235,10 +251,14 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::REALHBF16:
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::REALH:
return get_load_to_common_fn_realh<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::FLOATHBF16:
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::INTB:
return get_load_to_common_fn_intb<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::BOOL:
return get_load_to_common_fn_bool<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::BOOL_OR_BYTE:
return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::SAME_AS_COMPUTE:
Expand All @@ -259,10 +279,14 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::REALHBF16:
return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::REALH:
return get_store_common_to_tensor_fn_realh<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::FLOATHBF16:
return get_store_common_to_tensor_fn_floathbf16<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::INTB:
return get_store_common_to_tensor_fn_intb<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::BOOL:
return get_store_common_to_tensor_fn_bool<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::BOOL_OR_BYTE:
return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
t);
Expand Down
Loading
Loading