Skip to content

[NFC][SYCL] Switch to std:: equivalents for utilities in stl_type_traits.hpp #7668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sycl/include/sycl/atomic_ref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ template <sycl::access::address_space AS> struct IsValidAtomicRefAddressSpace {

// DefaultOrder parameter is limited to read-modify-write orders
template <memory_order Order>
using IsValidDefaultOrder = bool_constant<Order == memory_order::relaxed ||
Order == memory_order::acq_rel ||
Order == memory_order::seq_cst>;
using IsValidDefaultOrder = std::bool_constant<Order == memory_order::relaxed ||
Order == memory_order::acq_rel ||
Order == memory_order::seq_cst>;

template <memory_order ReadModifyWriteOrder> struct memory_order_traits;

Expand Down
18 changes: 9 additions & 9 deletions sycl/include/sycl/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ class buffer : public detail::buffer_plain,
// using same requirement for contiguous container as std::span
template <class Container>
using EnableIfContiguous =
detail::void_t<std::enable_if_t<std::is_convertible<
detail::remove_pointer_t<
decltype(std::declval<Container>().data())> (*)[],
const T (*)[]>::value>,
decltype(std::declval<Container>().size())>;
std::void_t<std::enable_if_t<std::is_convertible<
detail::remove_pointer_t<
decltype(std::declval<Container>().data())> (*)[],
const T (*)[]>::value>,
decltype(std::declval<Container>().size())>;
template <class It>
using EnableIfItInputIterator = std::enable_if_t<
std::is_convertible<typename std::iterator_traits<It>::iterator_category,
Expand Down Expand Up @@ -344,9 +344,9 @@ class buffer : public detail::buffer_plain,
using IteratorValueType =
detail::iterator_value_type_t<InputIterator>;
using IteratorNonConstValueType =
detail::remove_const_t<IteratorValueType>;
std::remove_const_t<IteratorValueType>;
using IteratorPointerToNonConstValueType =
detail::add_pointer_t<IteratorNonConstValueType>;
std::add_pointer_t<IteratorNonConstValueType>;
std::copy(first, last,
static_cast<IteratorPointerToNonConstValueType>(ToPtr));
},
Expand Down Expand Up @@ -377,9 +377,9 @@ class buffer : public detail::buffer_plain,
using IteratorValueType =
detail::iterator_value_type_t<InputIterator>;
using IteratorNonConstValueType =
detail::remove_const_t<IteratorValueType>;
std::remove_const_t<IteratorValueType>;
using IteratorPointerToNonConstValueType =
detail::add_pointer_t<IteratorNonConstValueType>;
std::add_pointer_t<IteratorNonConstValueType>;
std::copy(first, last,
static_cast<IteratorPointerToNonConstValueType>(ToPtr));
},
Expand Down
5 changes: 2 additions & 3 deletions sycl/include/sycl/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,8 @@ class __SYCL_EXPORT context : public detail::OwnerLessBase<context> {
friend decltype(Obj::impl) detail::getSyclObjImpl(const Obj &SyclObject);

template <class T>
friend
typename detail::add_pointer_t<typename decltype(T::impl)::element_type>
detail::getRawSyclObjImpl(const T &SyclObject);
friend typename std::add_pointer_t<typename decltype(T::impl)::element_type>
detail::getRawSyclObjImpl(const T &SyclObject);

template <class T>
friend T detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);
Expand Down
2 changes: 1 addition & 1 deletion sycl/include/sycl/detail/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ template <class Obj> decltype(Obj::impl) getSyclObjImpl(const Obj &SyclObject) {
// must make sure the returned pointer is not captured in a field or otherwise
// stored - i.e. must live only as on-stack value.
template <class T>
typename detail::add_pointer_t<typename decltype(T::impl)::element_type>
typename std::add_pointer_t<typename decltype(T::impl)::element_type>
getRawSyclObjImpl(const T &SyclObject) {
return SyclObject.impl.get();
}
Expand Down
51 changes: 27 additions & 24 deletions sycl/include/sycl/detail/generic_type_lists.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,36 +425,39 @@ using marray_byte_list = type_list<marray<std::byte, 1>, marray<std::byte, 2>,
#endif

// integer types
using scalar_signed_integer_list = type_list<
conditional_t<std::is_signed<char>::value,
using scalar_signed_integer_list =
type_list<std::conditional_t<
std::is_signed<char>::value,
type_list<scalar_default_char_list, scalar_signed_char_list>,
scalar_signed_char_list>,
scalar_signed_short_list, scalar_signed_int_list, scalar_signed_long_list,
scalar_signed_longlong_list>;
scalar_signed_short_list, scalar_signed_int_list,
scalar_signed_long_list, scalar_signed_longlong_list>;

using vector_signed_integer_list = type_list<
conditional_t<std::is_signed<char>::value,
using vector_signed_integer_list =
type_list<std::conditional_t<
std::is_signed<char>::value,
type_list<vector_default_char_list, vector_signed_char_list>,
vector_signed_char_list>,
vector_signed_short_list, vector_signed_int_list, vector_signed_long_list,
vector_signed_longlong_list>;
vector_signed_short_list, vector_signed_int_list,
vector_signed_long_list, vector_signed_longlong_list>;

using marray_signed_integer_list = type_list<
conditional_t<std::is_signed<char>::value,
using marray_signed_integer_list =
type_list<std::conditional_t<
std::is_signed<char>::value,
type_list<marray_default_char_list, marray_signed_char_list>,
marray_signed_char_list>,
marray_signed_short_list, marray_signed_int_list, marray_signed_long_list,
marray_signed_longlong_list>;
marray_signed_short_list, marray_signed_int_list,
marray_signed_long_list, marray_signed_longlong_list>;

using signed_integer_list =
type_list<scalar_signed_integer_list, vector_signed_integer_list,
marray_signed_integer_list>;

using scalar_unsigned_integer_list =
type_list<conditional_t<std::is_unsigned<char>::value,
type_list<scalar_default_char_list,
scalar_unsigned_char_list>,
scalar_unsigned_char_list>,
type_list<std::conditional_t<std::is_unsigned<char>::value,
type_list<scalar_default_char_list,
scalar_unsigned_char_list>,
scalar_unsigned_char_list>,
scalar_unsigned_short_list, scalar_unsigned_int_list,
scalar_unsigned_long_list, scalar_unsigned_longlong_list
#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
Expand All @@ -464,10 +467,10 @@ using scalar_unsigned_integer_list =
>;

using vector_unsigned_integer_list =
type_list<conditional_t<std::is_unsigned<char>::value,
type_list<vector_default_char_list,
vector_unsigned_char_list>,
vector_unsigned_char_list>,
type_list<std::conditional_t<std::is_unsigned<char>::value,
type_list<vector_default_char_list,
vector_unsigned_char_list>,
vector_unsigned_char_list>,
vector_unsigned_short_list, vector_unsigned_int_list,
vector_unsigned_long_list, vector_unsigned_longlong_list
#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
Expand All @@ -477,10 +480,10 @@ using vector_unsigned_integer_list =
>;

using marray_unsigned_integer_list =
type_list<conditional_t<std::is_unsigned<char>::value,
type_list<marray_default_char_list,
marray_unsigned_char_list>,
marray_unsigned_char_list>,
type_list<std::conditional_t<std::is_unsigned<char>::value,
type_list<marray_default_char_list,
marray_unsigned_char_list>,
marray_unsigned_char_list>,
marray_unsigned_short_list, marray_unsigned_int_list,
marray_unsigned_long_list, marray_unsigned_longlong_list
#if __cplusplus >= 201703L && (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
Expand Down
45 changes: 23 additions & 22 deletions sycl/include/sycl/detail/generic_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,17 @@ template <typename T>
using is_geninteger64bit = is_gen_based_on_type_sizeof<T, 8, is_geninteger>;

template <typename T>
using is_genintptr = bool_constant<
using is_genintptr = std::bool_constant<
is_pointer<T>::value && is_genint<remove_pointer_t<T>>::value &&
is_address_space_compliant<T, gvl::nonconst_address_space_list>::value>;

template <typename T>
using is_genfloatptr = bool_constant<
using is_genfloatptr = std::bool_constant<
is_pointer<T>::value && is_genfloat<remove_pointer_t<T>>::value &&
is_address_space_compliant<T, gvl::nonconst_address_space_list>::value>;

template <typename T>
using is_genptr = bool_constant<
using is_genptr = std::bool_constant<
is_pointer<T>::value && is_gentype<remove_pointer_t<T>>::value &&
is_address_space_compliant<T, gvl::nonconst_address_space_list>::value>;

Expand Down Expand Up @@ -424,10 +424,10 @@ using mptr_or_vec_elem_type_t = typename mptr_or_vec_elem_type<T>::type;
// select_apply_cl_scalar_t selects from T8/T16/T32/T64 basing on
// sizeof(IN). expected to handle scalar types.
template <typename T, typename T8, typename T16, typename T32, typename T64>
using select_apply_cl_scalar_t =
conditional_t<sizeof(T) == 1, T8,
conditional_t<sizeof(T) == 2, T16,
conditional_t<sizeof(T) == 4, T32, T64>>>;
using select_apply_cl_scalar_t = std::conditional_t<
sizeof(T) == 1, T8,
std::conditional_t<sizeof(T) == 2, T16,
std::conditional_t<sizeof(T) == 4, T32, T64>>>;

// Shortcuts for selecting scalar int/unsigned int/fp type.
template <typename T>
Expand All @@ -447,21 +447,21 @@ using select_cl_scalar_float_t =

template <typename T>
using select_cl_scalar_integral_t =
conditional_t<std::is_signed<T>::value,
select_cl_scalar_integral_signed_t<T>,
select_cl_scalar_integral_unsigned_t<T>>;
std::conditional_t<std::is_signed<T>::value,
select_cl_scalar_integral_signed_t<T>,
select_cl_scalar_integral_unsigned_t<T>>;

// select_cl_scalar_t picks corresponding cl_* type for input
// scalar T or returns T if T is not scalar.
template <typename T>
using select_cl_scalar_t = conditional_t<
using select_cl_scalar_t = std::conditional_t<
std::is_integral<T>::value, select_cl_scalar_integral_t<T>,
conditional_t<
std::conditional_t<
std::is_floating_point<T>::value, select_cl_scalar_float_t<T>,
// half is a special case: it is implemented differently on host and
// device and therefore, might lower to different types
conditional_t<std::is_same<T, half>::value,
sycl::detail::half_impl::BIsRepresentationT, T>>>;
std::conditional_t<std::is_same<T, half>::value,
sycl::detail::half_impl::BIsRepresentationT, T>>>;

// select_cl_vector_or_scalar_or_ptr does cl_* type selection for element type
// of a vector type T, pointer type substitution, and scalar type substitution.
Expand All @@ -476,9 +476,10 @@ struct select_cl_vector_or_scalar_or_ptr<
// select_cl_scalar_t returns _Float16, so, we try to instantiate vec
// class with _Float16 DataType, which is not expected there
// So, leave vector<half, N> as-is
vec<conditional_t<std::is_same<mptr_or_vec_elem_type_t<T>, half>::value,
mptr_or_vec_elem_type_t<T>,
select_cl_scalar_t<mptr_or_vec_elem_type_t<T>>>,
vec<std::conditional_t<
std::is_same<mptr_or_vec_elem_type_t<T>, half>::value,
mptr_or_vec_elem_type_t<T>,
select_cl_scalar_t<mptr_or_vec_elem_type_t<T>>>,
T::size()>;
};

Expand Down Expand Up @@ -547,10 +548,10 @@ using SelectMatchingOpenCLType_t =
// Converts T to OpenCL friendly
//
template <typename T>
using ConvertToOpenCLType_t = conditional_t<
using ConvertToOpenCLType_t = std::conditional_t<
TryToGetVectorT<SelectMatchingOpenCLType_t<T>>::value,
typename TryToGetVectorT<SelectMatchingOpenCLType_t<T>>::type,
conditional_t<
std::conditional_t<
TryToGetPointerT<SelectMatchingOpenCLType_t<T>>::value,
typename TryToGetPointerVecT<SelectMatchingOpenCLType_t<T>>::type,
SelectMatchingOpenCLType_t<T>>>;
Expand Down Expand Up @@ -593,12 +594,12 @@ template <typename T> inline constexpr bool msbIsSet(const T x) {
// TODO: marray support isn't implemented yet.
template <typename T>
using common_rel_ret_t =
conditional_t<is_vgentype<T>::value, make_singed_integer_t<T>, bool>;
std::conditional_t<is_vgentype<T>::value, make_singed_integer_t<T>, bool>;

// TODO: Remove this when common_rel_ret_t is promoted.
template <typename T>
using internal_host_rel_ret_t =
conditional_t<is_vgentype<T>::value, make_singed_integer_t<T>, int>;
std::conditional_t<is_vgentype<T>::value, make_singed_integer_t<T>, int>;
#else
// SYCL 1.2.1 4.13.7 (Relation functions), e.g.
//
Expand All @@ -612,7 +613,7 @@ using internal_host_rel_ret_t =
// Fixing it would be an ABI-breaking change so isn't done.
template <typename T>
using common_rel_ret_t =
conditional_t<is_vgentype<T>::value, make_singed_integer_t<T>, int>;
std::conditional_t<is_vgentype<T>::value, make_singed_integer_t<T>, int>;
template <typename T> using internal_host_rel_ret_t = common_rel_ret_t<T>;
#endif

Expand Down
20 changes: 11 additions & 9 deletions sycl/include/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ template <typename Group> bool GroupAny(bool pred) {
// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
// FIXME: Do not special-case for half once all backends support all data types.
template <typename T>
using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value &&
!std::is_same<T, half>::value>;
using is_native_broadcast =
std::bool_constant<detail::is_arithmetic<T>::value &&
!std::is_same<T, half>::value>;

template <typename T, typename IdT = size_t>
using EnableIfNativeBroadcast = std::enable_if_t<
Expand All @@ -115,7 +116,7 @@ using EnableIfNativeBroadcast = std::enable_if_t<
// Bitcast broadcasts can be implemented using a single SPIR-V GroupBroadcast
// intrinsic, but require type-punning via an appropriate integer type
template <typename T>
using is_bitcast_broadcast = bool_constant<
using is_bitcast_broadcast = std::bool_constant<
!is_native_broadcast<T>::value && std::is_trivially_copyable<T>::value &&
(sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8)>;

Expand All @@ -132,20 +133,21 @@ using ConvertToNativeBroadcastType_t = select_cl_scalar_integral_unsigned_t<T>;
// - At most one 32-bit, 16-bit and 8-bit chunk left over
template <typename T>
using is_generic_broadcast =
bool_constant<!is_native_broadcast<T>::value &&
!is_bitcast_broadcast<T>::value &&
std::is_trivially_copyable<T>::value>;
std::bool_constant<!is_native_broadcast<T>::value &&
!is_bitcast_broadcast<T>::value &&
std::is_trivially_copyable<T>::value>;

template <typename T, typename IdT = size_t>
using EnableIfGenericBroadcast = std::enable_if_t<
is_generic_broadcast<T>::value && std::is_integral<IdT>::value, T>;

// FIXME: Disable widening once all backends support all data types.
template <typename T>
using WidenOpenCLTypeTo32_t = conditional_t<
using WidenOpenCLTypeTo32_t = std::conditional_t<
std::is_same<T, cl_char>() || std::is_same<T, cl_short>(), cl_int,
conditional_t<std::is_same<T, cl_uchar>() || std::is_same<T, cl_ushort>(),
cl_uint, T>>;
std::conditional_t<std::is_same<T, cl_uchar>() ||
std::is_same<T, cl_ushort>(),
cl_uint, T>>;

// Broadcast with scalar local index
// Work-group supports any integral type
Expand Down
37 changes: 6 additions & 31 deletions sycl/include/sycl/detail/stl_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,9 @@ namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace detail {

// Type traits identical to those in std in newer versions. Can be removed when
// SYCL requires a newer version of the C++ standard.
// C++14
template <bool B, class T, class F>
using conditional_t = typename std::conditional<B, T, F>::type;

template <typename T>
using remove_const_t = typename std::remove_const<T>::type;

template <typename T> using remove_cv_t = typename std::remove_cv<T>::type;

template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;

template <typename T> using add_pointer_t = typename std::add_pointer<T>::type;

// C++17
template <bool V> using bool_constant = std::integral_constant<bool, V>;

template <class...> using void_t = void;

// Custom type traits
template <typename T>
using allocator_value_type_t = typename std::allocator_traits<T>::value_type;

template <typename T>
using allocator_pointer_t = typename std::allocator_traits<T>::pointer;

// Custom type traits.
// FIXME: Those doesn't seem to be a part of any published/future C++ standard
// so should probably be moved to a different place.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably can drop stl_ prefix from the file name to avoid confusion. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have detail/type_traits.hpp already that has lots of custom stuff (and remove_pointer that seems to behave differently than std:: version). I think that the best would be to move this code to detail/type_traits.hpp.

As for this file, I'm planning to add C++20's remove_cvref_t here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up on the previous comment - sycl::detail::remove_pointer difference comes from sycl::multi_ptr support, I think.

template <typename T>
using iterator_category_t = typename std::iterator_traits<T>::iterator_category;

Expand All @@ -62,9 +37,9 @@ using iterator_to_const_type_t =
// https://en.cppreference.com/w/cpp/named_req/OutputIterator
template <typename T>
using output_iterator_requirements =
void_t<iterator_category_t<T>,
decltype(*std::declval<T>() =
std::declval<iterator_value_type_t<T>>())>;
std::void_t<iterator_category_t<T>,
decltype(*std::declval<T>() =
std::declval<iterator_value_type_t<T>>())>;

template <typename, typename = void> struct is_output_iterator {
static constexpr bool value = false;
Expand Down
Loading