Skip to content

[SYCL] Fix 1-element vec ambiguities #17722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 2, 2025
Merged
8 changes: 8 additions & 0 deletions sycl/include/sycl/detail/type_traits/vec_marray_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@

#include <sycl/detail/defines_elementary.hpp>

#ifndef __SYCL_USE_LIBSYCL8_VEC_IMPL
#if defined(__INTEL_PREVIEW_BREAKING_CHANGES)
#define __SYCL_USE_LIBSYCL8_VEC_IMPL 0
#else
#define __SYCL_USE_LIBSYCL8_VEC_IMPL 1
#endif
#endif

namespace sycl {
inline namespace _V1 {
template <typename DataT, int NumElements> class __SYCL_EBO vec;
Expand Down
102 changes: 102 additions & 0 deletions sycl/include/sycl/detail/vector_arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ struct IncDec {};

template <class T> static constexpr bool not_fp = !is_vgenfloat_v<T>;

#if !__SYCL_USE_LIBSYCL8_VEC_IMPL
// Not using `is_byte_v` to avoid unnecessary dependencies on `half`/`bfloat16`
// headers.
template <class T>
static constexpr bool not_byte =
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
!std::is_same_v<T, std::byte>;
#else
true;
#endif
#endif

// To provide information about operators availability depending on vec/swizzle
// element type.
template <typename Op, typename T>
Expand All @@ -80,6 +92,7 @@ inline constexpr bool is_op_available_for_type<OpAssign<Op>, T> =
inline constexpr bool is_op_available_for_type<OP, T> = COND;

// clang-format off
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
__SYCL_OP_AVAILABILITY(std::plus<void> , true)
__SYCL_OP_AVAILABILITY(std::minus<void> , true)
__SYCL_OP_AVAILABILITY(std::multiplies<void> , true)
Expand Down Expand Up @@ -110,6 +123,38 @@ __SYCL_OP_AVAILABILITY(std::bit_not<void> , not_fp<T>)
__SYCL_OP_AVAILABILITY(UnaryPlus , true)

__SYCL_OP_AVAILABILITY(IncDec , true)
#else
__SYCL_OP_AVAILABILITY(std::plus<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::minus<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::multiplies<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::divides<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::modulus<void> , not_fp<T>)

__SYCL_OP_AVAILABILITY(std::bit_and<void> , not_fp<T>)
__SYCL_OP_AVAILABILITY(std::bit_or<void> , not_fp<T>)
__SYCL_OP_AVAILABILITY(std::bit_xor<void> , not_fp<T>)

__SYCL_OP_AVAILABILITY(std::equal_to<void> , true)
__SYCL_OP_AVAILABILITY(std::not_equal_to<void> , true)
__SYCL_OP_AVAILABILITY(std::less<void> , true)
__SYCL_OP_AVAILABILITY(std::greater<void> , true)
__SYCL_OP_AVAILABILITY(std::less_equal<void> , true)
__SYCL_OP_AVAILABILITY(std::greater_equal<void> , true)

__SYCL_OP_AVAILABILITY(std::logical_and<void> , not_byte<T> && not_fp<T>)
__SYCL_OP_AVAILABILITY(std::logical_or<void> , not_byte<T> && not_fp<T>)

__SYCL_OP_AVAILABILITY(ShiftLeft , not_byte<T> && not_fp<T>)
__SYCL_OP_AVAILABILITY(ShiftRight , not_byte<T> && not_fp<T>)

// Unary
__SYCL_OP_AVAILABILITY(std::negate<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::logical_not<void> , not_byte<T>)
__SYCL_OP_AVAILABILITY(std::bit_not<void> , not_fp<T>)
__SYCL_OP_AVAILABILITY(UnaryPlus , not_byte<T>)

__SYCL_OP_AVAILABILITY(IncDec , not_byte<T>)
#endif
// clang-format on

#undef __SYCL_OP_AVAILABILITY
Expand Down Expand Up @@ -188,6 +233,12 @@ template <typename Self> struct VecOperators {
using element_type = typename from_incomplete<Self>::element_type;
static constexpr int N = from_incomplete<Self>::size();

#if !__SYCL_USE_LIBSYCL8_VEC_IMPL
template <typename T>
static constexpr bool is_compatible_scalar =
std::is_convertible_v<T, typename from_incomplete<Self>::element_type>;
#endif

template <typename Op>
using result_t = std::conditional_t<
is_logical<Op>, vec<fixed_width_signed<sizeof(element_type)>, N>, Self>;
Expand Down Expand Up @@ -293,6 +344,7 @@ template <typename Self> struct VecOperators {
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, IncDec>>>
: public IncDecImpl<Self> {};

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
#define __SYCL_VEC_BINOP_MIXIN(OP, OPERATOR) \
template <typename Op> \
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
Expand Down Expand Up @@ -341,13 +393,60 @@ template <typename Self> struct VecOperators {
friend auto operator OPERATOR(const Self &v) { return apply<OP>(v); } \
};

#else

#define __SYCL_VEC_BINOP_MIXIN(OP, OPERATOR) \
template <typename Op> \
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
friend result_t<OP> operator OPERATOR(const Self & lhs, \
const Self & rhs) { \
return VecOperators::apply<OP>(lhs, rhs); \
} \
template <typename T> \
friend std::enable_if_t<is_compatible_scalar<T>, result_t<OP>> \
operator OPERATOR(const Self & lhs, const T & rhs) { \
return VecOperators::apply<OP>(lhs, Self{static_cast<T>(rhs)}); \
} \
template <typename T> \
friend std::enable_if_t<is_compatible_scalar<T>, result_t<OP>> \
operator OPERATOR(const T & lhs, const Self & rhs) { \
return VecOperators::apply<OP>(Self{static_cast<T>(lhs)}, rhs); \
} \
};

#define __SYCL_VEC_OPASSIGN_MIXIN(OP, OPERATOR) \
template <typename Op> \
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OpAssign<OP>>>> { \
friend Self &operator OPERATOR(Self & lhs, const Self & rhs) { \
lhs = OP{}(lhs, rhs); \
return lhs; \
} \
template <typename T> \
friend std::enable_if_t<is_compatible_scalar<T>, Self &> \
operator OPERATOR(Self & lhs, const T & rhs) { \
lhs = OP{}(lhs, rhs); \
return lhs; \
} \
};

#define __SYCL_VEC_UOP_MIXIN(OP, OPERATOR) \
template <typename Op> \
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
friend result_t<OP> operator OPERATOR(const Self & v) { \
return apply<OP>(v); \
} \
};

#endif

__SYCL_INSTANTIATE_OPERATORS(__SYCL_VEC_BINOP_MIXIN,
__SYCL_VEC_OPASSIGN_MIXIN, __SYCL_VEC_UOP_MIXIN)

#undef __SYCL_VEC_UOP_MIXIN
#undef __SYCL_VEC_OPASSIGN_MIXIN
#undef __SYCL_VEC_BINOP_MIXIN

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
template <typename Op>
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, std::bit_not<void>>>> {
template <typename T = typename from_incomplete<Self>::element_type>
Expand All @@ -356,6 +455,7 @@ template <typename Self> struct VecOperators {
return apply<std::bit_not<void>>(v);
}
};
#endif

template <typename... Op>
struct __SYCL_EBO CombineImpl : public OpMixin<Op>... {};
Expand All @@ -377,6 +477,7 @@ template <typename Self> struct VecOperators {
OpAssign<ShiftRight>, IncDec> {};
};

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
template <typename DataT, int NumElements>
class vec_arith : public VecOperators<vec<DataT, NumElements>>::Combined {};

Expand Down Expand Up @@ -427,6 +528,7 @@ class vec_arith<std::byte, NumElements>
}
};
#endif // (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
#endif

#undef __SYCL_INSTANTIATE_OPERATORS

Expand Down
110 changes: 81 additions & 29 deletions sycl/include/sycl/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@

// See vec::DataType definitions for more details
#ifndef __SYCL_USE_PLAIN_ARRAY_AS_VEC_STORAGE
#if defined(__INTEL_PREVIEW_BREAKING_CHANGES)
#define __SYCL_USE_PLAIN_ARRAY_AS_VEC_STORAGE 1
#else
#define __SYCL_USE_PLAIN_ARRAY_AS_VEC_STORAGE 0
#endif
#define __SYCL_USE_PLAIN_ARRAY_AS_VEC_STORAGE !__SYCL_USE_LIBSYCL8_VEC_IMPL
#endif

#if !defined(__HAS_EXT_VECTOR_TYPE__) && defined(__SYCL_DEVICE_ONLY__)
Expand Down Expand Up @@ -125,31 +121,35 @@ template <typename T> class GetOp {
//
// must go throw `v.x()` returning a swizzle, then its `operator==` returning
// vec<int, 1> and we want that code to compile.
template <typename Self> class ScalarConversionOperatorMixIn {
using T = typename from_incomplete<Self>::element_type;
template <typename Self> class ScalarConversionOperatorsMixIn {
using element_type = typename from_incomplete<Self>::element_type;

public:
operator T() const { return (*static_cast<const Self *>(this))[0]; }
operator element_type() const {
return (*static_cast<const Self *>(this))[0];
}

#if !__SYCL_USE_LIBSYCL8_VEC_IMPL
template <
typename T, typename = std::enable_if_t<!std::is_same_v<T, element_type>>,
typename =
std::void_t<decltype(static_cast<T>(std::declval<element_type>()))>>
explicit operator T() const {
return static_cast<T>((*static_cast<const Self *>(this))[0]);
}
#endif
};

template <typename T>
inline constexpr bool is_fundamental_or_half_or_bfloat16 =
std::is_fundamental_v<T> || std::is_same_v<std::remove_const_t<T>, half> ||
std::is_same_v<std::remove_const_t<T>, ext::oneapi::bfloat16>;

// Proposed SYCL specification changes have sycl::vec having different ctors
// available based on the number of elements. Without C++20's concepts we'll
// have to use partial specialization to represent that. This is a helper to do
// that. An alternative could be to have different specializations of the
// `sycl::vec` itself but then we'd need to outline all the common interfaces to
// re-use them.
//
// Note: the functional changes haven't been implemented yet, we've split
// vec_base in advance as a way to make changes easier to review/verify.
//
// Another note: `vector_t` is going to be removed, so corresponding ctor was
// kept inside `sycl::vec` to have all `vector_t` functionality in a single
// place.
// Per SYCL specification sycl::vec has different ctors available based on the
// number of elements. Without C++20's concepts we'd have to use partial
// specialization to represent that. This is a helper to do that. An alternative
// could be to have different specializations of the `sycl::vec` itself but then
// we'd need to outline all the common interfaces to re-use them.
template <typename DataT, int NumElements> class vec_base {
// https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#memory-layout-and-alignment
// It is required by the SPEC to align vec<DataT, 3> with vec<DataT, 4>.
Expand Down Expand Up @@ -271,20 +271,61 @@ template <typename DataT, int NumElements> class vec_base {
: vec_base{VecArgArrayCreator<DataT, argTN...>::Create(args...),
std::make_index_sequence<NumElements>()} {}
};

#if !__SYCL_USE_LIBSYCL8_VEC_IMPL
template <typename DataT> class vec_base<DataT, 1> {
using DataType = std::conditional_t<
#if __SYCL_USE_PLAIN_ARRAY_AS_VEC_STORAGE
true,
#else
sizeof(std::array<DataT, 1>) == sizeof(DataT[1]) &&
alignof(std::array<DataT, 1>) == alignof(DataT[1]),
#endif
DataT[1], std::array<DataT, 1>>;

protected:
static constexpr int alignment = (std::min)((size_t)64, sizeof(DataType));
alignas(alignment) DataType m_Data{};

public:
constexpr vec_base() = default;
constexpr vec_base(const vec_base &) = default;
constexpr vec_base(vec_base &&) = default;
constexpr vec_base &operator=(const vec_base &) = default;
constexpr vec_base &operator=(vec_base &&) = default;

// Not `explicit` on purpose, differs from NumElements > 1.
constexpr vec_base(const DataT &arg) : m_Data{{arg}} {}

// FIXME: Temporary workaround because swizzle's `operator DataT` is a
// template.
template <typename Swizzle,
typename = std::enable_if_t<is_swizzle_v<Swizzle>>,
typename = std::enable_if_t<Swizzle::size() == 1>,
typename = std::enable_if<
std::is_convertible_v<typename Swizzle::element_type, DataT>>>
constexpr vec_base(const Swizzle &other)
: vec_base(static_cast<DataT>(other)) {}
};
#endif
} // namespace detail

///////////////////////// class sycl::vec /////////////////////////
// Provides a cross-platform vector class template that works efficiently on
// SYCL devices as well as in host C++ code.
template <typename DataT, int NumElements>
class __SYCL_EBO vec
: public detail::vec_arith<DataT, NumElements>,
public detail::ApplyIf<
NumElements == 1,
detail::ScalarConversionOperatorMixIn<vec<DataT, NumElements>>>,
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>>,
// Keep it last to simplify ABI layout test:
public detail::vec_base<DataT, NumElements> {
class __SYCL_EBO vec :
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
public detail::vec_arith<DataT, NumElements>,
#else
public detail::VecOperators<vec<DataT, NumElements>>::Combined,
#endif
public detail::ApplyIf<
NumElements == 1,
detail::ScalarConversionOperatorsMixIn<vec<DataT, NumElements>>>,
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>>,
// Keep it last to simplify ABI layout test:
public detail::vec_base<DataT, NumElements> {
static_assert(std::is_same_v<DataT, std::remove_cv_t<DataT>>,
"DataT must be cv-unqualified");

Expand Down Expand Up @@ -367,6 +408,7 @@ class __SYCL_EBO vec
constexpr vec &operator=(const vec &) = default;
constexpr vec &operator=(vec &&) = default;

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
// Template required to prevent ambiguous overload with the copy assignment
// when NumElements == 1. The template prevents implicit conversion from
// vec<_, 1> to DataT.
Expand All @@ -386,6 +428,14 @@ class __SYCL_EBO vec
*this = Rhs.template as<vec>();
return *this;
}
#else
template <typename T>
typename std::enable_if_t<std::is_convertible_v<T, DataT>, vec &>
operator=(const T &Rhs) {
*this = vec{static_cast<DataT>(Rhs)};
return *this;
}
#endif

__SYCL2020_DEPRECATED("get_count() is deprecated, please use size() instead")
static constexpr size_t get_count() { return size(); }
Expand Down Expand Up @@ -495,8 +545,10 @@ class __SYCL_EBO vec
int... T5>
friend class detail::SwizzleOp;
template <typename T1, int T2> friend class __SYCL_EBO vec;
#if __SYCL_USE_LIBSYCL8_VEC_IMPL
// To allow arithmetic operators access private members of vec.
template <typename T1, int T2> friend class detail::vec_arith;
#endif
};
///////////////////////// class sycl::vec /////////////////////////

Expand Down
2 changes: 2 additions & 0 deletions sycl/test-e2e/Basic/vector/byte.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ int main() {
assert(SwizByte2Neg[0] == ~SwizByte2B[0]);
}

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
{
// std::byte is not an arithmetic type and it only supports the following
// overloads of >> and << operators.
Expand Down Expand Up @@ -207,6 +208,7 @@ int main() {
assert(SwizShiftRight[0] == SwizByte2Shift[0] >> 3 &&
SwizShiftLeft[1] == SwizByte2Shift[1] << 3);
}
#endif
}

return 0;
Expand Down
4 changes: 2 additions & 2 deletions sycl/test-e2e/Basic/vector/vec_binary_scalar_order.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ bool CheckResult(sycl::vec<T1, N> V, T2 Ref) {
constexpr T RefVal = 2; \
VecT InVec{static_cast<T>(RefVal)}; \
{ \
VecT OutVecsDevice[2]; \
ResT OutVecsDevice[2]; \
T OutRefsDevice[2]; \
{ \
sycl::buffer<VecT, 1> OutVecsBuff{OutVecsDevice, 2}; \
sycl::buffer<ResT, 1> OutVecsBuff{OutVecsDevice, 2}; \
sycl::buffer<T, 1> OutRefsBuff{OutRefsDevice, 2}; \
Q.submit([&](sycl::handler &CGH) { \
sycl::accessor OutVecsAcc{OutVecsBuff, CGH, sycl::read_write}; \
Expand Down
Loading
Loading