@@ -348,9 +348,14 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType)
348
348
349
349
// In this context, "COMPLEX" means complex types based on primitive C types,
350
350
// which is why ComplexHalf is not included.
351
- #define ET_FORALL_COMPLEX_TYPES (_ ) \
352
- _ (::torch::executor::complex <float >, ComplexFloat) \
353
- _ (::torch::executor::complex <double >, ComplexDouble)
351
+ #define ET_FORALL_COMPLEX_TYPES (_ ) \
352
+ _ (::executorch::aten::complex <float >, ComplexFloat) \
353
+ _ (::executorch::aten::complex <double >, ComplexDouble)
354
+
355
+ #define ET_FORALL_COMPLEXH_TYPES (_ ) \
356
+ _ (::executorch::aten::complex <::executorch::aten::Half>, ComplexHalf) \
357
+ _ (::executorch::aten::complex <float >, ComplexFloat) \
358
+ _ (::executorch::aten::complex <double >, ComplexDouble)
354
359
355
360
//
356
361
// Utility functions to retrieve metadata for a given ScalarType
@@ -593,7 +598,7 @@ inline bool isUnderlying(
593
598
return type == ::executorch::runtime::toUnderlying (qtype);
594
599
}
595
600
596
- inline ::executorch::aten::ScalarType toRealValueType (
601
+ inline constexpr ::executorch::aten::ScalarType toRealValueType (
597
602
::executorch::aten::ScalarType t) {
598
603
switch (t) {
599
604
case ::executorch::aten::ScalarType::ComplexHalf:
@@ -607,7 +612,7 @@ inline ::executorch::aten::ScalarType toRealValueType(
607
612
}
608
613
}
609
614
610
- inline ::executorch::aten::ScalarType toComplexType (
615
+ inline constexpr ::executorch::aten::ScalarType toComplexType (
611
616
::executorch::aten::ScalarType t) {
612
617
switch (t) {
613
618
case ::executorch::aten::ScalarType::BFloat16:
@@ -1060,6 +1065,14 @@ struct promote_types {
1060
1065
ET_INTERNAL_SWITCH_CASE( \
1061
1066
::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__)
1062
1067
1068
+ #define ET_INTERNAL_SWITCH_CASE_COMPLEXH_TYPES (CTYPE_ALIAS, ...) \
1069
+ ET_INTERNAL_SWITCH_CASE ( \
1070
+ ::executorch::aten::ScalarType::ComplexHalf, CTYPE_ALIAS, __VA_ARGS__) \
1071
+ ET_INTERNAL_SWITCH_CASE( \
1072
+ ::executorch::aten::ScalarType::ComplexFloat, CTYPE_ALIAS, __VA_ARGS__) \
1073
+ ET_INTERNAL_SWITCH_CASE( \
1074
+ ::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__)
1075
+
1063
1076
#define ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_TYPES (CTYPE_ALIAS, ...) \
1064
1077
ET_INTERNAL_SWITCH_CASE ( \
1065
1078
::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, __VA_ARGS__) \
@@ -1278,6 +1291,13 @@ struct promote_types {
1278
1291
NAME, \
1279
1292
ET_INTERNAL_SWITCH_CASE_COMPLEX_TYPES (CTYPE_ALIAS, __VA_ARGS__))
1280
1293
1294
+ #define ET_SWITCH_COMPLEXH_TYPES (TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1295
+ ET_INTERNAL_SWITCH ( \
1296
+ TYPE, \
1297
+ CONTEXT, \
1298
+ NAME, \
1299
+ ET_INTERNAL_SWITCH_CASE_COMPLEXH_TYPES (CTYPE_ALIAS, __VA_ARGS__))
1300
+
1281
1301
#define ET_SWITCH_SCALAR_OBJ_TYPES (TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1282
1302
ET_INTERNAL_SWITCH ( \
1283
1303
TYPE, \
0 commit comments