@@ -72,6 +72,16 @@ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_intb(const Tensor& t) {
7272  return  result;
7373}
7474
75+ template  <typename  CTYPE_COMPUTE, const  char * op_name>
76+ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool (const  Tensor& t) {
77+   ET_CHECK_MSG (
78+       t.scalar_type () == ScalarType::Bool,
79+       " Unhandled dtype %s for %s" 
80+       ::executorch::runtime::toString (t.scalar_type()),
81+       op_name);
82+   return  internal::load_and_convert<CTYPE_COMPUTE, bool >;
83+ }
84+ 
7585template  <typename  CTYPE_COMPUTE, const  char * op_name>
7686load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_bool_or_byte (
7787    const  Tensor& t) {
@@ -165,6 +175,17 @@ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_intb(
165175  return  result;
166176}
167177
178+ template  <typename  CTYPE_COMPUTE, const  char * op_name>
179+ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn_bool (
180+     const  Tensor& t) {
181+   ET_CHECK_MSG (
182+       t.scalar_type () == ScalarType::Bool,
183+       " Unhandled dtype %s for %s" 
184+       ::executorch::runtime::toString (t.scalar_type()),
185+       op_name);
186+   return  internal::convert_and_store<bool , CTYPE_COMPUTE>;
187+ }
188+ 
168189template  <typename  CTYPE_COMPUTE, const  char * op_name>
169190store_compute_to_tensor_fn<CTYPE_COMPUTE>
170191get_store_compute_to_tensor_fn_bool_or_byte (const  Tensor& t) {
@@ -219,6 +240,7 @@ enum class SupportedTensorDtypes {
219240  REALHBF16,
220241  FLOATHBF16,
221242  INTB,
243+   BOOL,
222244  BOOL_OR_BYTE,
223245  //  DEPRECATED: not likely to be correct; use SAME_AS_COMMON.
224246  SAME_AS_COMPUTE,
@@ -240,6 +262,8 @@ load_to_compute_fn<CTYPE_COMPUTE> get_load_to_compute_fn_impl(
240262      return  get_load_to_compute_fn_realhbf16<CTYPE_COMPUTE, op_name>(t);
241263    case  SupportedTensorDtypes::INTB:
242264      return  get_load_to_compute_fn_intb<CTYPE_COMPUTE, op_name>(t);
265+     case  SupportedTensorDtypes::BOOL:
266+       return  get_load_to_compute_fn_bool<CTYPE_COMPUTE, op_name>(t);
243267    case  SupportedTensorDtypes::BOOL_OR_BYTE:
244268      return  get_load_to_compute_fn_bool_or_byte<CTYPE_COMPUTE, op_name>(t);
245269    case  SupportedTensorDtypes::SAME_AS_COMPUTE:
@@ -271,6 +295,8 @@ store_compute_to_tensor_fn<CTYPE_COMPUTE> get_store_compute_to_tensor_fn(
271295          t);
272296    case  SupportedTensorDtypes::INTB:
273297      return  get_store_compute_to_tensor_fn_intb<CTYPE_COMPUTE, op_name>(t);
298+     case  SupportedTensorDtypes::BOOL:
299+       return  get_store_compute_to_tensor_fn_bool<CTYPE_COMPUTE, op_name>(t);
274300    case  SupportedTensorDtypes::BOOL_OR_BYTE:
275301      return  get_store_compute_to_tensor_fn_bool_or_byte<
276302          CTYPE_COMPUTE,
@@ -318,12 +344,14 @@ bool check_tensor_dtype(
318344    const  ScalarType compute_type);
319345
320346// / Return the one output type we are willing to emit specialized code
321- // / to handle, given a compute type of CTYPE_COMMON  and supported
347+ // / to handle, given a compute type of CTYPE_COMPUTE  and supported
322348// / output types of out_dtypes.
323349template  <typename  CTYPE_COMPUTE>
324350inline  constexpr  ScalarType specialized_output_scalar_type (
325351    SupportedTensorDtypes out_dtypes) {
326352  switch  (out_dtypes) {
353+     case  SupportedTensorDtypes::BOOL:
354+       return  ScalarType::Bool;
327355    case  SupportedTensorDtypes::BOOL_OR_BYTE:
328356      return  ScalarType::Bool;
329357    case  SupportedTensorDtypes::REALHBBF16:
0 commit comments