diff --git a/clang/lib/DPCT/APINames.inc b/clang/lib/DPCT/APINames.inc index 0388fccade66..cb915a73ec6d 100644 --- a/clang/lib/DPCT/APINames.inc +++ b/clang/lib/DPCT/APINames.inc @@ -384,26 +384,32 @@ ENTRY(atomicXor, atomicXor, true, NO_FLAG, P0, "Successful: DPCT1039") // Half Arithmetic Functions ENTRY(__habs, __habs, true, NO_FLAG, P4, "Successful") ENTRY(__h2div, __h2div, true, NO_FLAG, P4, "Successful") +ENTRY(__hadd_rn, __hadd_rn, true, NO_FLAG, P4, "Successful") ENTRY(__hadd_sat, __hadd_sat, true, NO_FLAG, P4, "Successful") ENTRY(__hdiv, __hdiv, true, NO_FLAG, P4, "Successful") ENTRY(__hfma, __hfma, true, NO_FLAG, P0, "Successful") ENTRY(__hfma_sat, __hfma_sat, true, NO_FLAG, P4, "Successful") ENTRY(__hmul, __hmul, true, NO_FLAG, P4, "Successful") +ENTRY(__hmul_rn, __hmul_rn, true, NO_FLAG, P4, "Successful") ENTRY(__hmul_sat, __hmul_sat, true, NO_FLAG, P4, "Successful") ENTRY(__hneg, __hneg, true, NO_FLAG, P4, "Successful") ENTRY(__hsub, __hsub, true, NO_FLAG, P4, "Successful") +ENTRY(__hsub_rn, __hsub_rn, true, NO_FLAG, P4, "Successful") ENTRY(__hsub_sat, __hsub_sat, true, NO_FLAG, P4, "Successful") // Half2 Arithmetic Functions ENTRY(__habs2, __habs2, true, NO_FLAG, P4, "Successful") ENTRY(__hadd2, __hadd2, true, NO_FLAG, P0, "Successful") +ENTRY(__hadd2_rn, __hadd2_rn, true, NO_FLAG, P4, "Successful") ENTRY(__hadd2_sat, __hadd2_sat, true, NO_FLAG, P4, "Successful") ENTRY(__hfma2, __hfma2, true, NO_FLAG, P0, "Successful") ENTRY(__hfma2_sat, __hfma2_sat, true, NO_FLAG, P4, "Successful") ENTRY(__hmul2, __hmul2, true, NO_FLAG, P4, "Successful") +ENTRY(__hmul2_rn, __hmul2_rn, true, NO_FLAG, P4, "Successful") ENTRY(__hmul2_sat, __hmul2_sat, true, NO_FLAG, P4, "Successful") ENTRY(__hneg2, __hneg2, true, NO_FLAG, P4, "Successful") ENTRY(__hsub2, __hsub2, true, NO_FLAG, P4, "Successful") +ENTRY(__hsub2_rn, __hsub2_rn, true, NO_FLAG, P4, "Successful") ENTRY(__hsub2_sat, __hsub2_sat, true, NO_FLAG, P4, "Successful") // Half Comparison Functions @@ -1809,7 +1815,7 @@ ENTRY(__bfloat162ushort_rz, __bfloat162ushort_rz, false, NO_FLAG, P4, "comment") ENTRY(__bfloat16_as_short, __bfloat16_as_short, false, NO_FLAG, P4, "comment") ENTRY(__bfloat16_as_ushort, __bfloat16_as_ushort, false, NO_FLAG, P4, "comment") ENTRY(__double2bfloat16, __double2bfloat16, false, NO_FLAG, P4, "comment") -ENTRY(__double2half, __double2half, false, NO_FLAG, P4, "comment") +ENTRY(__double2half, __double2half, true, NO_FLAG, P4, "Successful") ENTRY(__float22bfloat162_rn, __float22bfloat162_rn, true, NO_FLAG, P4, "Successful") ENTRY(__float2bfloat162_rn, __float2bfloat162_rn, false, NO_FLAG, P4, "comment") ENTRY(__float2bfloat16_rd, __float2bfloat16_rd, false, NO_FLAG, P4, "comment") @@ -1822,20 +1828,20 @@ ENTRY(__fmaf_ieee_rn, __fmaf_ieee_rn, false, NO_FLAG, P4, "comment") ENTRY(__fmaf_ieee_ru, __fmaf_ieee_ru, false, NO_FLAG, P4, "comment") ENTRY(__fmaf_ieee_rz, __fmaf_ieee_rz, false, NO_FLAG, P4, "comment") ENTRY(__halves2bfloat162, __halves2bfloat162, false, NO_FLAG, P4, "comment") -ENTRY(__hcmadd, __hcmadd, false, NO_FLAG, P4, "comment") -ENTRY(__hfma2_relu, __hfma2_relu, false, NO_FLAG, P4, "comment") -ENTRY(__hfma_relu, __hfma_relu, false, NO_FLAG, P4, "comment") +ENTRY(__hcmadd, __hcmadd, true, NO_FLAG, P4, "Successful") +ENTRY(__hfma2_relu, __hfma2_relu, true, NO_FLAG, P4, "Successful") +ENTRY(__hfma_relu, __hfma_relu, true, NO_FLAG, P4, "Successful") ENTRY(__high2bfloat16, __high2bfloat16, false, NO_FLAG, P4, "comment") ENTRY(__high2bfloat162, __high2bfloat162, false, NO_FLAG, P4, "comment") ENTRY(__highs2bfloat162, __highs2bfloat162, false, NO_FLAG, P4, "comment") -ENTRY(__hmax, __hmax, false, NO_FLAG, P4, "comment") -ENTRY(__hmax2, __hmax2, false, NO_FLAG, P4, "comment") -ENTRY(__hmax2_nan, __hmax2_nan, false, NO_FLAG, P4, "comment") -ENTRY(__hmax_nan, __hmax_nan, false, NO_FLAG, P4, "comment") -ENTRY(__hmin, __hmin, false, NO_FLAG, P4, "comment") -ENTRY(__hmin2, __hmin2, false, NO_FLAG, P4, "comment") -ENTRY(__hmin2_nan, __hmin2_nan, false, NO_FLAG, P4, "comment") -ENTRY(__hmin_nan, __hmin_nan, false, NO_FLAG, P4, "comment") +ENTRY(__hmax, __hmax, true, NO_FLAG, P4, "Successful") +ENTRY(__hmax2, __hmax2, true, NO_FLAG, P4, "Successful") +ENTRY(__hmax2_nan, __hmax2_nan, true, NO_FLAG, P4, "Successful") +ENTRY(__hmax_nan, __hmax_nan, true, NO_FLAG, P4, "Successful") +ENTRY(__hmin, __hmin, true, NO_FLAG, P4, "Successful") +ENTRY(__hmin2, __hmin2, true, NO_FLAG, P4, "Successful") +ENTRY(__hmin2_nan, __hmin2_nan, true, NO_FLAG, P4, "Successful") +ENTRY(__hmin_nan, __hmin_nan, true, NO_FLAG, P4, "Successful") ENTRY(__int2bfloat16_rd, __int2bfloat16_rd, false, NO_FLAG, P4, "comment") ENTRY(__int2bfloat16_rn, __int2bfloat16_rn, false, NO_FLAG, P4, "comment") ENTRY(__int2bfloat16_ru, __int2bfloat16_ru, false, NO_FLAG, P4, "comment") @@ -1877,18 +1883,18 @@ ENTRY(__ushort2bfloat16_rn, __ushort2bfloat16_rn, false, NO_FLAG, P4, "comment") ENTRY(__ushort2bfloat16_ru, __ushort2bfloat16_ru, false, NO_FLAG, P4, "comment") ENTRY(__ushort2bfloat16_rz, __ushort2bfloat16_rz, false, NO_FLAG, P4, "comment") ENTRY(__ushort_as_bfloat16, __ushort_as_bfloat16, false, NO_FLAG, P4, "comment") -ENTRY(__heq2_mask, __heq2_mask, false, NO_FLAG, P4, "comment") -ENTRY(__hequ2_mask, __hequ2_mask, false, NO_FLAG, P4, "comment") -ENTRY(__hge2_mask, __hge2_mask, false, NO_FLAG, P4, "comment") -ENTRY(__hgeu2_mask, __hgeu2_mask, false, NO_FLAG, P4, "comment") -ENTRY(__hgt2_mask, __hgt2_mask, false, NO_FLAG, P4, "comment") -ENTRY(__hgtu2_mask, __hgtu2_mask, false, NO_FLAG, P4, "comment") -ENTRY(__hle2_mask, __hle2_mask, false, NO_FLAG, P4, "comment") -ENTRY(__hleu2_mask, __hleu2_mask, false, NO_FLAG, P4, "comment") -ENTRY(__hlt2_mask, __hlt2_mask, false, NO_FLAG, P4, "comment") -ENTRY(__hltu2_mask, __hltu2_mask, false, NO_FLAG, P4, "comment") -ENTRY(__hne2_mask, __hne2_mask, false, NO_FLAG, P4, "comment") -ENTRY(__hneu2_mask, __hneu2_mask, false, NO_FLAG, P4, "comment") +ENTRY(__heq2_mask, __heq2_mask, true, NO_FLAG, P4, "Successful") +ENTRY(__hequ2_mask, __hequ2_mask, true, NO_FLAG, P4, "Successful") +ENTRY(__hge2_mask, __hge2_mask, true, NO_FLAG, P4, "Successful") +ENTRY(__hgeu2_mask, __hgeu2_mask, true, NO_FLAG, P4, "Successful") +ENTRY(__hgt2_mask, __hgt2_mask, true, NO_FLAG, P4, "Successful") +ENTRY(__hgtu2_mask, __hgtu2_mask, true, NO_FLAG, P4, "Successful") +ENTRY(__hle2_mask, __hle2_mask, true, NO_FLAG, P4, "Successful") +ENTRY(__hleu2_mask, __hleu2_mask, true, NO_FLAG, P4, "Successful") +ENTRY(__hlt2_mask, __hlt2_mask, true, NO_FLAG, P4, "Successful") +ENTRY(__hltu2_mask, __hltu2_mask, true, NO_FLAG, P4, "Successful") +ENTRY(__hne2_mask, __hne2_mask, true, NO_FLAG, P4, "Successful") +ENTRY(__hneu2_mask, __hneu2_mask, true, NO_FLAG, P4, "Successful") ENTRY(__viaddmax_s16x2, __viaddmax_s16x2, false, NO_FLAG, P4, "comment") ENTRY(__viaddmax_s16x2_relu, __viaddmax_s16x2_relu, false, NO_FLAG, P4, "comment") ENTRY(__viaddmax_s32, __viaddmax_s32, false, NO_FLAG, P4, "comment") diff --git a/clang/lib/DPCT/APINamesMath.inc b/clang/lib/DPCT/APINamesMath.inc index c9d86d7f2228..4ef9b8157f76 100644 --- a/clang/lib/DPCT/APINamesMath.inc +++ b/clang/lib/DPCT/APINamesMath.inc @@ -352,6 +352,7 @@ ENTRY_OPERATOR("__ddiv_rz", BinaryOperatorKind::BO_Div) // Half Precision Conversion And Data Movement +ENTRY_REWRITE("__double2half") ENTRY_TYPECAST("__float22half2_rn") ENTRY_TYPECAST("__float2half") ENTRY_TYPECAST("__float2half2_rn") @@ -587,15 +588,24 @@ ENTRY_REWRITE("__vsubus2") ENTRY_REWRITE("__vsubus4") // Half Arithmetic Functions +ENTRY_REWRITE("__hadd_rn") ENTRY_REWRITE("__hadd_sat") +ENTRY_REWRITE("__hfma_relu") ENTRY_REWRITE("__hfma_sat") +ENTRY_REWRITE("__hmul_rn") ENTRY_REWRITE("__hmul_sat") +ENTRY_REWRITE("__hsub_rn") ENTRY_REWRITE("__hsub_sat") // Half2 Arithmetic Functions +ENTRY_REWRITE("__hadd2_rn") ENTRY_REWRITE("__hadd2_sat") +ENTRY_REWRITE("__hcmadd") +ENTRY_REWRITE("__hfma2_relu") ENTRY_REWRITE("__hfma2_sat") +ENTRY_REWRITE("__hmul2_rn") ENTRY_REWRITE("__hmul2_sat") +ENTRY_REWRITE("__hsub2_rn") ENTRY_REWRITE("__hsub2_sat") // Half Comparison Functions @@ -604,6 +614,10 @@ ENTRY_REWRITE("__hgeu") ENTRY_REWRITE("__hgtu") ENTRY_REWRITE("__hleu") ENTRY_REWRITE("__hltu") +ENTRY_REWRITE("__hmax") +ENTRY_REWRITE("__hmax_nan") +ENTRY_REWRITE("__hmin") +ENTRY_REWRITE("__hmin_nan") ENTRY_REWRITE("__hneu") // Half2 Comparison Functions @@ -620,18 +634,34 @@ ENTRY_REWRITE("__hbltu2") ENTRY_REWRITE("__hbne2") ENTRY_REWRITE("__hbneu2") ENTRY_REWRITE("__heq2") +ENTRY_REWRITE("__heq2_mask") ENTRY_REWRITE("__hequ2") +ENTRY_REWRITE("__hequ2_mask") ENTRY_REWRITE("__hge2") +ENTRY_REWRITE("__hge2_mask") ENTRY_REWRITE("__hgeu2") +ENTRY_REWRITE("__hgeu2_mask") ENTRY_REWRITE("__hgt2") +ENTRY_REWRITE("__hgt2_mask") ENTRY_REWRITE("__hgtu2") +ENTRY_REWRITE("__hgtu2_mask") ENTRY_REWRITE("__hisnan2") ENTRY_REWRITE("__hle2") +ENTRY_REWRITE("__hle2_mask") ENTRY_REWRITE("__hleu2") +ENTRY_REWRITE("__hleu2_mask") ENTRY_REWRITE("__hlt2") +ENTRY_REWRITE("__hlt2_mask") ENTRY_REWRITE("__hltu2") +ENTRY_REWRITE("__hltu2_mask") +ENTRY_REWRITE("__hmax2") +ENTRY_REWRITE("__hmax2_nan") +ENTRY_REWRITE("__hmin2") +ENTRY_REWRITE("__hmin2_nan") ENTRY_REWRITE("__hne2") +ENTRY_REWRITE("__hne2_mask") ENTRY_REWRITE("__hneu2") +ENTRY_REWRITE("__hneu2_mask") // Half2 Math Functions ENTRY_UNSUPPORTED("h2rcp") diff --git a/clang/lib/DPCT/APINamesMathRewrite.inc b/clang/lib/DPCT/APINamesMathRewrite.inc index 6fefd6087cc9..6f71c2fd95b7 100644 --- a/clang/lib/DPCT/APINamesMathRewrite.inc +++ b/clang/lib/DPCT/APINamesMathRewrite.inc @@ -32,6 +32,19 @@ MATH_API_REWRITER_HOST_DEVICE( EMPTY_FACTORY_ENTRY("std::abs")))) // Half Arithmetic Functions +MATH_API_REWRITER_DEVICE( + "__hadd_rn", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hadd_rn"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hadd_rn", CALL(MapNames::getClNamespace() + + "ext::intel::math::hadd", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hadd_rn"), + BINARY_OP_FACTORY_ENTRY("__hadd_rn", BinaryOperatorKind::BO_Add, ARG(0), + ARG(1)))) + MATH_API_REWRITER_DEVICE( "__hadd_sat", MATH_API_DEVICE_NODES( @@ -63,6 +76,22 @@ MATH_API_REWRITER_DEVICE( makeCallArgCreatorWithCall(0), makeCallArgCreatorWithCall(1)))) +MATH_API_REWRITER_DEVICE( + "__hfma_relu", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hfma_relu"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hfma_relu", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hfma_relu", + ARG(0), ARG(1), ARG(2)))), + EMPTY_FACTORY_ENTRY("__hfma_relu"), + CALL_FACTORY_ENTRY("__hfma_relu", + CALL(MapNames::getDpctNamespace() + "relu", + CALL(MapNames::getClNamespace() + "fma", ARG(0), + ARG(1), ARG(2)))))) + MATH_API_REWRITER_DEVICE( "__hfma_sat", MATH_API_DEVICE_NODES( @@ -95,6 +124,19 @@ MATH_API_REWRITER_DEVICE( makeCallArgCreatorWithCall(0), makeCallArgCreatorWithCall(1)))) +MATH_API_REWRITER_DEVICE( + "__hmul_rn", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmul_rn"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmul_rn", CALL(MapNames::getClNamespace() + + "ext::intel::math::hmul", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmul_rn"), + BINARY_OP_FACTORY_ENTRY("__hmul_rn", BinaryOperatorKind::BO_Mul, ARG(0), + ARG(1)))) + MATH_API_REWRITER_DEVICE( "__hmul_sat", MATH_API_DEVICE_NODES( @@ -139,6 +181,19 @@ MATH_API_REWRITER_DEVICE( makeCallArgCreatorWithCall(0), makeCallArgCreatorWithCall(1)))) +MATH_API_REWRITER_DEVICE( + "__hsub_rn", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hsub_rn"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hsub_rn", CALL(MapNames::getClNamespace() + + "ext::intel::math::hsub", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hsub_rn"), + BINARY_OP_FACTORY_ENTRY("__hsub_rn", BinaryOperatorKind::BO_Sub, ARG(0), + ARG(1)))) + MATH_API_REWRITER_DEVICE( "__hsub_sat", MATH_API_DEVICE_NODES( @@ -199,6 +254,19 @@ MATH_API_REWRITER_DEVICE( makeCallArgCreatorWithCall(0), makeCallArgCreatorWithCall(1)))) +MATH_API_REWRITER_DEVICE( + "__hadd2_rn", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hadd2_rn"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hadd2_rn", CALL(MapNames::getClNamespace() + + "ext::intel::math::hadd2", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hadd2_rn"), + BINARY_OP_FACTORY_ENTRY("__hadd2_rn", BinaryOperatorKind::BO_Add, + ARG(0), ARG(1)))) + MATH_API_REWRITER_DEVICE( "__hadd2_sat", MATH_API_DEVICE_NODES( @@ -216,6 +284,36 @@ MATH_API_REWRITER_DEVICE( BO(BinaryOperatorKind::BO_Add, ARG(0), ARG(1)), LITERAL("{0.f, 0.f}"), LITERAL("{1.f, 1.f}"))))) +MATH_API_REWRITER_DEVICE( + "__hcmadd", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hcmadd"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hcmadd", CALL(MapNames::getClNamespace() + + "ext::intel::math::hcmadd", + ARG(0), ARG(1), ARG(2)))), + EMPTY_FACTORY_ENTRY("__hcmadd"), + CALL_FACTORY_ENTRY("__hcmadd", CALL(MapNames::getDpctNamespace() + + "complex_mul_add", + ARG(0), ARG(1), ARG(2))))) + +MATH_API_REWRITER_DEVICE( + "__hfma2_relu", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hfma2_relu"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hfma2_relu", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hfma2_relu", + ARG(0), ARG(1), ARG(2)))), + EMPTY_FACTORY_ENTRY("__hfma2_relu"), + CALL_FACTORY_ENTRY("__hfma2_relu", + CALL(MapNames::getDpctNamespace() + "relu", + CALL(MapNames::getClNamespace() + "fma", ARG(0), + ARG(1), ARG(2)))))) + MATH_API_REWRITER_DEVICE( "__hfma2_sat", MATH_API_DEVICE_NODES( @@ -248,6 +346,19 @@ MATH_API_REWRITER_DEVICE( makeCallArgCreatorWithCall(0), makeCallArgCreatorWithCall(1)))) +MATH_API_REWRITER_DEVICE( + "__hmul2_rn", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmul2_rn"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmul2_rn", CALL(MapNames::getClNamespace() + + "ext::intel::math::hmul2", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmul2_rn"), + BINARY_OP_FACTORY_ENTRY("__hmul2_rn", BinaryOperatorKind::BO_Mul, + ARG(0), ARG(1)))) + MATH_API_REWRITER_DEVICE( "__hmul2_sat", MATH_API_DEVICE_NODES( @@ -292,6 +403,19 @@ MATH_API_REWRITER_DEVICE( makeCallArgCreatorWithCall(0), makeCallArgCreatorWithCall(1)))) +MATH_API_REWRITER_DEVICE( + "__hsub2_rn", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hsub2_rn"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hsub2_rn", CALL(MapNames::getClNamespace() + + "ext::intel::math::hsub2", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hsub2_rn"), + BINARY_OP_FACTORY_ENTRY("__hsub2_rn", BinaryOperatorKind::BO_Sub, + ARG(0), ARG(1)))) + MATH_API_REWRITER_DEVICE( "__hsub2_sat", MATH_API_DEVICE_NODES( @@ -479,6 +603,62 @@ MATH_API_REWRITER_DEVICE( ARG(0), ARG(1), LITERAL("std::less<>()")))))) +MATH_API_REWRITER_DEVICE( + "__hmax", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmax"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmax", CALL(MapNames::getClNamespace() + + "ext::intel::math::hmax", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmax"), + CALL_FACTORY_ENTRY("__hmax", CALL(MapNames::getClNamespace() + "fmax", + ARG(0), ARG(1))))) + +MATH_API_REWRITER_DEVICE( + "__hmax_nan", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmax_nan"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmax_nan", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hmax_nan", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmax_nan"), + CALL_FACTORY_ENTRY("__hmax_nan", + CALL(MapNames::getDpctNamespace() + "fmax_nan", + ARG(0), ARG(1))))) + +MATH_API_REWRITER_DEVICE( + "__hmin", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmin"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmin", CALL(MapNames::getClNamespace() + + "ext::intel::math::hmin", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmin"), + CALL_FACTORY_ENTRY("__hmin", CALL(MapNames::getClNamespace() + "fmin", + ARG(0), ARG(1))))) + +MATH_API_REWRITER_DEVICE( + "__hmin_nan", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmin_nan"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmin_nan", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hmin_nan", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmin_nan"), + CALL_FACTORY_ENTRY("__hmin_nan", + CALL(MapNames::getDpctNamespace() + "fmin_nan", + ARG(0), ARG(1))))) + MATH_API_REWRITER_DEVICE( "__hne", MATH_API_DEVICE_NODES( @@ -732,6 +912,15 @@ MATH_API_REWRITER_DEVICE( "__heq2", CALL(MapNames::getDpctNamespace() + "compare", ARG(0), ARG(1), LITERAL("std::equal_to<>()")))))) +MATH_API_REWRITER_DEVICE( + "__heq2_mask", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__heq2_mask"), EMPTY_FACTORY_ENTRY("__heq2_mask"), + EMPTY_FACTORY_ENTRY("__heq2_mask"), + CALL_FACTORY_ENTRY("__heq2_mask", + CALL(MapNames::getDpctNamespace() + "compare_mask", + ARG(0), ARG(1), LITERAL("std::equal_to<>()"))))) + MATH_API_REWRITER_DEVICE( "__hequ2", MATH_API_DEVICE_NODES( @@ -749,6 +938,17 @@ MATH_API_REWRITER_DEVICE( ARG(0), ARG(1), LITERAL("std::equal_to<>()")))))) +MATH_API_REWRITER_DEVICE( + "__hequ2_mask", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hequ2_mask"), + EMPTY_FACTORY_ENTRY("__hequ2_mask"), + EMPTY_FACTORY_ENTRY("__hequ2_mask"), + CALL_FACTORY_ENTRY("__hequ2_mask", + CALL(MapNames::getDpctNamespace() + + "unordered_compare_mask", + ARG(0), ARG(1), LITERAL("std::equal_to<>()"))))) + MATH_API_REWRITER_DEVICE( "__hge2", MATH_API_DEVICE_NODES( @@ -765,6 +965,16 @@ MATH_API_REWRITER_DEVICE( "__hge2", CALL(MapNames::getDpctNamespace() + "compare", ARG(0), ARG(1), LITERAL("std::greater_equal<>()")))))) +MATH_API_REWRITER_DEVICE( + "__hge2_mask", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hge2_mask"), EMPTY_FACTORY_ENTRY("__hge2_mask"), + EMPTY_FACTORY_ENTRY("__hge2_mask"), + CALL_FACTORY_ENTRY("__hge2_mask", + CALL(MapNames::getDpctNamespace() + "compare_mask", + ARG(0), ARG(1), + LITERAL("std::greater_equal<>()"))))) + MATH_API_REWRITER_DEVICE( "__hgeu2", MATH_API_DEVICE_NODES( @@ -782,6 +992,17 @@ MATH_API_REWRITER_DEVICE( CALL(MapNames::getDpctNamespace() + "unordered_compare", ARG(0), ARG(1), LITERAL("std::greater_equal<>()")))))) +MATH_API_REWRITER_DEVICE( + "__hgeu2_mask", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hgeu2_mask"), + EMPTY_FACTORY_ENTRY("__hgeu2_mask"), + EMPTY_FACTORY_ENTRY("__hgeu2_mask"), + CALL_FACTORY_ENTRY( + "__hgeu2_mask", + CALL(MapNames::getDpctNamespace() + "unordered_compare_mask", + ARG(0), ARG(1), LITERAL("std::greater_equal<>()"))))) + MATH_API_REWRITER_DEVICE( "__hgt2", MATH_API_DEVICE_NODES( @@ -798,6 +1019,15 @@ MATH_API_REWRITER_DEVICE( "__hgt2", CALL(MapNames::getDpctNamespace() + "compare", ARG(0), ARG(1), LITERAL("std::greater<>()")))))) +MATH_API_REWRITER_DEVICE( + "__hgt2_mask", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hgt2_mask"), EMPTY_FACTORY_ENTRY("__hgt2_mask"), + EMPTY_FACTORY_ENTRY("__hgt2_mask"), + CALL_FACTORY_ENTRY("__hgt2_mask", + CALL(MapNames::getDpctNamespace() + "compare_mask", + ARG(0), ARG(1), LITERAL("std::greater<>()"))))) + MATH_API_REWRITER_DEVICE( "__hgtu2", MATH_API_DEVICE_NODES( @@ -815,6 +1045,17 @@ MATH_API_REWRITER_DEVICE( ARG(0), ARG(1), LITERAL("std::greater<>()")))))) +MATH_API_REWRITER_DEVICE( + "__hgtu2_mask", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hgtu2_mask"), + EMPTY_FACTORY_ENTRY("__hgtu2_mask"), + EMPTY_FACTORY_ENTRY("__hgtu2_mask"), + CALL_FACTORY_ENTRY("__hgtu2_mask", + CALL(MapNames::getDpctNamespace() + + "unordered_compare_mask", + ARG(0), ARG(1), LITERAL("std::greater<>()"))))) + MATH_API_REWRITER_DEVICE( "__hisnan2", MATH_API_DEVICE_NODES( @@ -848,6 +1089,16 @@ MATH_API_REWRITER_DEVICE( "__hle2", CALL(MapNames::getDpctNamespace() + "compare", ARG(0), ARG(1), LITERAL("std::less_equal<>()")))))) +MATH_API_REWRITER_DEVICE( + "__hle2_mask", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hle2_mask"), EMPTY_FACTORY_ENTRY("__hle2_mask"), + EMPTY_FACTORY_ENTRY("__hle2_mask"), + CALL_FACTORY_ENTRY("__hle2_mask", + CALL(MapNames::getDpctNamespace() + "compare_mask", + ARG(0), ARG(1), + LITERAL("std::less_equal<>()"))))) + MATH_API_REWRITER_DEVICE( "__hleu2", MATH_API_DEVICE_NODES( @@ -865,6 +1116,17 @@ MATH_API_REWRITER_DEVICE( CALL(MapNames::getDpctNamespace() + "unordered_compare", ARG(0), ARG(1), LITERAL("std::less_equal<>()")))))) +MATH_API_REWRITER_DEVICE( + "__hleu2_mask", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hleu2_mask"), + EMPTY_FACTORY_ENTRY("__hleu2_mask"), + EMPTY_FACTORY_ENTRY("__hleu2_mask"), + CALL_FACTORY_ENTRY( + "__hleu2_mask", + CALL(MapNames::getDpctNamespace() + "unordered_compare_mask", + ARG(0), ARG(1), LITERAL("std::less_equal<>()"))))) + MATH_API_REWRITER_DEVICE( "__hlt2", MATH_API_DEVICE_NODES( @@ -881,6 +1143,15 @@ MATH_API_REWRITER_DEVICE( "__hlt2", CALL(MapNames::getDpctNamespace() + "compare", ARG(0), ARG(1), LITERAL("std::less<>()")))))) +MATH_API_REWRITER_DEVICE( + "__hlt2_mask", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hlt2_mask"), EMPTY_FACTORY_ENTRY("__hlt2_mask"), + EMPTY_FACTORY_ENTRY("__hlt2_mask"), + CALL_FACTORY_ENTRY("__hlt2_mask", + CALL(MapNames::getDpctNamespace() + "compare_mask", + ARG(0), ARG(1), LITERAL("std::less<>()"))))) + MATH_API_REWRITER_DEVICE( "__hltu2", MATH_API_DEVICE_NODES( @@ -898,6 +1169,85 @@ MATH_API_REWRITER_DEVICE( ARG(0), ARG(1), LITERAL("std::less<>()")))))) +MATH_API_REWRITER_DEVICE( + "__hltu2_mask", + MATH_API_DEVICE_NODES(EMPTY_FACTORY_ENTRY("__hltu2_mask"), + EMPTY_FACTORY_ENTRY("__hltu2_mask"), + EMPTY_FACTORY_ENTRY("__hltu2_mask"), + CALL_FACTORY_ENTRY("__hltu2_mask", + CALL(MapNames::getDpctNamespace() + + "unordered_compare_mask", + ARG(0), ARG(1), + LITERAL("std::less<>()"))))) + +MATH_API_REWRITER_DEVICE( + "__hmax2", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmax2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmax2", CALL(MapNames::getClNamespace() + + "ext::intel::math::hmax2", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmax2"), + CALL_FACTORY_ENTRY("__hmax2", + CALL(MapNames::getClNamespace() + "half2", + CALL(MapNames::getClNamespace() + "fmax", + ARRAY_SUBSCRIPT(ARG(0), LITERAL("0")), + ARRAY_SUBSCRIPT(ARG(1), LITERAL("0"))), + CALL(MapNames::getClNamespace() + "fmax", + ARRAY_SUBSCRIPT(ARG(0), LITERAL("1")), + ARRAY_SUBSCRIPT(ARG(1), LITERAL("1"))))))) + +MATH_API_REWRITER_DEVICE( + "__hmax2_nan", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmax2_nan"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmax2_nan", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hmax2_nan", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmax2_nan"), + CALL_FACTORY_ENTRY("__hmax2_nan", + CALL(MapNames::getDpctNamespace() + "fmax_nan", + ARG(0), ARG(1))))) + +MATH_API_REWRITER_DEVICE( + "__hmin2", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmin2"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmin2", CALL(MapNames::getClNamespace() + + "ext::intel::math::hmin2", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmin2"), + CALL_FACTORY_ENTRY("__hmin2", + CALL(MapNames::getClNamespace() + "half2", + CALL(MapNames::getClNamespace() + "fmin", + ARRAY_SUBSCRIPT(ARG(0), LITERAL("0")), + ARRAY_SUBSCRIPT(ARG(1), LITERAL("0"))), + CALL(MapNames::getClNamespace() + "fmin", + ARRAY_SUBSCRIPT(ARG(0), LITERAL("1")), + ARRAY_SUBSCRIPT(ARG(1), LITERAL("1"))))))) + +MATH_API_REWRITER_DEVICE( + "__hmin2_nan", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hmin2_nan"), + HEADER_INSERT_FACTORY( + HeaderType::HT_SYCL_Math, + CALL_FACTORY_ENTRY("__hmin2_nan", + CALL(MapNames::getClNamespace() + + "ext::intel::math::hmin2_nan", + ARG(0), ARG(1)))), + EMPTY_FACTORY_ENTRY("__hmin2_nan"), + CALL_FACTORY_ENTRY("__hmin2_nan", + CALL(MapNames::getDpctNamespace() + "fmin_nan", + ARG(0), ARG(1))))) + MATH_API_REWRITER_DEVICE( "__hne2", MATH_API_DEVICE_NODES( @@ -914,6 +1264,16 @@ MATH_API_REWRITER_DEVICE( "__hne2", CALL(MapNames::getDpctNamespace() + "compare", ARG(0), ARG(1), LITERAL("std::not_equal_to<>()")))))) +MATH_API_REWRITER_DEVICE( + "__hne2_mask", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hne2_mask"), EMPTY_FACTORY_ENTRY("__hne2_mask"), + EMPTY_FACTORY_ENTRY("__hne2_mask"), + CALL_FACTORY_ENTRY("__hne2_mask", + CALL(MapNames::getDpctNamespace() + "compare_mask", + ARG(0), ARG(1), + LITERAL("std::not_equal_to<>()"))))) + MATH_API_REWRITER_DEVICE( "__hneu2", MATH_API_DEVICE_NODES( @@ -931,7 +1291,21 @@ MATH_API_REWRITER_DEVICE( CALL(MapNames::getDpctNamespace() + "unordered_compare", ARG(0), ARG(1), LITERAL("std::not_equal_to<>()")))))) +MATH_API_REWRITER_DEVICE( + "__hneu2_mask", + MATH_API_DEVICE_NODES( + EMPTY_FACTORY_ENTRY("__hneu2_mask"), + EMPTY_FACTORY_ENTRY("__hneu2_mask"), + EMPTY_FACTORY_ENTRY("__hneu2_mask"), + CALL_FACTORY_ENTRY( + "__hneu2_mask", + CALL(MapNames::getDpctNamespace() + "unordered_compare_mask", + ARG(0), ARG(1), LITERAL("std::not_equal_to<>()"))))) + // Half Precision Conversion and Data Movement +CALL_FACTORY_ENTRY("__double2half", + CALL(MapNames::getClNamespace() + "half", ARG(0))) + ARRAYSUBSCRIPT_EXPR_FACTORY_ENTRY("__high2float", ARG(0), LITERAL("1")) MATH_API_REWRITER_DEVICE( diff --git a/clang/runtime/dpct-rt/include/math.hpp.inc b/clang/runtime/dpct-rt/include/math.hpp.inc index eb4af2870a06..e83bf793bb55 100644 --- a/clang/runtime/dpct-rt/include/math.hpp.inc +++ b/clang/runtime/dpct-rt/include/math.hpp.inc @@ -225,6 +225,21 @@ compare(const T a, const T b, const BinaryOperation binary_op) { } // DPCT_LABEL_END +/// Performs 2 elements comparison, compare result of each element is 0 (false) +/// or 0xffff (true), returns an unsigned int by composing compare result of two +/// elements. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline unsigned compare_mask(const sycl::vec a, const sycl::vec b, + const BinaryOperation binary_op) { + return sycl::vec(-compare(a[0], b[0], binary_op), + -compare(a[1], b[1], binary_op)) + .as>(); +} + // DPCT_LABEL_BEGIN|unordered_compare2|dpct // DPCT_DEPENDENCY_BEGIN // Math|unordered_compare @@ -243,6 +258,22 @@ unordered_compare(const T a, const T b, const BinaryOperation binary_op) { } // DPCT_LABEL_END +/// Performs 2 elements unordered comparison, compare result of each element is +/// 0 (false) or 0xffff (true), returns an unsigned int by composing compare +/// result of two elements. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline unsigned unordered_compare_mask(const sycl::vec a, + const sycl::vec b, + const BinaryOperation binary_op) { + return sycl::vec(-unordered_compare(a[0], b[0], binary_op), + -unordered_compare(a[1], b[1], binary_op)) + .as>(); +} + // DPCT_LABEL_BEGIN|isnan|dpct // DPCT_DEPENDENCY_BEGIN // Math|detail_isnan @@ -358,6 +389,63 @@ inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b) { } // DPCT_LABEL_END +/// Performs relu saturation. +/// \param [in] a The input value +/// \returns the relu saturation result +template inline T relu(const T a) { + if (!detail::isnan(a) && a < 0.f) + return 0.f; + return a; +} +template inline sycl::vec relu(const sycl::vec a) { + return {relu(a[0]), relu(a[1])}; +} + +/// Performs complex number multiply addition. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns the operation result +template +inline sycl::vec complex_mul_add(const sycl::vec a, + const sycl::vec b, + const sycl::vec c) { + return sycl::vec{a[0] * b[0] - a[1] * b[1] + c[0], + a[0] * b[1] + a[1] * b[0] + c[1]}; +} + +/// Performs 2 elements comparison and returns the bigger one. If either of +/// inputs is NaN, then return NaN. +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns the bigger value +template inline T fmax_nan(const T a, const T b) { + if (detail::isnan(a) || detail::isnan(b)) + return NAN; + return sycl::fmax(a, b); +} +template +inline sycl::vec fmax_nan(const sycl::vec a, + const sycl::vec b) { + return {fmax_nan(a[0], b[0]), fmax_nan(a[1], b[1])}; +} + +/// Performs 2 elements comparison and returns the smaller one. If either of +/// inputs is NaN, then return NaN. +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns the smaller value +template inline T fmin_nan(const T a, const T b) { + if (detail::isnan(a) || detail::isnan(b)) + return NAN; + return sycl::fmin(a, b); +} +template +inline sycl::vec fmin_nan(const sycl::vec a, + const sycl::vec b) { + return {fmin_nan(a[0], b[0]), fmin_nan(a[1], b[1])}; +} + // DPCT_LABEL_BEGIN|abs|dpct // DPCT_DEPENDENCY_EMPTY // DPCT_CODE diff --git a/clang/test/dpct/cuda-math-extension-cuda11-after.cu b/clang/test/dpct/cuda-math-extension-cuda11-after.cu new file mode 100644 index 000000000000..d8e91c57545d --- /dev/null +++ b/clang/test/dpct/cuda-math-extension-cuda11-after.cu @@ -0,0 +1,62 @@ +// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2 +// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2 +// RUN: dpct --format-range=none --use-dpcpp-extensions=intel_device_math -out-root %T/cuda-math-extension-cuda11-after %s --cuda-include-path="%cuda-path/include" -- -x cuda --cuda-host-only --std=c++14 +// RUN: FileCheck --input-file %T/cuda-math-extension-cuda11-after/cuda-math-extension-cuda11-after.dp.cpp --match-full-lines %s + +// CHECK: #include +#include "cuda_fp16.h" + +using namespace std; + +__global__ void kernelFuncHalf(__half *deviceArrayHalf) { + __half h, h_1, h_2; + __half2 h2, h2_1, h2_2; + + // Half Arithmetic Functions + + // CHECK: h_2 = sycl::ext::intel::math::hadd(h, h_1); + h_2 = __hadd_rn(h, h_1); + // CHECK: h_2 = sycl::ext::intel::math::hfma_relu(h, h_1, h_2); + h_2 = __hfma_relu(h, h_1, h_2); + // CHECK: h_2 = sycl::ext::intel::math::hmul(h, h_1); + h_2 = __hmul_rn(h, h_1); + // CHECK: h_2 = sycl::ext::intel::math::hsub(h, h_1); + h_2 = __hsub_rn(h, h_1); + + // Half2 Arithmetic Functions + + // CHECK: h2_2 = sycl::ext::intel::math::hadd2(h2, h2_1); + h2_2 = __hadd2_rn(h2, h2_1); + // CHECK: h2_2 = sycl::ext::intel::math::hcmadd(h2, h2_1, h2_2); + h2_2 = __hcmadd(h2, h2_1, h2_2); + // CHECK: h2_2 = sycl::ext::intel::math::hfma2_relu(h2, h2_1, h2_2); + h2_2 = __hfma2_relu(h2, h2_1, h2_2); + // CHECK: h2_2 = sycl::ext::intel::math::hmul2(h2, h2_1); + h2_2 = __hmul2_rn(h2, h2_1); + // CHECK: h2_2 = sycl::ext::intel::math::hsub2(h2, h2_1); + h2_2 = __hsub2_rn(h2, h2_1); + + // Half Comparison Functions + + // CHECK: h_2 = sycl::ext::intel::math::hmax(h, h_1); + h_2 = __hmax(h, h_1); + // CHECK: h_2 = sycl::ext::intel::math::hmax_nan(h, h_1); + h_2 = __hmax_nan(h, h_1); + // CHECK: h_2 = sycl::ext::intel::math::hmin(h, h_1); + h_2 = __hmin(h, h_1); + // CHECK: h_2 = sycl::ext::intel::math::hmin_nan(h, h_1); + h_2 = __hmin_nan(h, h_1); + + // Half2 Comparison Functions + + // CHECK: h2_2 = sycl::ext::intel::math::hmax2(h2, h2_1); + h2_2 = __hmax2(h2, h2_1); + // CHECK: h2_2 = sycl::ext::intel::math::hmax2_nan(h2, h2_1); + h2_2 = __hmax2_nan(h2, h2_1); + // CHECK: h2_2 = sycl::ext::intel::math::hmin2(h2, h2_1); + h2_2 = __hmin2(h2, h2_1); + // CHECK: h2_2 = sycl::ext::intel::math::hmin2_nan(h2, h2_1); + h2_2 = __hmin2_nan(h2, h2_1); +} + +int main() { return 0; } diff --git a/clang/test/dpct/cuda-math-intrinsics-cuda11-after.cu b/clang/test/dpct/cuda-math-intrinsics-cuda11-after.cu index 52c63460168f..7257ab7751de 100644 --- a/clang/test/dpct/cuda-math-intrinsics-cuda11-after.cu +++ b/clang/test/dpct/cuda-math-intrinsics-cuda11-after.cu @@ -10,9 +10,58 @@ using namespace std; __global__ void kernelFuncHalf(__half *deviceArrayHalf) { __half h, h_1, h_2; __half2 h2, h2_1, h2_2; + double d; + + // Half Arithmetic Functions + + // CHECK: h_2 = h + h_1; + h_2 = __hadd_rn(h, h_1); + // CHECK: h_2 = dpct::relu(sycl::fma(h, h_1, h_2)); + h_2 = __hfma_relu(h, h_1, h_2); + // CHECK: h_2 = h * h_1; + h_2 = __hmul_rn(h, h_1); + // CHECK: h_2 = h - h_1; + h_2 = __hsub_rn(h, h_1); + + // Half2 Arithmetic Functions + + // CHECK: h2_2 = h2 + h2_1; + h2_2 = __hadd2_rn(h2, h2_1); + // CHECK: h2_2 = dpct::complex_mul_add(h2, h2_1, h2_2); + h2_2 = __hcmadd(h2, h2_1, h2_2); + // CHECK: h2_2 = dpct::relu(sycl::fma(h2, h2_1, h2_2)); + h2_2 = __hfma2_relu(h2, h2_1, h2_2); + // CHECK: h2_2 = h2 * h2_1; + h2_2 = __hmul2_rn(h2, h2_1); + // CHECK: h2_2 = h2 - h2_1; + h2_2 = __hsub2_rn(h2, h2_1); + + // Half Comparison Functions + + // CHECK: h2_2 = sycl::fmax(h, h_1); + h2_2 = __hmax(h, h_1); + // CHECK: h2_2 = dpct::fmax_nan(h, h_1); + h2_2 = __hmax_nan(h, h_1); + // CHECK: h2_2 = sycl::fmin(h, h_1); + h2_2 = __hmin(h, h_1); + // CHECK: h2_2 = dpct::fmin_nan(h, h_1); + h2_2 = __hmin_nan(h, h_1); + + // Half2 Comparison Functions + + // CHECK: h2_2 = sycl::half2(sycl::fmax(h2[0], h2_1[0]), sycl::fmax(h2[1], h2_1[1])); + h2_2 = __hmax2(h2, h2_1); + // CHECK: h2_2 = dpct::fmax_nan(h2, h2_1); + h2_2 = __hmax2_nan(h2, h2_1); + // CHECK: h2_2 = sycl::half2(sycl::fmin(h2[0], h2_1[0]), sycl::fmin(h2[1], h2_1[1])); + h2_2 = __hmin2(h2, h2_1); + // CHECK: h2_2 = dpct::fmin_nan(h2, h2_1); + h2_2 = __hmin2_nan(h2, h2_1); // Half Precision Conversion and Data Movement + // CHECK: h_2 = sycl::half(d); + h_2 = __double2half(d); // CHECK: /* // CHECK-NEXT: DPCT1098:{{[0-9]+}}: The '*' expression is used instead of the __ldca call. These two expressions do not provide the exact same functionality. Check the generated code for potential precision and/or performance issues. // CHECK-NEXT: */ diff --git a/clang/test/dpct/cuda-math-intrinsics-cuda12-after.cu b/clang/test/dpct/cuda-math-intrinsics-cuda12-after.cu new file mode 100644 index 000000000000..a74f71875bc9 --- /dev/null +++ b/clang/test/dpct/cuda-math-intrinsics-cuda12-after.cu @@ -0,0 +1,42 @@ +// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2, cuda-11.0, cuda-11.1, cuda-11.2, cuda-11.3, cuda-11.4, cuda-11.5, cuda-11.6, cuda-11.7, cuda-11.8 +// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2, v11.0, v11.1, v11.2, v11.3, v11.4, v11.5, v11.6, v11.7, v11.8 +// RUN: dpct --format-range=none -out-root %T/cuda-math-intrinsics-cuda12-after %s --cuda-include-path="%cuda-path/include" -- -x cuda --cuda-host-only --std=c++14 +// RUN: FileCheck --input-file %T/cuda-math-intrinsics-cuda12-after/cuda-math-intrinsics-cuda12-after.dp.cpp --match-full-lines %s + +#include "cuda_fp16.h" + +using namespace std; + +__global__ void kernelFuncHalf(__half *deviceArrayHalf) { + __half h, h_1, h_2; + __half2 h2, h2_1, h2_2; + + // Half2 Comparison Functions + + // CHECK: h2_2 = dpct::compare_mask(h2, h2_1, std::equal_to<>()); + h2_2 = __heq2_mask(h2, h2_1); + // CHECK: h2_2 = dpct::unordered_compare_mask(h2, h2_1, std::equal_to<>()); + h2_2 = __hequ2_mask(h2, h2_1); + // CHECK: h2_2 = dpct::compare_mask(h2, h2_1, std::greater_equal<>()); + h2_2 = __hge2_mask(h2, h2_1); + // CHECK: h2_2 = dpct::unordered_compare_mask(h2, h2_1, std::greater_equal<>()); + h2_2 = __hgeu2_mask(h2, h2_1); + // CHECK: h2_2 = dpct::compare_mask(h2, h2_1, std::greater<>()); + h2_2 = __hgt2_mask(h2, h2_1); + // CHECK: h2_2 = dpct::unordered_compare_mask(h2, h2_1, std::greater<>()); + h2_2 = __hgtu2_mask(h2, h2_1); + // CHECK: h2_2 = dpct::compare_mask(h2, h2_1, std::less_equal<>()); + h2_2 = __hle2_mask(h2, h2_1); + // CHECK: h2_2 = dpct::unordered_compare_mask(h2, h2_1, std::less_equal<>()); + h2_2 = __hleu2_mask(h2, h2_1); + // CHECK: h2_2 = dpct::compare_mask(h2, h2_1, std::less<>()); + h2_2 = __hlt2_mask(h2, h2_1); + // CHECK: h2_2 = dpct::unordered_compare_mask(h2, h2_1, std::less<>()); + h2_2 = __hltu2_mask(h2, h2_1); + // CHECK: h2_2 = dpct::compare_mask(h2, h2_1, std::not_equal_to<>()); + h2_2 = __hne2_mask(h2, h2_1); + // CHECK: h2_2 = dpct::unordered_compare_mask(h2, h2_1, std::not_equal_to<>()); + h2_2 = __hneu2_mask(h2, h2_1); +} + +int main() { return 0; } diff --git a/clang/test/dpct/helper_files_ref/include/math.hpp b/clang/test/dpct/helper_files_ref/include/math.hpp index 13529c91a2c6..04d78bc2a393 100644 --- a/clang/test/dpct/helper_files_ref/include/math.hpp +++ b/clang/test/dpct/helper_files_ref/include/math.hpp @@ -149,6 +149,21 @@ compare(const T a, const T b, const BinaryOperation binary_op) { return {compare(a[0], b[0], binary_op), compare(a[1], b[1], binary_op)}; } +/// Performs 2 elements comparison, compare result of each element is 0 (false) +/// or 0xffff (true), returns an unsigned int by composing compare result of two +/// elements. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline unsigned compare_mask(const sycl::vec a, const sycl::vec b, + const BinaryOperation binary_op) { + return sycl::vec(-compare(a[0], b[0], binary_op), + -compare(a[1], b[1], binary_op)) + .as>(); +} + /// Performs 2 element unordered comparison. /// \param [in] a The first value /// \param [in] b The second value @@ -161,6 +176,22 @@ unordered_compare(const T a, const T b, const BinaryOperation binary_op) { unordered_compare(a[1], b[1], binary_op)}; } +/// Performs 2 elements unordered comparison, compare result of each element is +/// 0 (false) or 0xffff (true), returns an unsigned int by composing compare +/// result of two elements. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline unsigned unordered_compare_mask(const sycl::vec a, + const sycl::vec b, + const BinaryOperation binary_op) { + return sycl::vec(-unordered_compare(a[0], b[0], binary_op), + -unordered_compare(a[1], b[1], binary_op)) + .as>(); +} + /// Determine whether 2 element value is NaN. /// \param [in] a The input value /// \returns the comparison result @@ -266,6 +297,63 @@ inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b) { return sycl::max(static_cast(a), b); } +/// Performs relu saturation. +/// \param [in] a The input value +/// \returns the relu saturation result +template inline T relu(const T a) { + if (!detail::isnan(a) && a < 0.f) + return 0.f; + return a; +} +template inline sycl::vec relu(const sycl::vec a) { + return {relu(a[0]), relu(a[1])}; +} + +/// Performs complex number multiply addition. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns the operation result +template +inline sycl::vec complex_mul_add(const sycl::vec a, + const sycl::vec b, + const sycl::vec c) { + return sycl::vec{a[0] * b[0] - a[1] * b[1] + c[0], + a[0] * b[1] + a[1] * b[0] + c[1]}; +} + +/// Performs 2 elements comparison and returns the bigger one. If either of +/// inputs is NaN, then return NaN. +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns the bigger value +template inline T fmax_nan(const T a, const T b) { + if (detail::isnan(a) || detail::isnan(b)) + return NAN; + return sycl::fmax(a, b); +} +template +inline sycl::vec fmax_nan(const sycl::vec a, + const sycl::vec b) { + return {fmax_nan(a[0], b[0]), fmax_nan(a[1], b[1])}; +} + +/// Performs 2 elements comparison and returns the smaller one. If either of +/// inputs is NaN, then return NaN. +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns the smaller value +template inline T fmin_nan(const T a, const T b) { + if (detail::isnan(a) || detail::isnan(b)) + return NAN; + return sycl::fmin(a, b); +} +template +inline sycl::vec fmin_nan(const sycl::vec a, + const sycl::vec b) { + return {fmin_nan(a[0], b[0]), fmin_nan(a[1], b[1])}; +} + /// A sycl::abs wrapper functors. struct abs { template auto operator()(const T x) const {