@@ -72,6 +72,16 @@ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_intb(const Tensor& t) {
7272 return result;
7373}
7474
75+ template <typename CTYPE_COMPUTE, const char * op_name>
76+ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool (const Tensor& t) {
77+ ET_CHECK_MSG (
78+ t.scalar_type () == ScalarType::Bool,
79+ " Unhandled dtype %s for %s" ,
80+ ::executorch::runtime::toString (t.scalar_type()),
81+ op_name);
82+ return internal::load_and_convert<CTYPE_COMPUTE, bool >;
83+ }
84+
7585template <typename CTYPE_COMPUTE, const char * op_name>
7686load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool_or_byte (
7787 const Tensor& t) {
@@ -165,6 +175,17 @@ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_intb(
165175 return result;
166176}
167177
178+ template <typename CTYPE_COMPUTE, const char * op_name>
179+ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_bool (
180+ const Tensor& t) {
181+ ET_CHECK_MSG (
182+ t.scalar_type () == ScalarType::Bool,
183+ " Unhandled dtype %s for %s" ,
184+ ::executorch::runtime::toString (t.scalar_type()),
185+ op_name);
186+ return internal::convert_and_store<bool , CTYPE_COMPUTE>;
187+ }
188+
168189template <typename CTYPE_COMPUTE, const char * op_name>
169190store_compute_to_tensor_fn<CTYPE_COMPUTE>
170191get_store_compute_to_tensor_fn_bool_or_byte (const Tensor& t) {
@@ -219,6 +240,7 @@ enum class SupportedTensorDtypes {
219240 REALHBF16,
220241 FLOATHBF16,
221242 INTB,
243+ BOOL,
222244 BOOL_OR_BYTE,
223245 // DEPRECATED: not likely to be correct; use SAME_AS_COMMON.
224246 SAME_AS_COMPUTE,
@@ -240,6 +262,8 @@ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_impl(
240262 return get_load_to_compute_fn_realhbf16<CTYPE_COMPUTE, op_name>(t);
241263 case SupportedTensorDtypes::INTB:
242264 return get_load_to_compute_fn_intb<CTYPE_COMPUTE, op_name>(t);
265+ case SupportedTensorDtypes::BOOL:
266+ return get_load_to_compute_fn_bool<CTYPE_COMPUTE, op_name>(t);
243267 case SupportedTensorDtypes::BOOL_OR_BYTE:
244268 return get_load_to_compute_fn_bool_or_byte<CTYPE_COMPUTE, op_name>(t);
245269 case SupportedTensorDtypes::SAME_AS_COMPUTE:
@@ -271,6 +295,8 @@ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn(
271295 t);
272296 case SupportedTensorDtypes::INTB:
273297 return get_store_compute_to_tensor_fn_intb<CTYPE_COMPUTE, op_name>(t);
298+ case SupportedTensorDtypes::BOOL:
299+ return get_store_compute_to_tensor_fn_bool<CTYPE_COMPUTE, op_name>(t);
274300 case SupportedTensorDtypes::BOOL_OR_BYTE:
275301 return get_store_compute_to_tensor_fn_bool_or_byte<
276302 CTYPE_COMPUTE,
@@ -318,12 +344,14 @@ bool check_tensor_dtype(
318344 const ScalarType compute_type);
319345
320346// / Return the one output type we are willing to emit specialized code
321- // / to handle, given a compute type of CTYPE_COMMON and supported
347+ // / to handle, given a compute type of CTYPE_COMPUTE and supported
322348// / output types of out_dtypes.
323349template <typename CTYPE_COMPUTE>
324350inline constexpr ScalarType specialized_output_scalar_type (
325351 SupportedTensorDtypes out_dtypes) {
326352 switch (out_dtypes) {
353+ case SupportedTensorDtypes::BOOL:
354+ return ScalarType::Bool;
327355 case SupportedTensorDtypes::BOOL_OR_BYTE:
328356 return ScalarType::Bool;
329357 case SupportedTensorDtypes::REALHBBF16:
0 commit comments