From da735fe8d654e4ca349debdeaeb19dd090ad82ee Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 12 Jun 2024 13:16:08 +0100 Subject: [PATCH] [SYCL][COMPAT] Add math extend_v*4 to SYCLCompat (#14078) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds math `extend_v*4` operators (18 in total) along with unit-tests for signed and unsigned int32 cases. *Some changes overlap with the previous `extend_v*2` PR #13953 and thus should be reviewed/merged first. --------- Co-authored-by: Alberto Cabrera PĂ©rez Co-authored-by: Joe Todd Co-authored-by: Yihan Wang --- sycl/doc/syclcompat/README.md | 238 ++++++++++++ sycl/include/syclcompat/math.hpp | 339 ++++++++++++++++- .../syclcompat/math/math_extend_v.cpp | 342 +++++++++++++++++- 3 files changed, 898 insertions(+), 21 deletions(-) diff --git a/sycl/doc/syclcompat/README.md b/sycl/doc/syclcompat/README.md index e697fd00be374..1b7f0ab003bf2 100644 --- a/sycl/doc/syclcompat/README.md +++ b/sycl/doc/syclcompat/README.md @@ -2178,6 +2178,244 @@ template inline constexpr RetT extend_vavrg2_sat(AT a, BT b, RetT c); ``` +Similarly, a set of vectorized extend 32-bit operations is provided in the math +header treating each of the 32-bit operands as 4-elements vector (8-bits each) +while handling sign extension to 9-bits internally. There is support for `add`, +`sub`, `absdiff`, `min`, `max` and `avg` binary operations. +Each operation provides has a `_sat` variat which determines if the returning +value is saturated or not, and a `_add` variant that computes the binary sum +of the the initial operation outputs and a third operand. + +```cpp +/// Compute vectorized addition of \p a and \p b, with each value treated as a +/// 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized addition of the two values +template +inline constexpr RetT extend_vadd4(AT a, BT b, RetT c); + +/// Compute vectorized addition of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized addition of the two +/// values and the third value +template +inline constexpr RetT extend_vadd4_add(AT a, BT b, RetT c); + +/// Compute vectorized addition of \p a and \p b with saturation, with each +/// value treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized addition of the two values with saturation +template +inline constexpr RetT extend_vadd4_sat(AT a, BT b, RetT c); + +/// Compute vectorized subtraction of \p a and \p b, with each value treated as +/// a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized subtraction of the two values +template +inline constexpr RetT extend_vsub4(AT a, BT b, RetT c); + +/// Compute vectorized subtraction of \p a and \p b, with each value treated as +/// a 4 elements vector type and extend each element to 9 bit. Then add each +/// half of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized subtraction of the +/// two values and the third value +template +inline constexpr RetT extend_vsub4_add(AT a, BT b, RetT c); + +/// Compute vectorized subtraction of \p a and \p b with saturation, with each +/// value treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized subtraction of the two values with saturation +template +inline constexpr RetT extend_vsub4_sat(AT a, BT b, RetT c); + +/// Compute vectorized abs_diff of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized abs_diff of the two values +template +inline constexpr RetT extend_vabsdiff4(AT a, BT b, RetT c); + +/// Compute vectorized abs_diff of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized abs_diff of the +/// two values and the third value +template +inline constexpr RetT extend_vabsdiff4_add(AT a, BT b, RetT c); + +/// Compute vectorized abs_diff of \p a and \p b with saturation, with each +/// value treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized abs_diff of the two values with saturation +template +inline constexpr RetT extend_vabsdiff4_sat(AT a, BT b, RetT c); + +/// Compute vectorized minimum of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized minimum of the two values +template +inline constexpr RetT extend_vmin4(AT a, BT b, RetT c); + +/// Compute vectorized minimum of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized minimum of the +/// two values and the third value +template +inline constexpr RetT extend_vmin4_add(AT a, BT b, RetT c); + +/// Compute vectorized minimum of \p a and \p b with saturation, with each value +/// treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized minimum of the two values with saturation +template +inline constexpr RetT extend_vmin4_sat(AT a, BT b, RetT c); + +/// Compute vectorized maximum of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized maximum of the two values +template +inline constexpr RetT extend_vmax4(AT a, BT b, RetT c); + +/// Compute vectorized maximum of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized maximum of the +/// two values and the third value +template +inline constexpr RetT extend_vmax4_add(AT a, BT b, RetT c); + +/// Compute vectorized maximum of \p a and \p b with saturation, with each value +/// treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized maximum of the two values with saturation +template +inline constexpr RetT extend_vmax4_sat(AT a, BT b, RetT c); + +/// Compute vectorized average of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized average of the two values +template +inline constexpr RetT extend_vavrg4(AT a, BT b, RetT c); + +/// Compute vectorized average of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized average of the +/// two values and the third value +template +inline constexpr RetT extend_vavrg4_add(AT a, BT b, RetT c); + +/// Compute vectorized average of \p a and \p b with saturation, with each value +/// treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized average of the two values with saturation +template +inline constexpr RetT extend_vavrg4_sat(AT a, BT b, RetT c); +``` + The math header file provides APIs for bit-field insertion (`bfi_safe`) and bit-field extraction (`bfe_safe`). These are bounds-checked variants of underlying `detail` APIs (`detail::bfi`, `detail::bfe`) which, in future diff --git a/sycl/include/syclcompat/math.hpp b/sycl/include/syclcompat/math.hpp index cdd8ee99e1c4d..91990b6585fc8 100644 --- a/sycl/include/syclcompat/math.hpp +++ b/sycl/include/syclcompat/math.hpp @@ -51,6 +51,10 @@ inline ValueT clamp(ValueT val, ValueT min_val, ValueT max_val) { return sycl::clamp(val, min_val, max_val); } +template +constexpr bool is_int32_type = std::is_same_v, int32_t> || + std::is_same_v, uint32_t>; + #ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS // TODO: Follow the process to add this to the extension. If added, // remove this functionality from the header. @@ -140,41 +144,77 @@ inline constexpr RetT extend_binary(AT a, BT b, CT c, return second_op(extend_temp, extend_c); } -template sycl::vec extractAndExtend2(T a) { +template sycl::vec extract_and_extend2(T a) { sycl::vec ret; sycl::vec va{a}; - using Tint = - typename std::conditional, int16_t, uint16_t>::type; - auto v = va.template as>(); + using IntT = std::conditional_t, int16_t, uint16_t>; + auto v = va.template as>(); ret[0] = zero_or_signed_extend(v[0], 17); ret[1] = zero_or_signed_extend(v[1], 17); return ret; } +template sycl::vec extract_and_extend4(T a) { + sycl::vec ret; + sycl::vec va{a}; + using IntT = std::conditional_t, int8_t, uint8_t>; + auto v = va.template as>(); + ret[0] = zero_or_signed_extend(v[0], 9); + ret[1] = zero_or_signed_extend(v[1], 9); + ret[2] = zero_or_signed_extend(v[2], 9); + ret[3] = zero_or_signed_extend(v[3], 9); + return ret; +} + template inline constexpr RetT extend_vbinary2(AT a, BT b, RetT c, BinaryOperation binary_op) { - static_assert(std::is_integral_v && std::is_integral_v && - std::is_integral_v && sizeof(AT) == 4 && - sizeof(BT) == 4 && sizeof(RetT) == 4); - sycl::vec extend_a = extractAndExtend2(a); - sycl::vec extend_b = extractAndExtend2(b); + static_assert(is_int32_type && is_int32_type && is_int32_type); + sycl::vec extend_a = extract_and_extend2(a); + sycl::vec extend_b = extract_and_extend2(b); sycl::vec temp{binary_op(extend_a[0], extend_b[0]), binary_op(extend_a[1], extend_b[1])}; - using Tint = typename std::conditional, int16_t, - uint16_t>::type; + using IntT = std::conditional_t, int16_t, uint16_t>; if constexpr (NeedSat) { int32_t min_val = 0, max_val = 0; - min_val = std::numeric_limits::min(); - max_val = std::numeric_limits::max(); - temp = detail::clamp(temp, {min_val, min_val}, {max_val, max_val}); + min_val = std::numeric_limits::min(); + max_val = std::numeric_limits::max(); + temp = detail::clamp(temp, sycl::vec(min_val), + sycl::vec(max_val)); } if constexpr (NeedAdd) { return temp[0] + temp[1] + c; } - return sycl::vec{temp[0], temp[1]}.template as>(); + return sycl::vec{temp[0], temp[1]}.template as>(); +} + +template +inline constexpr RetT extend_vbinary4(AT a, BT b, RetT c, + BinaryOperation binary_op) { + static_assert(is_int32_type && is_int32_type && is_int32_type); + sycl::vec extend_a = extract_and_extend4(a); + sycl::vec extend_b = extract_and_extend4(b); + sycl::vec temp{ + binary_op(extend_a[0], extend_b[0]), binary_op(extend_a[1], extend_b[1]), + binary_op(extend_a[2], extend_b[2]), binary_op(extend_a[3], extend_b[3])}; + using IntT = std::conditional_t, int8_t, uint8_t>; + + if constexpr (NeedSat) { + int16_t min_val = 0, max_val = 0; + min_val = std::numeric_limits::min(); + max_val = std::numeric_limits::max(); + temp = detail::clamp(temp, sycl::vec(min_val), + sycl::vec(max_val)); + } + if constexpr (NeedAdd) { + return temp[0] + temp[1] + temp[2] + temp[3] + c; + } + + return sycl::vec{temp[0], temp[1], temp[2], temp[3]} + .template as>(); } template inline bool isnan(const ValueT a) { @@ -973,10 +1013,6 @@ template sycl::vec extract_and_sign_or_zero_extend2(T val) { .template convert(); } -template -constexpr bool is_int32_type = - std::is_same_v || std::is_same_v; - } // namespace detail /// Two-way dot product-accumulate. Calculate and return integer_vector2( @@ -1820,4 +1856,269 @@ inline constexpr RetT extend_vavrg2_sat(AT a, BT b, RetT c) { return detail::extend_vbinary2(a, b, c, detail::average()); } +/// Compute vectorized addition of \p a and \p b, with each value treated as a +/// 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized addition of the two values +template +inline constexpr RetT extend_vadd4(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, std::plus()); +} + +/// Compute vectorized addition of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized addition of the two +/// values and the third value +template +inline constexpr RetT extend_vadd4_add(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, std::plus()); +} + +/// Compute vectorized addition of \p a and \p b with saturation, with each +/// value treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized addition of the two values with saturation +template +inline constexpr RetT extend_vadd4_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, std::plus()); +} + +/// Compute vectorized subtraction of \p a and \p b, with each value treated as +/// a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized subtraction of the two values +template +inline constexpr RetT extend_vsub4(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, std::minus()); +} + +/// Compute vectorized subtraction of \p a and \p b, with each value treated as +/// a 4 elements vector type and extend each element to 9 bit. Then add each +/// half of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized subtraction of the +/// two values and the third value +template +inline constexpr RetT extend_vsub4_add(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, std::minus()); +} + +/// Compute vectorized subtraction of \p a and \p b with saturation, with each +/// value treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized subtraction of the two values with saturation +template +inline constexpr RetT extend_vsub4_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, std::minus()); +} + +/// Compute vectorized abs_diff of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized abs_diff of the two values +template +inline constexpr RetT extend_vabsdiff4(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, abs_diff()); +} + +/// Compute vectorized abs_diff of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized abs_diff of the +/// two values and the third value +template +inline constexpr RetT extend_vabsdiff4_add(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, abs_diff()); +} + +/// Compute vectorized abs_diff of \p a and \p b with saturation, with each +/// value treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized abs_diff of the two values with saturation +template +inline constexpr RetT extend_vabsdiff4_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, abs_diff()); +} + +/// Compute vectorized minimum of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized minimum of the two values +template +inline constexpr RetT extend_vmin4(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, minimum()); +} + +/// Compute vectorized minimum of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized minimum of the +/// two values and the third value +template +inline constexpr RetT extend_vmin4_add(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, minimum()); +} + +/// Compute vectorized minimum of \p a and \p b with saturation, with each value +/// treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized minimum of the two values with saturation +template +inline constexpr RetT extend_vmin4_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, minimum()); +} + +/// Compute vectorized maximum of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized maximum of the two values +template +inline constexpr RetT extend_vmax4(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, maximum()); +} + +/// Compute vectorized maximum of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized maximum of the +/// two values and the third value +template +inline constexpr RetT extend_vmax4_add(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, maximum()); +} + +/// Compute vectorized maximum of \p a and \p b with saturation, with each value +/// treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized maximum of the two values with saturation +template +inline constexpr RetT extend_vmax4_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, maximum()); +} + +/// Compute vectorized average of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized average of the two values +template +inline constexpr RetT extend_vavrg4(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, + detail::average()); +} + +/// Compute vectorized average of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized average of the +/// two values and the third value +template +inline constexpr RetT extend_vavrg4_add(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, detail::average()); +} + +/// Compute vectorized average of \p a and \p b with saturation, with each value +/// treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized average of the two values with saturation +template +inline constexpr RetT extend_vavrg4_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, detail::average()); +} + } // namespace syclcompat diff --git a/sycl/test-e2e/syclcompat/math/math_extend_v.cpp b/sycl/test-e2e/syclcompat/math/math_extend_v.cpp index 6b079422a6156..27bacc106b9e9 100644 --- a/sycl/test-e2e/syclcompat/math/math_extend_v.cpp +++ b/sycl/test-e2e/syclcompat/math/math_extend_v.cpp @@ -20,8 +20,7 @@ * math extend-vectorized helpers tests **************************************************************************/ -// ===----------- math_extend_vfunc[2/4].cpp ---------- -*- C++ -* -// --------------===// +// ===------------- math_extend_vfunc[2/4].cpp --------------*- C++ -*-----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -275,6 +274,246 @@ std::pair vavrg2_add() { return {nullptr, 0}; } +// v4 +std::pair vadd4() { + CHECK(syclcompat::extend_vadd4(0x0102FFFE, 0x01FF02FF, 0), + 0x020101FD); + CHECK(syclcompat::extend_vadd4((int32_t)0x7E81FEFF, + (int32_t)0x02FD03FF, 0), + 0x807E01FE); + CHECK(syclcompat::extend_vadd4((uint32_t)0x7E81FEFF, + (uint32_t)0x02FD03FF, 0), + 0x807E01FE); + CHECK(syclcompat::extend_vadd4((uint32_t)0x7E81FEFF, + (int32_t)0x02FD03FF, 0), + 0x807E01FE); + CHECK(syclcompat::extend_vadd4((int32_t)0x7E81FEFF, + (uint32_t)0x02FD03FF, 0), + 0x807E01FE); + CHECK(syclcompat::extend_vadd4_sat((int32_t)0x7E81FEFF, + (int32_t)0x02FD03FF, 0), + 0x7F8001FE); + CHECK(syclcompat::extend_vadd4_sat((uint32_t)0x7E81FEFF, + (uint32_t)0x02FD03FF, 0), + 0x7F7F7F7F); + CHECK(syclcompat::extend_vadd4_sat((uint32_t)0x7E81FEFF, + (int32_t)0x02FD03FF, 0), + 0x7F7E7F7F); + CHECK(syclcompat::extend_vadd4_sat((int32_t)0x7E81FEFF, + (uint32_t)0x02FD03FF, 0), + 0x7F7E017F); + + CHECK(syclcompat::extend_vadd4(0x01020304, 0x0A0B0C0D, 0), + 0x0B0D0F11); + CHECK(syclcompat::extend_vadd4((uint32_t)0x000100FF, + (uint32_t)0x00FE0001, 0), + 0x00FF0000); + CHECK(syclcompat::extend_vadd4_sat((uint32_t)0x000100FF, + (uint32_t)0x00FE0001, 0), + 0x00FF00FF); + + return {nullptr, 0}; +} + +std::pair vadd4_add() { + + CHECK(syclcompat::extend_vadd4_add(0x0102FFFE, 0x01FF02FF, 1), + 0x00000002); + CHECK(syclcompat::extend_vadd4_add((int32_t)0x7E81FEFF, + (int32_t)0x02FD03FF, -1), + 0xFFFFFFFC); + CHECK(syclcompat::extend_vadd4_add((uint32_t)0x7E81FEFF, + (uint32_t)0x02FD03FF, -1), + 0x000004FC); + CHECK(syclcompat::extend_vadd4_add((uint32_t)0x7E81FEFF, + (int32_t)0x02FD03FF, -1), + 0x000002FC); + CHECK(syclcompat::extend_vadd4_add((int32_t)0x7E81FEFF, + (uint32_t)0x02FD03FF, -1), + 0x000001FC); + + CHECK(syclcompat::extend_vadd4_add(0x01020304, 0x01000100, 1), + 0x0000000D); + CHECK(syclcompat::extend_vadd4_add((uint32_t)0x000100FF, + (uint32_t)0x00FE0001, 1), + 0x0000000200); + + return {nullptr, 0}; +} + +std::pair vsub4() { + + CHECK(syclcompat::extend_vsub4((int32_t)0x0102FFFF, + (int32_t)0x020101FE, 0), + 0xFF01FE01); + CHECK(syclcompat::extend_vsub4((int32_t)0x01807F10, 0x0102FE10, 0), + 0x007E8100); + CHECK( + syclcompat::extend_vsub4_sat((int32_t)0x01807F10, 0x0102FE10, 0), + 0x00807F00); + + CHECK(syclcompat::extend_vsub4(0x02020C0B, 0x02010A0A, 0), + 0x00010201); + CHECK(syclcompat::extend_vsub4(0x01020304, 0x02040608, 0), + 0xFFFEFDFC); + CHECK(syclcompat::extend_vsub4_sat(0x01020304, 0x02040608, 0), + 0x00000000); + + return {nullptr, 0}; +} + +std::pair vsub4_add() { + + CHECK(syclcompat::extend_vsub4_add((int32_t)0x0102FFFF, + (int32_t)0x020101FE, -1), + 0xFFFFFFFE); + CHECK( + syclcompat::extend_vsub4_add((int32_t)0x01807F10, 0x0102FE10, 2), + 0x00000001); + + CHECK(syclcompat::extend_vsub4_add(0x02020C0B, 0x02010A0A, 2), + 0x00000006); + CHECK(syclcompat::extend_vsub4_add(0x01020304, 0x02040608, 1), + 0xFFFFFFF7); + + CHECK(syclcompat::extend_vsub4_add((uint32_t)0x01020304, + (uint32_t)0x02040608, 1), + 0xFFFFFFF7); + + return {nullptr, 0}; +} + +std::pair vabsdiff4() { + + CHECK( + syclcompat::extend_vabsdiff4((int32_t)0xFF01FF02, 0x01FF02FF, 0), + 0x02020303); + CHECK(syclcompat::extend_vabsdiff4((int32_t)0x8002007F, + (int32_t)0x01010080, 0), + 0x810100FF); + CHECK(syclcompat::extend_vabsdiff4_sat((int32_t)0x8002007F, + (int32_t)0x01010080, 0), + 0x7F01007F); + + CHECK(syclcompat::extend_vabsdiff4(0x01020304, 0x04030201, 0), + 0x03010103); + CHECK(syclcompat::extend_vabsdiff4((uint32_t)0xFEFF0001, + (int32_t)0xF0FE0003, 0), + 0x0E010002); + CHECK(syclcompat::extend_vabsdiff4_sat((uint32_t)0xFEFF0001, + (int32_t)0xF0FE0003, 0), + 0xFFFF0002); + + return {nullptr, 0}; +} + +std::pair vabsdiff4_add() { + + CHECK(syclcompat::extend_vabsdiff4_add((int32_t)0xFF01FF02, + 0x01FF02FF, 1), + 0x0000000B); + CHECK(syclcompat::extend_vabsdiff4_add((int32_t)0x8002007F, + (int32_t)0x01010080, -1), + 0x00000180); + + CHECK(syclcompat::extend_vabsdiff4_add(0x01020304, 0x04030201, 2), + 0x0000000A); + CHECK(syclcompat::extend_vabsdiff4_add((uint32_t)0xFEFF0001, + (int32_t)0xF0FE0003, 1), + 0x00000212); + + return {nullptr, 0}; +} + +std::pair vmin4() { + + CHECK(syclcompat::extend_vmin4((int32_t)0xFFFF0102, + (int32_t)0xFE010201, 0), + 0xFEFF0101); + + CHECK(syclcompat::extend_vmin4_sat(0x0102FF00, 0x0201FE00, 0), + 0x0101FE00); + + CHECK(syclcompat::extend_vmin4(0x010A020D, 0x000B020C, 0), + 0x000A020C); + + CHECK(syclcompat::extend_vmin4_sat(0x020201FF, 0x0201FFFE, 0), + 0x02010000); + + return {nullptr, 0}; +} + +std::pair vmax4() { + + CHECK(syclcompat::extend_vmax4((int32_t)0xFFFF0102, + (int32_t)0xFE010201, 0), + 0xFF010202); + CHECK(syclcompat::extend_vmax4_sat(0x0102FF00, 0x0201FE00, 0), + 0x0202FF00); + + CHECK(syclcompat::extend_vmax4(0x010A020D, 0x000B020C, 0), + 0x010B020D); + CHECK(syclcompat::extend_vmax4_sat(0x020201FF, 0x0201FFFE, 0), + 0x02020100); + + return {nullptr, 0}; +} + +std::pair vmin4_vmax4_add() { + + CHECK(syclcompat::extend_vmin4_add((int32_t)0xFFFF0102, + (int32_t)0xFE010201, -1), + 0xFFFFFFFE); + + CHECK(syclcompat::extend_vmin4_add(0x010A020D, 0x000B020C, 1), + 0x00000019); + + CHECK(syclcompat::extend_vmax4_add((int32_t)0xFFFF0102, + (int32_t)0xFE010201, 2), + 0x00000006); + CHECK(syclcompat::extend_vmax4_add(0x010A020D, 0x000B020C, -1), + 0x0000001A); + + return {nullptr, 0}; +} + +std::pair vavrg4() { + + CHECK(syclcompat::extend_vavrg4((int32_t)0xFF01FF01, 0x0505FF00, 0), + 0x0203FF01); + CHECK(syclcompat::extend_vavrg4_sat((int32_t)0xFF01FF01, 0x0505FF00, + 0), + 0x0203FF01); + + CHECK(syclcompat::extend_vavrg4(0x00010106, (int32_t)0xFC050101, 0), + (int32_t)0xFE030104); + CHECK(syclcompat::extend_vavrg4_sat(0x00010106, (int32_t)0xFC050101, + 0), + (int32_t)0x00030104); + + return {nullptr, 0}; +} + +std::pair vavrg4_add() { + + CHECK(syclcompat::extend_vavrg4_add((int32_t)0xFF01FF01, 0x0505FF00, + 1), + 0x00000006); + CHECK(syclcompat::extend_vavrg4_add((int32_t)0xFF01FF01, 0x0505FF00, + -6), + 0xFFFFFFFF); + + CHECK(syclcompat::extend_vavrg4_add(0x00010106, (int32_t)0xFC050101, + 1), + (int32_t)0x00000007); + + CHECK(syclcompat::extend_vavrg4_add(0x00010106, (int32_t)0xFC050101, + -1), + (int32_t)0x00000005); + + return {nullptr, 0}; +} + void test(const sycl::stream &s, int *ec) { { auto res = vadd2(); @@ -375,6 +614,105 @@ void test(const sycl::stream &s, int *ec) { } s << "vabsdiff2_add check passed!\n"; } + { + auto res = vadd4(); + if (res.first) { + s << res.first << " = " << res.second << " check failed!\n"; + *ec = 12; + return; + } + s << "vadd4 check passed!\n"; + } + { + auto res = vsub4(); + if (res.first) { + s << res.first << " = " << res.second << " check failed!\n"; + *ec = 13; + return; + } + s << "vsub4 check passed!\n"; + } + { + auto res = vadd4_add(); + if (res.first) { + s << res.first << " = " << res.second << " check failed!\n"; + *ec = 14; + return; + } + s << "vadd4_add check passed!\n"; + } + { + auto res = vsub4_add(); + if (res.first) { + s << res.first << " = " << res.second << " check failed!\n"; + *ec = 15; + return; + } + s << "vsub4_add check passed!\n"; + } + { + auto res = vabsdiff4(); + if (res.first) { + s << res.first << " = " << res.second << " check failed!\n"; + *ec = 16; + return; + } + s << "vabsdiff4 check passed!\n"; + } + { + auto res = vabsdiff4_add(); + if (res.first) { + s << res.first << " = " << res.second << " check failed!\n"; + *ec = 17; + return; + } + s << "vabsdiff4_add check passed!\n"; + } + { + auto res = vmin4(); + if (res.first) { + s << res.first << " = " << res.second << " check failed!\n"; + *ec = 18; + return; + } + s << "vmin4 check passed!\n"; + } + { + auto res = vmax4(); + if (res.first) { + s << res.first << " = " << res.second << " check failed!\n"; + *ec = 19; + return; + } + s << "vmax4 check passed!\n"; + } + { + auto res = vmin4_vmax4_add(); + if (res.first) { + s << res.first << " = " << res.second << " check failed!\n"; + *ec = 20; + return; + } + s << "vmin4_add/vmax4_add check passed!\n"; + } + { + auto res = vavrg4(); + if (res.first) { + s << res.first << " = " << res.second << " check failed!\n"; + *ec = 21; + return; + } + s << "vavrg4 check passed!\n"; + } + { + auto res = vavrg4_add(); + if (res.first) { + s << res.first << " = " << res.second << " check failed!\n"; + *ec = 22; + return; + } + s << "vavrg4_add check passed!\n"; + } *ec = 0; }