Skip to content

[SYCL] Fix 1-element vec ambiguities #17722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 2, 2025
Merged
102 changes: 102 additions & 0 deletions sycl/include/sycl/detail/vector_arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ struct IncDec {};

template <class T> static constexpr bool not_fp = !is_vgenfloat_v<T>;

#if !__SYCL_USE_LIBSYCL8_VEC_IMPL
// Not using `is_byte_v` to avoid unnecessary dependencies on `half`/`bfloat16`
// headers.
template <class T>
static constexpr bool not_byte =
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
!std::is_same_v<T, std::byte>;
#else
true;
#endif
#endif

// To provide information about operators availability depending on vec/swizzle
// element type.
template <typename Op, typename T>
Expand All @@ -80,6 +92,7 @@ inline constexpr bool is_op_available_for_type<OpAssign<Op>, T> =
inline constexpr bool is_op_available_for_type<OP, T> = COND;

// clang-format off
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
__SYCL_OP_AVAILABILITY(std::plus<void> , true)
__SYCL_OP_AVAILABILITY(std::minus<void> , true)
__SYCL_OP_AVAILABILITY(std::multiplies<void> , true)
Expand Down Expand Up @@ -110,6 +123,38 @@ __SYCL_OP_AVAILABILITY(std::bit_not<void> , not_fp<T>)
__SYCL_OP_AVAILABILITY(UnaryPlus , true)

__SYCL_OP_AVAILABILITY(IncDec , true)
#else
__SYCL_OP_AVAILABILITY(std::plus<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::minus<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::multiplies<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::divides<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::modulus<void> , not_fp<T>)

__SYCL_OP_AVAILABILITY(std::bit_and<void> , not_fp<T>)
__SYCL_OP_AVAILABILITY(std::bit_or<void> , not_fp<T>)
__SYCL_OP_AVAILABILITY(std::bit_xor<void> , not_fp<T>)

__SYCL_OP_AVAILABILITY(std::equal_to<void> , true)
__SYCL_OP_AVAILABILITY(std::not_equal_to<void> , true)
__SYCL_OP_AVAILABILITY(std::less<void> , true)
__SYCL_OP_AVAILABILITY(std::greater<void> , true)
__SYCL_OP_AVAILABILITY(std::less_equal<void> , true)
__SYCL_OP_AVAILABILITY(std::greater_equal<void> , true)

__SYCL_OP_AVAILABILITY(std::logical_and<void> , not_byte<T> && not_fp<T>)
__SYCL_OP_AVAILABILITY(std::logical_or<void> , not_byte<T> && not_fp<T>)

__SYCL_OP_AVAILABILITY(ShiftLeft , not_byte<T> && not_fp<T>)
__SYCL_OP_AVAILABILITY(ShiftRight , not_byte<T> && not_fp<T>)

// Unary
__SYCL_OP_AVAILABILITY(std::negate<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::logical_not<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::bit_not<void> , not_fp<T>)
__SYCL_OP_AVAILABILITY(UnaryPlus , not_byte<T>)

__SYCL_OP_AVAILABILITY(IncDec , not_byte<T>)
#endif
// clang-format on

#undef __SYCL_OP_AVAILABILITY
Expand Down Expand Up @@ -188,6 +233,12 @@ template <typename Self> struct VecOperators {
using element_type = typename from_incomplete<Self>::element_type;
static constexpr int N = from_incomplete<Self>::size();

#if !__SYCL_USE_LIBSYCL8_VEC_IMPL
template <typename T>
static constexpr bool is_compatible_scalar =
std::is_convertible_v<T, typename from_incomplete<Self>::element_type>;
#endif

template <typename Op>
using result_t = std::conditional_t<
is_logical<Op>, vec<fixed_width_signed<sizeof(element_type)>, N>, Self>;
Expand Down Expand Up @@ -293,6 +344,7 @@ template <typename Self> struct VecOperators {
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, IncDec>>>
: public IncDecImpl<Self> {};

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
#define __SYCL_VEC_BINOP_MIXIN(OP, OPERATOR) \
template <typename Op> \
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
Expand Down Expand Up @@ -341,13 +393,60 @@ template <typename Self> struct VecOperators {
friend auto operator OPERATOR(const Self &v) { return apply<OP>(v); } \
};

#else

#define __SYCL_VEC_BINOP_MIXIN(OP, OPERATOR) \
template <typename Op> \
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
friend result_t<OP> operator OPERATOR(const Self & lhs, \
const Self & rhs) { \
return VecOperators::apply<OP>(lhs, rhs); \
} \
template <typename T> \
friend std::enable_if_t<is_compatible_scalar<T>, result_t<OP>> \
operator OPERATOR(const Self & lhs, const T & rhs) { \
return VecOperators::apply<OP>(lhs, Self{static_cast<T>(rhs)}); \
} \
template <typename T> \
friend std::enable_if_t<is_compatible_scalar<T>, result_t<OP>> \
operator OPERATOR(const T & lhs, const Self & rhs) { \
return VecOperators::apply<OP>(Self{static_cast<T>(lhs)}, rhs); \
} \
};

#define __SYCL_VEC_OPASSIGN_MIXIN(OP, OPERATOR) \
template <typename Op> \
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OpAssign<OP>>>> { \
friend Self &operator OPERATOR(Self & lhs, const Self & rhs) { \
lhs = OP{}(lhs, rhs); \
return lhs; \
} \
template <typename T> \
friend std::enable_if_t<is_compatible_scalar<T>, Self &> \
operator OPERATOR(Self & lhs, const T & rhs) { \
lhs = OP{}(lhs, rhs); \
return lhs; \
} \
};

#define __SYCL_VEC_UOP_MIXIN(OP, OPERATOR) \
template <typename Op> \
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
friend result_t<OP> operator OPERATOR(const Self & v) { \
return apply<OP>(v); \
} \
};

#endif

__SYCL_INSTANTIATE_OPERATORS(__SYCL_VEC_BINOP_MIXIN,
__SYCL_VEC_OPASSIGN_MIXIN, __SYCL_VEC_UOP_MIXIN)

#undef __SYCL_VEC_UOP_MIXIN
#undef __SYCL_VEC_OPASSIGN_MIXIN
#undef __SYCL_VEC_BINOP_MIXIN

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
template <typename Op>
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, std::bit_not<void>>>> {
template <typename T = typename from_incomplete<Self>::element_type>
Expand All @@ -356,6 +455,7 @@ template <typename Self> struct VecOperators {
return apply<std::bit_not<void>>(v);
}
};
#endif

template <typename... Op>
struct __SYCL_EBO CombineImpl : public OpMixin<Op>... {};
Expand All @@ -377,6 +477,7 @@ template <typename Self> struct VecOperators {
OpAssign<ShiftRight>, IncDec> {};
};

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
template <typename DataT, int NumElements>
class vec_arith : public VecOperators<vec<DataT, NumElements>>::Combined {};

Expand Down Expand Up @@ -427,6 +528,7 @@ class vec_arith<std::byte, NumElements>
}
};
#endif // (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
#endif

#undef __SYCL_INSTANTIATE_OPERATORS

Expand Down
31 changes: 23 additions & 8 deletions sycl/include/sycl/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,14 +318,18 @@ template <typename DataT> class vec_base<DataT, 1> {
// Provides a cross-platform vector class template that works efficiently on
// SYCL devices as well as in host C++ code.
template <typename DataT, int NumElements>
class __SYCL_EBO vec
: public detail::vec_arith<DataT, NumElements>,
public detail::ApplyIf<
NumElements == 1,
detail::ScalarConversionOperatorsMixIn<vec<DataT, NumElements>>>,
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>>,
// Keep it last to simplify ABI layout test:
public detail::vec_base<DataT, NumElements> {
class __SYCL_EBO vec :
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
public detail::vec_arith<DataT, NumElements>,
#else
public detail::VecOperators<vec<DataT, NumElements>>::Combined,
#endif
public detail::ApplyIf<
NumElements == 1,
detail::ScalarConversionOperatorsMixIn<vec<DataT, NumElements>>>,
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>>,
// Keep it last to simplify ABI layout test:
public detail::vec_base<DataT, NumElements> {
static_assert(std::is_same_v<DataT, std::remove_cv_t<DataT>>,
"DataT must be cv-unqualified");

Expand Down Expand Up @@ -408,6 +412,7 @@ class __SYCL_EBO vec
constexpr vec &operator=(const vec &) = default;
constexpr vec &operator=(vec &&) = default;

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
// Template required to prevent ambiguous overload with the copy assignment
// when NumElements == 1. The template prevents implicit conversion from
// vec<_, 1> to DataT.
Expand All @@ -427,6 +432,14 @@ class __SYCL_EBO vec
*this = Rhs.template as<vec>();
return *this;
}
#else
template <typename T>
typename std::enable_if_t<std::is_convertible_v<T, DataT>, vec &>
operator=(const T &Rhs) {
*this = vec{static_cast<DataT>(Rhs)};
return *this;
}
#endif

__SYCL2020_DEPRECATED("get_count() is deprecated, please use size() instead")
static constexpr size_t get_count() { return size(); }
Expand Down Expand Up @@ -536,8 +549,10 @@ class __SYCL_EBO vec
int... T5>
friend class detail::SwizzleOp;
template <typename T1, int T2> friend class __SYCL_EBO vec;
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
// To allow arithmetic operators access private members of vec.
template <typename T1, int T2> friend class detail::vec_arith;
#endif
};
///////////////////////// class sycl::vec /////////////////////////

Expand Down
2 changes: 2 additions & 0 deletions sycl/test-e2e/Basic/vector/byte.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ int main() {
assert(SwizByte2Neg[0] == ~SwizByte2B[0]);
}

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
{
// std::byte is not an arithmetic type and it only supports the following
// overloads of >> and << operators.
Expand Down Expand Up @@ -207,6 +208,7 @@ int main() {
assert(SwizShiftRight[0] == SwizByte2Shift[0] >> 3 &&
SwizShiftLeft[1] == SwizByte2Shift[1] << 3);
}
#endif
}

return 0;
Expand Down
4 changes: 2 additions & 2 deletions sycl/test-e2e/Basic/vector/vec_binary_scalar_order.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ bool CheckResult(sycl::vec<T1, N> V, T2 Ref) {
constexpr T RefVal = 2; \
VecT InVec{static_cast<T>(RefVal)}; \
{ \
VecT OutVecsDevice[2]; \
ResT OutVecsDevice[2]; \
T OutRefsDevice[2]; \
{ \
sycl::buffer<VecT, 1> OutVecsBuff{OutVecsDevice, 2}; \
sycl::buffer<ResT, 1> OutVecsBuff{OutVecsDevice, 2}; \
sycl::buffer<T, 1> OutRefsBuff{OutRefsDevice, 2}; \
Q.submit([&](sycl::handler &CGH) { \
sycl::accessor OutVecsAcc{OutVecsBuff, CGH, sycl::read_write}; \
Expand Down
16 changes: 8 additions & 8 deletions sycl/test-e2e/DeviceLib/built-ins/vector_integer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ int main() {

// abs
{
s::uint2 r{0};
s::int2 r{0};
{
s::buffer<s::uint2, 1> BufR(&r, s::range<1>(1));
s::buffer<s::int2, 1> BufR(&r, s::range<1>(1));
s::queue myQueue;
myQueue.submit([&](s::handler &cgh) {
auto AccR = BufR.get_access<s::access::mode::write>(cgh);
Expand All @@ -214,8 +214,8 @@ int main() {
});
});
}
unsigned int r1 = r.x();
unsigned int r2 = r.y();
int r1 = r.x();
int r2 = r.y();
assert(r1 == 5);
assert(r2 == 2);
}
Expand All @@ -240,9 +240,9 @@ int main() {

// abs_diff
{
s::uint2 r{0};
s::int2 r{0};
{
s::buffer<s::uint2, 1> BufR(&r, s::range<1>(1));
s::buffer<s::int2, 1> BufR(&r, s::range<1>(1));
s::queue myQueue;
myQueue.submit([&](s::handler &cgh) {
auto AccR = BufR.get_access<s::access::mode::write>(cgh);
Expand All @@ -251,8 +251,8 @@ int main() {
});
});
}
unsigned int r1 = r.x();
unsigned int r2 = r.y();
int r1 = r.x();
int r2 = r.y();
assert(r1 == 4);
assert(r2 == 1);
}
Expand Down
4 changes: 2 additions & 2 deletions sycl/test/basic_tests/vectors/assign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ using sw_double_2 = decltype(std::declval<vec<double, 4>>().swizzle<1, 2>());
// EXCEPT_IN_PREVIEW condition<>

static_assert( std::is_assignable_v<vec<half, 1>, half>);
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<half, 1>, float>);
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<half, 1>, double>);
static_assert( std::is_assignable_v<vec<half, 1>, float>);
static_assert( std::is_assignable_v<vec<half, 1>, double>);
static_assert( std::is_assignable_v<vec<half, 1>, vec<half, 1>>);
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<half, 1>, vec<float, 1>>);
static_assert(EXCEPT_IN_PREVIEW std::is_assignable_v<vec<half, 1>, vec<double, 1>>);
Expand Down