Skip to content

Commit

Permalink
add patch to unwrap nested tuples of iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
trxcllnt committed Apr 5, 2024
1 parent 366657b commit 23bdc22
Show file tree
Hide file tree
Showing 2 changed files with 384 additions and 3 deletions.
376 changes: 376 additions & 0 deletions rapids-cmake/cpm/patches/cccl/unwrap_nested_tuple_of_iterators.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,376 @@
diff --git a/libcudacxx/include/cuda/std/detail/libcxx/include/tuple b/libcudacxx/include/cuda/std/detail/libcxx/include/tuple
index a02f31fa8..a3e6f6e5d 100644
--- a/libcudacxx/include/cuda/std/detail/libcxx/include/tuple
+++ b/libcudacxx/include/cuda/std/detail/libcxx/include/tuple
@@ -197,6 +197,10 @@ template <class... Types>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

+template<class>
+struct __is_tuple_of_iterator_references : false_type
+{};
+
// __tuple_leaf
struct __tuple_leaf_default_constructor_tag {};

@@ -851,6 +855,15 @@ public:
_Tp...>::template __tuple_like_constraints<_Tuple>,
__invalid_tuple_constraints>;

+ // Horrible hack to make tuple_of_iterator_references work
+ template <class _TupleOfIteratorReferences,
+ __enable_if_t<__is_tuple_of_iterator_references<_TupleOfIteratorReferences>::value, int> = 0,
+ __enable_if_t<(tuple_size<_TupleOfIteratorReferences>::value == sizeof...(_Tp)), int> = 0>
+ _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 tuple(_TupleOfIteratorReferences&& __t)
+ : tuple(_CUDA_VSTD::forward<_TupleOfIteratorReferences>(__t).template __to_tuple<_Tp...>(
+ __make_tuple_indices_t<sizeof...(_Tp)>()))
+ {}
+
template <
class _Tuple, class _Constraints = __tuple_like_constraints<_Tuple>,
__enable_if_t<!_PackExpandsToThisTuple<_Tuple>::value, int> = 0,
diff --git a/thrust/testing/zip_function.cu b/thrust/testing/zip_function.cu
index a1545a1a1..9f038f907 100644
--- a/thrust/testing/zip_function.cu
+++ b/thrust/testing/zip_function.cu
@@ -2,13 +2,17 @@

#if THRUST_CPP_DIALECT >= 2011 && !defined(THRUST_LEGACY_GCC)

-#include <unittest/unittest.h>
+#include <thrust/device_vector.h>
#include <thrust/iterator/zip_iterator.h>
+#include <thrust/remove.h>
+#include <thrust/sort.h>
#include <thrust/transform.h>
#include <thrust/zip_function.h>

#include <iostream>

+#include <unittest/unittest.h>
+
using namespace unittest;

struct SumThree
@@ -67,4 +71,98 @@ struct TestZipFunctionTransform
};
VariableUnitTest<TestZipFunctionTransform, ThirtyTwoBitTypes> TestZipFunctionTransformInstance;

+struct RemovePred
+{
+ __host__ __device__ bool operator()(const thrust::tuple<uint32_t, uint32_t>& ele1, const float&)
+ {
+ return thrust::get<0>(ele1) == thrust::get<1>(ele1);
+ }
+};
+template <typename T>
+struct TestZipFunctionMixed
+{
+ void operator()()
+ {
+ thrust::device_vector<uint32_t> vecA{0, 0, 2, 0};
+ thrust::device_vector<uint32_t> vecB{0, 2, 2, 2};
+ thrust::device_vector<float> vecC{88.0f, 88.0f, 89.0f, 89.0f};
+ thrust::device_vector<float> expected{88.0f, 89.0f};
+
+ auto inputKeyItBegin =
+ thrust::make_zip_iterator(thrust::make_zip_iterator(vecA.begin(), vecB.begin()), vecC.begin());
+ auto endIt =
+ thrust::remove_if(inputKeyItBegin, inputKeyItBegin + vecA.size(), thrust::make_zip_function(RemovePred{}));
+ auto numEle = endIt - inputKeyItBegin;
+ vecA.resize(numEle);
+ vecB.resize(numEle);
+ vecC.resize(numEle);
+
+ ASSERT_EQUAL(numEle, 2);
+ ASSERT_EQUAL(vecC, expected);
+ }
+};
+SimpleUnitTest<TestZipFunctionMixed, type_list<int, float> > TestZipFunctionMixedInstance;
+
+struct NestedFunctionCall
+{
+ __host__ __device__ bool
+ operator()(const thrust::tuple<uint32_t, thrust::tuple<thrust::tuple<int, int>, thrust::tuple<int, int>>>& idAndPt)
+ {
+ thrust::tuple<thrust::tuple<int, int>, thrust::tuple<int, int>> ele1 = thrust::get<1>(idAndPt);
+ thrust::tuple<int, int> p1 = thrust::get<0>(ele1);
+ thrust::tuple<int, int> p2 = thrust::get<1>(ele1);
+ return thrust::get<0>(p1) == thrust::get<0>(p2) || thrust::get<1>(p1) == thrust::get<1>(p2);
+ }
+};
+
+template <typename T>
+struct TestNestedZipFunction
+{
+ void operator()()
+ {
+ thrust::device_vector<int> PX{0, 1, 2, 3};
+ thrust::device_vector<int> PY{0, 1, 2, 2};
+ thrust::device_vector<uint32_t> SS{0, 1, 2};
+ thrust::device_vector<uint32_t> ST{1, 2, 3};
+ thrust::device_vector<float> vecC{88.0f, 88.0f, 89.0f, 89.0f};
+
+ auto segIt = thrust::make_zip_iterator(
+ thrust::make_zip_iterator(thrust::make_permutation_iterator(PX.begin(), SS.begin()),
+ thrust::make_permutation_iterator(PY.begin(), SS.begin())),
+ thrust::make_zip_iterator(thrust::make_permutation_iterator(PX.begin(), ST.begin()),
+ thrust::make_permutation_iterator(PY.begin(), ST.begin())));
+ auto idAndSegIt = thrust::make_zip_iterator(thrust::make_counting_iterator(0u), segIt);
+
+ thrust::device_vector<bool> isMH{false, false, false};
+ thrust::device_vector<bool> expected{false, false, true};
+ thrust::transform(idAndSegIt, idAndSegIt + SS.size(), isMH.begin(), NestedFunctionCall{});
+ ASSERT_EQUAL(isMH, expected);
+ }
+};
+SimpleUnitTest<TestNestedZipFunction, type_list<int, float> > TestNestedZipFunctionInstance;
+
+struct SortPred {
+ __device__ __forceinline__
+ bool operator()(const thrust::tuple<thrust::tuple<int, int>, int>& a,
+ const thrust::tuple<thrust::tuple<int, int>, int>& b) {
+ return thrust::get<1>(a) < thrust::get<1>(b);
+ }
+};
+template <typename T>
+struct TestNestedZipFunction2
+{
+ void operator()()
+ {
+ thrust::device_vector<int> A(5);
+ thrust::device_vector<int> B(5);
+ thrust::device_vector<int> C(5);
+ auto n = A.size();
+
+ auto tupleIt = thrust::make_zip_iterator(cuda::std::begin(A), cuda::std::begin(B));
+ auto nestedTupleIt = thrust::make_zip_iterator(tupleIt, cuda::std::begin(C));
+ thrust::sort(nestedTupleIt, nestedTupleIt + n, SortPred{});
+ }
+};
+SimpleUnitTest<TestNestedZipFunction2, type_list<int, float> > TestNestedZipFunctionInstance2;
+
#endif // THRUST_CPP_DIALECT
diff --git a/thrust/thrust/iterator/detail/tuple_of_iterator_references.h b/thrust/thrust/iterator/detail/tuple_of_iterator_references.h
index 1bb721909..91f4fcc65 100644
--- a/thrust/thrust/iterator/detail/tuple_of_iterator_references.h
+++ b/thrust/thrust/iterator/detail/tuple_of_iterator_references.h
@@ -26,111 +26,124 @@
# pragma system_header
#endif // no system header

-#include <cuda/std/type_traits>
-#include <cuda/std/tuple>
-
-#include <thrust/tuple.h>
-#include <thrust/pair.h>
-#include <thrust/detail/reference_forward_declaration.h>
#include <thrust/detail/raw_reference_cast.h>
+#include <thrust/detail/reference_forward_declaration.h>
+#include <thrust/pair.h>
+#include <thrust/tuple.h>
+
+#include <cuda/std/tuple>
+#include <cuda/std/type_traits>

THRUST_NAMESPACE_BEGIN

namespace detail
{

-template<
- typename... Ts
->
- class tuple_of_iterator_references : public thrust::tuple<Ts...>
+template <typename... Ts >
+class tuple_of_iterator_references;
+
+template <class U, class T>
+struct maybe_unwrap_nested
+{
+ __host__ __device__ U operator()(const T& t) const
+ {
+ return t;
+ }
+};
+
+template <class... Us, class... Ts>
+struct maybe_unwrap_nested<thrust::tuple<Us...>, tuple_of_iterator_references<Ts...>>
{
- public:
- using super_t = thrust::tuple<Ts...>;
- using super_t::super_t;
+ __host__ __device__ thrust::tuple<Us...> operator()(const tuple_of_iterator_references<Ts...>& t) const
+ {
+ return t.template __to_tuple<Us...>(typename ::cuda::std::__make_tuple_indices<sizeof...(Ts)>::type{});
+ }
+};

- inline __host__ __device__
- tuple_of_iterator_references()
+template < typename... Ts >
+class tuple_of_iterator_references : public thrust::tuple<Ts...>
+{
+public:
+ using super_t = thrust::tuple<Ts...>;
+ using super_t::super_t;
+
+ inline __host__ __device__ tuple_of_iterator_references()
: super_t()
- {}
+ {}

- // allow implicit construction from tuple<refs>
- inline __host__ __device__
- tuple_of_iterator_references(const super_t& other)
+ // allow implicit construction from tuple<refs>
+ inline __host__ __device__ tuple_of_iterator_references(const super_t& other)
: super_t(other)
- {}
+ {}

- inline __host__ __device__
- tuple_of_iterator_references(super_t&& other)
+ inline __host__ __device__ tuple_of_iterator_references(super_t&& other)
: super_t(::cuda::std::move(other))
- {}
-
- // allow assignment from tuples
- // XXX might be worthwhile to guard this with an enable_if is_assignable
- __thrust_exec_check_disable__
- template<typename... Us>
- inline __host__ __device__
- tuple_of_iterator_references &operator=(const thrust::tuple<Us...> &other)
- {
- super_t::operator=(other);
- return *this;
- }
-
- // allow assignment from pairs
- // XXX might be worthwhile to guard this with an enable_if is_assignable
- __thrust_exec_check_disable__
- template<typename U1, typename U2>
- inline __host__ __device__
- tuple_of_iterator_references &operator=(const thrust::pair<U1,U2> &other)
- {
- super_t::operator=(other);
- return *this;
- }
-
- // allow assignment from reference<tuple>
- // XXX perhaps we should generalize to reference<T>
- // we could captures reference<pair> this way
- __thrust_exec_check_disable__
- template<typename Pointer, typename Derived, typename... Us>
- inline __host__ __device__
- tuple_of_iterator_references&
- operator=(const thrust::reference<thrust::tuple<Us...>, Pointer, Derived> &other)
- {
- typedef thrust::tuple<Us...> tuple_type;
-
- // XXX perhaps this could be accelerated
- super_t::operator=(tuple_type{other});
- return *this;
- }
-
- template<class... Us, ::cuda::std::__enable_if_t<sizeof...(Us) == sizeof...(Ts), int> = 0>
- inline __host__ __device__
- constexpr operator thrust::tuple<Us...>() const {
- return to_tuple<Us...>(typename ::cuda::std::__make_tuple_indices<sizeof...(Ts)>::type{});
- }
-
- // this overload of swap() permits swapping tuple_of_iterator_references returned as temporaries from
- // iterator dereferences
- template<class... Us>
- inline __host__ __device__
- friend void swap(tuple_of_iterator_references&& x, tuple_of_iterator_references<Us...>&& y)
- {
- x.swap(y);
- }
-
-private:
- template<class... Us, size_t... Id>
- inline __host__ __device__
- constexpr thrust::tuple<Us...> to_tuple(::cuda::std::__tuple_indices<Id...>) const {
- return {get<Id>(*this)...};
- }
+ {}
+
+ // allow assignment from tuples
+ // XXX might be worthwhile to guard this with an enable_if is_assignable
+ __thrust_exec_check_disable__ template <typename... Us>
+ inline __host__ __device__ tuple_of_iterator_references& operator=(const thrust::tuple<Us...>& other)
+ {
+ super_t::operator=(other);
+ return *this;
+ }
+
+ // allow assignment from pairs
+ // XXX might be worthwhile to guard this with an enable_if is_assignable
+ __thrust_exec_check_disable__ template <typename U1, typename U2>
+ inline __host__ __device__ tuple_of_iterator_references& operator=(const thrust::pair<U1, U2>& other)
+ {
+ super_t::operator=(other);
+ return *this;
+ }
+
+ // allow assignment from reference<tuple>
+ // XXX perhaps we should generalize to reference<T>
+ // we could captures reference<pair> this way
+ __thrust_exec_check_disable__ template <typename Pointer, typename Derived, typename... Us>
+ inline __host__ __device__ tuple_of_iterator_references&
+ operator=(const thrust::reference<thrust::tuple<Us...>, Pointer, Derived>& other)
+ {
+ typedef thrust::tuple<Us...> tuple_type;
+
+ // XXX perhaps this could be accelerated
+ super_t::operator=(tuple_type{other});
+ return *this;
+ }
+
+ template <class... Us, ::cuda::std::__enable_if_t<sizeof...(Us) == sizeof...(Ts), int> = 0>
+ inline __host__ __device__ constexpr operator thrust::tuple<Us...>() const
+ {
+ return __to_tuple<Us...>(typename ::cuda::std::__make_tuple_indices<sizeof...(Ts)>::type{});
+ }
+
+ // this overload of swap() permits swapping tuple_of_iterator_references returned as temporaries from
+ // iterator dereferences
+ template <class... Us>
+ inline __host__ __device__ friend void swap(tuple_of_iterator_references&& x, tuple_of_iterator_references<Us...>&& y)
+ {
+ x.swap(y);
+ }
+
+ template <class... Us, size_t... Id>
+ inline __host__ __device__ constexpr thrust::tuple<Us...> __to_tuple(::cuda::std::__tuple_indices<Id...>) const
+ {
+ return {maybe_unwrap_nested<Us, Ts>{}(get<Id>(*this))...};
+ }
};

-} // end detail
+} // namespace detail

THRUST_NAMESPACE_END

_LIBCUDACXX_BEGIN_NAMESPACE_STD

+template <class... Ts>
+struct __is_tuple_of_iterator_references<THRUST_NS_QUALIFIER::detail::tuple_of_iterator_references<Ts...>>
+ : integral_constant<bool, true>
+{};
+
// define tuple_size, tuple_element, etc.
template <class... Ts>
struct tuple_size<THRUST_NS_QUALIFIER::detail::tuple_of_iterator_references<Ts...>>
@@ -145,7 +158,8 @@ struct tuple_element<Id, THRUST_NS_QUALIFIER::detail::tuple_of_iterator_referenc
_LIBCUDACXX_END_NAMESPACE_STD

// structured bindings suppport
-namespace std {
+namespace std
+{

template <class... Ts>
struct tuple_size<THRUST_NS_QUALIFIER::detail::tuple_of_iterator_references<Ts...>>
Loading

0 comments on commit 23bdc22

Please sign in to comment.