Skip to content

[SYCL][libdevice] Add sqrt with rounding mode supported in sycl::ext::intel::math #12571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions libdevice/imf_impl_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ class __iml_ui128 {
return (this->bits[0] > x.bits[0]);
}

bool operator>=(const __iml_ui128 &x) {
return operator==(x) || operator>(x);
}

bool operator>(const uint64_t &x) {
if (this->bits[1] > 0)
return true;
Expand Down
113 changes: 113 additions & 0 deletions libdevice/imf_rounding_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define __LIBDEVICE_IMF_ROUNDING_OP_H__
#include "imf_impl_utils.hpp"
#include <limits>

template <typename Ty>
static Ty __handling_fp_overflow(unsigned z_sig, int rd) {
typedef typename __iml_fp_config<Ty>::utype UTy;
Expand Down Expand Up @@ -1569,4 +1570,116 @@ template <typename FTy> FTy __fp_fma(FTy x, FTy y, FTy z, int rd) {
}
}

template <typename UTy> UTy integer_sqrt(UTy n, bool &is_squares) {
UTy x{n}, c{0}, d{1};
d = d << (sizeof(UTy) * 8 - 2);
while (d > n)
d = d >> 2;

while (d != 0) {
if (x >= (c + d)) {
x -= (c + d);
c = (c >> 1) + d;
} else
c = c >> 1;
d = d >> 2;
}

if (c * c > n)
c -= 1;
if (c * c == n)
is_squares = true;
else
is_squares = false;
return c;
}

template <typename FTy> FTy __fp_sqrt(FTy x, int rd) {
typedef typename __iml_fp_config<FTy>::utype UTy;
typedef typename __iml_get_double_size_unsigned<UTy>::utype DSUTy;
constexpr int fra_digits = std::numeric_limits<FTy>::digits - 1;
UTy x_bit = __builtin_bit_cast(UTy, x);
UTy x_exp = (x_bit & __iml_fp_config<FTy>::pos_inf_bits) >> fra_digits;
UTy x_fra = x_bit & __iml_fp_config<FTy>::fra_mask;
UTy x_sig = x_bit >> (sizeof(FTy) * 8 - 1);
DSUTy Bit1(1);
constexpr UTy NAN_BITS = __iml_fp_config<FTy>::nan_bits;
constexpr UTy INF_BITS = __iml_fp_config<FTy>::pos_inf_bits;

if ((x_exp == __iml_fp_config<FTy>::exp_mask) && (x_fra != 0x0))
return __builtin_bit_cast(FTy, NAN_BITS);

if ((x_exp == 0x0) && (x_fra == 0x0))
return __builtin_bit_cast(FTy, static_cast<UTy>(0x0));

if (x_sig == 1)
return __builtin_bit_cast(FTy, NAN_BITS);

if ((x_exp == __iml_fp_config<FTy>::exp_mask) && (x_fra == 0x0))
return __builtin_bit_cast(FTy, INF_BITS);

// For all postive subnormal and normal values, the result of sqrt
// is a normal value.
int32_t sx_exp = x_exp;
if (sx_exp == 0x0)
sx_exp = 1 - __iml_fp_config<FTy>::bias;
else
sx_exp -= __iml_fp_config<FTy>::bias;

DSUTy fra_holder{x_fra};
if (x_exp != 0)
fra_holder = (Bit1 << fra_digits) | fra_holder;
sx_exp -= fra_digits;

// 2^x_exp * 1.mant can be represented as: 2^(x_exp - 52) * fra_holder
// for normal value and 2^-1022 * 0.mant can be represented as:
// 2^(-1074) * fra_holder for subnormal value. For fp32, 2^x_exp * 1.mant
// can be represented as: 2^(x_exp - 23) * fra_holder for normal value and
// 2^-126 * 0.mant can be represented as 2^-149 * fra_holder for subnormal.
// fra_holder is a non-zero value.
size_t lz = 0;
if constexpr (std::is_same<DSUTy, __iml_ui128>::value)
lz = 127 - fra_holder.ui128_msb_pos();
else
lz = 63 - get_msb_pos(fra_holder);

fra_holder = fra_holder << lz;
sx_exp -= lz;
if (static_cast<uint32_t>(sx_exp) & 0x1) {
sx_exp += 1;
fra_holder = fra_holder >> 1;
}

bool is_squares = false;
DSUTy sqrt_fra = integer_sqrt<DSUTy>(fra_holder, is_squares);
sx_exp = sx_exp / 2;

if constexpr (std::is_same<DSUTy, __iml_ui128>::value)
lz = 127 - sqrt_fra.ui128_msb_pos();
else
lz = 63 - get_msb_pos(sqrt_fra);
UTy fra1 =
static_cast<UTy>(sqrt_fra >> (sizeof(DSUTy) * 8 - lz - fra_digits - 1));
fra1 = fra1 & __iml_fp_config<FTy>::fra_mask;
sx_exp += sizeof(DSUTy) * 8 - 1 - lz + __iml_fp_config<FTy>::bias;

size_t grs_nsbit = sizeof(FTy) * 16 - lz - 1 - fra_digits;
uint32_t grs_bits =
static_cast<uint32_t>(sqrt_fra & ((Bit1 << grs_nsbit) - Bit1));
uint32_t s_bits =
grs_bits & static_cast<uint32_t>((Bit1 << (grs_nsbit - 3)) - Bit1);
grs_bits = grs_bits >> (grs_nsbit - 3);
if ((s_bits > 0) || !is_squares)
grs_bits |= 0x1;

uint32_t rb =
__handling_rounding(0U, static_cast<uint32_t>(fra1), grs_bits, rd);
fra1 += rb;
if (fra1 > __iml_fp_config<FTy>::fra_mask) {
fra1 = 0x0;
sx_exp++;
}
return __builtin_bit_cast(FTy,
(static_cast<UTy>(sx_exp) << fra_digits) | fra1);
}
#endif
12 changes: 12 additions & 0 deletions libdevice/imf_utils/fp32_round.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,16 @@ DEVICE_EXTERN_C_INLINE
float __devicelib_imf_fmaf_rz(float x, float y, float z) {
return __fp_fma(x, y, z, __IML_RTZ);
}

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_rd(float x) { return __fp_sqrt(x, __IML_RTN); }

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_rn(float x) { return __fp_sqrt(x, __IML_RTE); }

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_ru(float x) { return __fp_sqrt(x, __IML_RTP); }

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_rz(float x) { return __fp_sqrt(x, __IML_RTZ); }
#endif
12 changes: 12 additions & 0 deletions libdevice/imf_utils/fp64_round.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,16 @@ DEVICE_EXTERN_C_INLINE
double __devicelib_imf_fma_rz(double x, double y, double z) {
return __fp_fma(x, y, z, __IML_RTZ);
}

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_rd(double x) { return __fp_sqrt(x, __IML_RTN); }

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_rn(double x) { return __fp_sqrt(x, __IML_RTE); }

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_ru(double x) { return __fp_sqrt(x, __IML_RTP); }

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_rz(double x) { return __fp_sqrt(x, __IML_RTZ); }
#endif
24 changes: 24 additions & 0 deletions libdevice/imf_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1992,4 +1992,28 @@ DEVICE_EXTERN_C_INLINE
float __imf_fmaf_rz(float x, float y, float z) {
return __devicelib_imf_fmaf_rz(x, y, z);
}

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_rd(float);

DEVICE_EXTERN_C_INLINE
float __imf_sqrtf_rd(float x) { return __devicelib_imf_sqrtf_rd(x); }

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_rn(float);

DEVICE_EXTERN_C_INLINE
float __imf_sqrtf_rn(float x) { return __devicelib_imf_sqrtf_rn(x); }

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_ru(float);

DEVICE_EXTERN_C_INLINE
float __imf_sqrtf_ru(float x) { return __devicelib_imf_sqrtf_ru(x); }

DEVICE_EXTERN_C_INLINE
float __devicelib_imf_sqrtf_rz(float);

DEVICE_EXTERN_C_INLINE
float __imf_sqrtf_rz(float x) { return __devicelib_imf_sqrtf_rz(x); }
#endif // __LIBDEVICE_IMF_ENABLED__
24 changes: 24 additions & 0 deletions libdevice/imf_wrapper_fp64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,4 +549,28 @@ DEVICE_EXTERN_C_INLINE
double __imf_fma_rz(double x, double y, double z) {
return __devicelib_imf_fma_rz(x, y, z);
}

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_rd(double);

DEVICE_EXTERN_C_INLINE
double __imf_sqrt_rd(double x) { return __devicelib_imf_sqrt_rd(x); }

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_rn(double);

DEVICE_EXTERN_C_INLINE
double __imf_sqrt_rn(double x) { return __devicelib_imf_sqrt_rn(x); }

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_ru(double);

DEVICE_EXTERN_C_INLINE
double __imf_sqrt_ru(double x) { return __devicelib_imf_sqrt_ru(x); }

DEVICE_EXTERN_C_INLINE
double __devicelib_imf_sqrt_rz(double);

DEVICE_EXTERN_C_INLINE
double __imf_sqrt_rz(double x) { return __devicelib_imf_sqrt_rz(x); }
#endif // __LIBDEVICE_IMF_ENABLED__
8 changes: 8 additions & 0 deletions llvm/tools/sycl-post-link/SYCLDeviceLibReqMask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ SYCLDeviceLibFuncMap SDLMap = {
{"__devicelib_imf_fmaf_rn", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_fmaf_ru", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_fmaf_rz", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_sqrtf_rd", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_sqrtf_rn", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_sqrtf_ru", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_sqrtf_rz", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_float2int_rd", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_float2int_rn", DeviceLibExt::cl_intel_devicelib_imf},
{"__devicelib_imf_float2int_ru", DeviceLibExt::cl_intel_devicelib_imf},
Expand Down Expand Up @@ -528,6 +532,10 @@ SYCLDeviceLibFuncMap SDLMap = {
{"__devicelib_imf_fma_rn", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_fma_ru", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_fma_rz", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_sqrt_rd", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_sqrt_rn", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_sqrt_ru", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_sqrt_rz", DeviceLibExt::cl_intel_devicelib_imf_fp64},
{"__devicelib_imf_bfloat162float",
DeviceLibExt::cl_intel_devicelib_imf_bf16},
{"__devicelib_imf_bfloat162int_rd",
Expand Down
8 changes: 8 additions & 0 deletions sycl/include/sycl/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ extern __DPCPP_SYCL_EXTERNAL float __imf_fmaf_rd(float x, float y, float z);
extern __DPCPP_SYCL_EXTERNAL float __imf_fmaf_rn(float x, float y, float z);
extern __DPCPP_SYCL_EXTERNAL float __imf_fmaf_ru(float x, float y, float z);
extern __DPCPP_SYCL_EXTERNAL float __imf_fmaf_rz(float x, float y, float z);
extern __DPCPP_SYCL_EXTERNAL float __imf_sqrtf_rd(float x);
extern __DPCPP_SYCL_EXTERNAL float __imf_sqrtf_rn(float x);
extern __DPCPP_SYCL_EXTERNAL float __imf_sqrtf_ru(float x);
extern __DPCPP_SYCL_EXTERNAL float __imf_sqrtf_rz(float x);
extern __DPCPP_SYCL_EXTERNAL int __imf_float2int_rd(float x);
extern __DPCPP_SYCL_EXTERNAL int __imf_float2int_rn(float x);
extern __DPCPP_SYCL_EXTERNAL int __imf_float2int_ru(float x);
Expand Down Expand Up @@ -358,6 +362,10 @@ extern __DPCPP_SYCL_EXTERNAL double __imf_drcp_rd(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_drcp_rn(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_drcp_ru(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_drcp_rz(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_sqrt_rd(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_sqrt_rn(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_sqrt_ru(double x);
extern __DPCPP_SYCL_EXTERNAL double __imf_sqrt_rz(double x);
extern __DPCPP_SYCL_EXTERNAL float __imf_double2float_rd(double x);
extern __DPCPP_SYCL_EXTERNAL float __imf_double2float_rn(double x);
extern __DPCPP_SYCL_EXTERNAL float __imf_double2float_ru(double x);
Expand Down
24 changes: 24 additions & 0 deletions sycl/include/sycl/ext/intel/math/imf_rounding_math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ float __imf_fmaf_rz(float, float, float);
float __imf_fmaf_rn(float, float, float);
float __imf_fmaf_ru(float, float, float);
float __imf_fmaf_rd(float, float, float);
float __imf_sqrtf_rz(float);
float __imf_sqrtf_rn(float);
float __imf_sqrtf_ru(float);
float __imf_sqrtf_rd(float);

double __imf_dadd_rz(double, double);
double __imf_dadd_rn(double, double);
Expand All @@ -60,6 +64,10 @@ double __imf_fma_rz(double, double, double);
double __imf_fma_rn(double, double, double);
double __imf_fma_ru(double, double, double);
double __imf_fma_rd(double, double, double);
double __imf_sqrt_rz(double);
double __imf_sqrt_rn(double);
double __imf_sqrt_ru(double);
double __imf_sqrt_rd(double);
};

namespace sycl {
Expand Down Expand Up @@ -154,6 +162,14 @@ template <typename Tp = float> Tp fmaf_rz(Tp x, Tp y, Tp z) {
return __imf_fmaf_rz(x, y, z);
}

template <typename Tp = float> Tp fsqrt_rd(Tp x) { return __imf_sqrtf_rd(x); }

template <typename Tp = float> Tp fsqrt_rn(Tp x) { return __imf_sqrtf_rn(x); }

template <typename Tp = float> Tp fsqrt_ru(Tp x) { return __imf_sqrtf_ru(x); }

template <typename Tp = float> Tp fsqrt_rz(Tp x) { return __imf_sqrtf_rz(x); }

template <typename Tp = double> Tp dadd_rd(Tp x, Tp y) {
return __imf_dadd_rd(x, y);
}
Expand Down Expand Up @@ -242,6 +258,14 @@ template <typename Tp = double> Tp fma_rz(Tp x, Tp y, Tp z) {
return __imf_fma_rz(x, y, z);
}

template <typename Tp = double> Tp dsqrt_rd(Tp x) { return __imf_sqrt_rd(x); }

template <typename Tp = double> Tp dsqrt_rn(Tp x) { return __imf_sqrt_rn(x); }

template <typename Tp = double> Tp dsqrt_ru(Tp x) { return __imf_sqrt_ru(x); }

template <typename Tp = double> Tp dsqrt_rz(Tp x) { return __imf_sqrt_rz(x); }

} // namespace ext::intel::math
} // namespace _V1
} // namespace sycl
30 changes: 30 additions & 0 deletions sycl/test-e2e/DeviceLib/imf_fp32_rounding_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,35 @@ int main(int, char **) {
std::cout << "sycl::ext::intel::math::fmaf_rz passes." << std::endl;
}

{
std::initializer_list<float> input_vals = {
0x1.ba90e6p+1, 0x1.4p+1, 0x1.ea77e6p-2, 0x1.e8330ap+19,
0x1.4ffd68p+5, 0x1.443084p-15, 0x1.605fb2p+6, 0x1.2eb718p-7};
std::initializer_list<unsigned> ref_vals_rd = {
0x3fee0264, 0x3fca62c1, 0x3f312c12, 0x4479faa2,
0x40cf616d, 0x3bcbb4d0, 0x41162c48, 0x3dc4d80e};
std::initializer_list<unsigned> ref_vals_rn = {
0x3fee0265, 0x3fca62c2, 0x3f312c13, 0x4479faa2,
0x40cf616d, 0x3bcbb4d0, 0x41162c48, 0x3dc4d80e};
std::initializer_list<unsigned> ref_vals_ru = {
0x3fee0265, 0x3fca62c2, 0x3f312c13, 0x4479faa3,
0x40cf616e, 0x3bcbb4d1, 0x41162c49, 0x3dc4d80f};
std::initializer_list<unsigned> ref_vals_rz = {
0x3fee0264, 0x3fca62c1, 0x3f312c12, 0x4479faa2,
0x40cf616d, 0x3bcbb4d0, 0x41162c48, 0x3dc4d80e};
test(device_queue, input_vals, ref_vals_rd,
FT(unsigned, sycl::ext::intel::math::fsqrt_rd));
std::cout << "sycl::ext::intel::math::fsqrt_rd passes." << std::endl;
test(device_queue, input_vals, ref_vals_rn,
FT(unsigned, sycl::ext::intel::math::fsqrt_rn));
std::cout << "sycl::ext::intel::math::fsqrt_rn passes." << std::endl;
test(device_queue, input_vals, ref_vals_ru,
FT(unsigned, sycl::ext::intel::math::fsqrt_ru));
std::cout << "sycl::ext::intel::math::fsqrt_ru passes." << std::endl;
test(device_queue, input_vals, ref_vals_rz,
FT(unsigned, sycl::ext::intel::math::fsqrt_rz));
std::cout << "sycl::ext::intel::math::fsqrt_rz passes." << std::endl;
}

return 0;
}
29 changes: 29 additions & 0 deletions sycl/test-e2e/DeviceLib/imf_fp64_rounding_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,5 +210,34 @@ int main(int, char **) {
std::cout << "sycl::ext::intel::math::fmaf_rz passes." << std::endl;
}

{
std::initializer_list<double> input_vals1 = {
0x1p+2, 0x1.fbd37afb0f8edp-1, 0x1.9238e38e38e35p+6, 0x1.7p+3};
std::initializer_list<unsigned long long> ref_vals_rd = {
0x4000000000000000, 0x3fefde8a59acb0bb, 0x40240e33d899cd1b,
0x400b211b1c70d023};
std::initializer_list<unsigned long long> ref_vals_rn = {
0x4000000000000000, 0x3fefde8a59acb0bc, 0x40240e33d899cd1c,
0x400b211b1c70d023};
std::initializer_list<unsigned long long> ref_vals_ru = {
0x4000000000000000, 0x3fefde8a59acb0bc, 0x40240e33d899cd1c,
0x400b211b1c70d024};
std::initializer_list<unsigned long long> ref_vals_rz = {
0x4000000000000000, 0x3fefde8a59acb0bb, 0x40240e33d899cd1b,
0x400b211b1c70d023};
test(device_queue, input_vals1, ref_vals_rd,
FT(unsigned long long, sycl::ext::intel::math::dsqrt_rd));
std::cout << "sycl::ext::intel::math::dsqrt_rd passes." << std::endl;
test(device_queue, input_vals1, ref_vals_rn,
FT(unsigned long long, sycl::ext::intel::math::dsqrt_rn));
std::cout << "sycl::ext::intel::math::dsqrt_rn passes." << std::endl;
test(device_queue, input_vals1, ref_vals_ru,
FT(unsigned long long, sycl::ext::intel::math::dsqrt_ru));
std::cout << "sycl::ext::intel::math::dsqrt_ru passes." << std::endl;
test(device_queue, input_vals1, ref_vals_rz,
FT(unsigned long long, sycl::ext::intel::math::dsqrt_rz));
std::cout << "sycl::ext::intel::math::dsqrt_rz passes." << std::endl;
}

return 0;
}