@@ -51,6 +51,15 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbf16(
51
51
return result;
52
52
}
53
53
54
+ template <typename CTYPE_COMMON, const char * op_name>
55
+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realh (const Tensor& t) {
56
+ CTYPE_COMMON (*result)(const void *) = nullptr ;
57
+ ET_SWITCH_REALH_TYPES (t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
58
+ result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
59
+ });
60
+ return result;
61
+ }
62
+
54
63
template <typename CTYPE_COMMON, const char * op_name>
55
64
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_floathbf16 (
56
65
const Tensor& t) {
@@ -72,6 +81,16 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_intb(const Tensor& t) {
72
81
return result;
73
82
}
74
83
84
+ template <typename CTYPE_COMMON, const char * op_name>
85
+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool (const Tensor& t) {
86
+ ET_CHECK_MSG (
87
+ t.scalar_type () == ScalarType::Bool,
88
+ " Unhandled dtype %s for %s" ,
89
+ ::executorch::runtime::toString (t.scalar_type()),
90
+ op_name);
91
+ return internal::load_and_convert<CTYPE_COMMON, bool >;
92
+ }
93
+
75
94
template <typename CTYPE_COMMON, const char * op_name>
76
95
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte (
77
96
const Tensor& t) {
@@ -137,6 +156,16 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realhbf16(
137
156
return result;
138
157
}
139
158
159
+ template <typename CTYPE_COMMON, const char * op_name>
160
+ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realh (
161
+ const Tensor& t) {
162
+ void (*result)(CTYPE_COMMON, void *) = nullptr ;
163
+ ET_SWITCH_REALH_TYPES (t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
164
+ result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
165
+ });
166
+ return result;
167
+ }
168
+
140
169
template <typename CTYPE_COMMON, const char * op_name>
141
170
store_common_to_tensor_fn<CTYPE_COMMON>
142
171
get_store_common_to_tensor_fn_floathbf16 (const Tensor& t) {
@@ -159,6 +188,17 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_intb(
159
188
return result;
160
189
}
161
190
191
+ template <typename CTYPE_COMMON, const char * op_name>
192
+ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_bool (
193
+ const Tensor& t) {
194
+ ET_CHECK_MSG (
195
+ t.scalar_type () == ScalarType::Bool,
196
+ " Unhandled dtype %s for %s" ,
197
+ ::executorch::runtime::toString (t.scalar_type()),
198
+ op_name);
199
+ return internal::convert_and_store<bool , CTYPE_COMMON>;
200
+ }
201
+
162
202
template <typename CTYPE_COMMON, const char * op_name>
163
203
store_common_to_tensor_fn<CTYPE_COMMON>
164
204
get_store_common_to_tensor_fn_bool_or_byte (const Tensor& t) {
@@ -206,8 +246,10 @@ get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
206
246
enum class SupportedTensorDtypes {
207
247
REALHBBF16,
208
248
REALHBF16,
249
+ REALH,
209
250
FLOATHBF16,
210
251
INTB,
252
+ BOOL,
211
253
BOOL_OR_BYTE,
212
254
SAME_AS_COMPUTE,
213
255
SAME_AS_COMMON,
@@ -224,10 +266,14 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
224
266
return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
225
267
case SupportedTensorDtypes::REALHBF16:
226
268
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
269
+ case SupportedTensorDtypes::REALH:
270
+ return get_load_to_common_fn_realh<CTYPE_COMMON, op_name>(t);
227
271
case SupportedTensorDtypes::FLOATHBF16:
228
272
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
229
273
case SupportedTensorDtypes::INTB:
230
274
return get_load_to_common_fn_intb<CTYPE_COMMON, op_name>(t);
275
+ case SupportedTensorDtypes::BOOL:
276
+ return get_load_to_common_fn_bool<CTYPE_COMMON, op_name>(t);
231
277
case SupportedTensorDtypes::BOOL_OR_BYTE:
232
278
return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
233
279
case SupportedTensorDtypes::SAME_AS_COMPUTE:
@@ -248,10 +294,14 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
248
294
return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
249
295
case SupportedTensorDtypes::REALHBF16:
250
296
return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
297
+ case SupportedTensorDtypes::REALH:
298
+ return get_store_common_to_tensor_fn_realh<CTYPE_COMMON, op_name>(t);
251
299
case SupportedTensorDtypes::FLOATHBF16:
252
300
return get_store_common_to_tensor_fn_floathbf16<CTYPE_COMMON, op_name>(t);
253
301
case SupportedTensorDtypes::INTB:
254
302
return get_store_common_to_tensor_fn_intb<CTYPE_COMMON, op_name>(t);
303
+ case SupportedTensorDtypes::BOOL:
304
+ return get_store_common_to_tensor_fn_bool<CTYPE_COMMON, op_name>(t);
255
305
case SupportedTensorDtypes::BOOL_OR_BYTE:
256
306
return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
257
307
t);
0 commit comments