Skip to content

Commit

Permalink
Merge pull request kokkos#224 from youyu3/fix-195-required-span-size
Browse files Browse the repository at this point in the history
Fix required_span_size for layout_stride, after kokkos#195
  • Loading branch information
crtrott authored Jan 6, 2023
2 parents 973ef64 + d2afc33 commit 8cad638
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 38 deletions.
46 changes: 8 additions & 38 deletions include/experimental/__p0009_bits/layout_stride.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ struct layout_stride {
#endif
}

template<class SizeType, ::std::size_t ... Ep, ::std::size_t ... Idx>
_MDSPAN_HOST_DEVICE
constexpr index_type __get_size(extents<SizeType, Ep...>,integer_sequence<::std::size_t, Idx...>) const {
return _MDSPAN_FOLD_TIMES_RIGHT( static_cast<index_type>(extents().extent(Idx)), 1 );
}

//----------------------------------------------------------------------------

template <class>
Expand Down Expand Up @@ -391,7 +397,7 @@ struct layout_stride {
for(unsigned r = 0; r < extents_type::rank(); r++) {
// Return early if any of the extents are zero
if(extents().extent(r)==0) return 0;
span_size = std::max(span_size, static_cast<index_type>(extents().extent(r) * __strides_storage()[r]));
span_size += ( static_cast<index_type>(extents().extent(r) - 1 ) * __strides_storage()[r]);
}
return span_size;
}
Expand All @@ -418,43 +424,7 @@ struct layout_stride {

MDSPAN_INLINE_FUNCTION static constexpr bool is_unique() noexcept { return true; }
MDSPAN_INLINE_FUNCTION _MDSPAN_CONSTEXPR_14 bool is_exhaustive() const noexcept {
// TODO @testing test layout_stride is_exhaustive()
// FIXME CUDA
#ifdef __CUDA_ARCH__
return false;
#else
auto rem = array<size_t, Extents::rank()>{ };
std::iota(rem.begin(), rem.end(), size_t(0));
auto next_idx_iter = std::find_if(
rem.begin(), rem.end(),
[&](size_t i) { return this->stride(i) == 1; }
);
if(next_idx_iter != rem.end()) {
size_t prev_stride_times_prev_extent =
this->extents().extent(*next_idx_iter) * this->stride(*next_idx_iter);
// "remove" the index
constexpr auto removed_index_sentinel = static_cast<size_t>(-1);
*next_idx_iter = removed_index_sentinel;
size_t found_count = 1;
while (found_count != Extents::rank()) {
next_idx_iter = std::find_if(
rem.begin(), rem.end(),
[&](size_t i) {
return i != removed_index_sentinel
&& static_cast<size_t>(this->extents().extent(i)) == prev_stride_times_prev_extent;
}
);
if (next_idx_iter != rem.end()) {
// "remove" the index
*next_idx_iter = removed_index_sentinel;
++found_count;
prev_stride_times_prev_extent = stride(*next_idx_iter) * this->extents().extent(*next_idx_iter);
} else { break; }
}
return found_count == Extents::rank();
}
return false;
#endif
return required_span_size() == __get_size(extents(), make_index_sequence<extents_type::rank()>());
}
MDSPAN_INLINE_FUNCTION static constexpr bool is_strided() noexcept { return true; }

Expand Down
3 changes: 3 additions & 0 deletions tests/test_layout_stride.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ TEST(TestLayoutStrideSpanConstruction, test_from_span_construction) {
ASSERT_EQ(m1.stride(1), 128);
ASSERT_EQ(m1.strides()[0], 1);
ASSERT_EQ(m1.strides()[1], 128);
ASSERT_EQ(m1.required_span_size(),1+15*1+31*128);
ASSERT_FALSE(m1.is_exhaustive());
}
#endif
Expand All @@ -101,6 +102,7 @@ TEST(TestLayoutStrideListInitialization, test_list_initialization) {
ASSERT_EQ(m.stride(1), 128);
ASSERT_EQ(m.strides()[0], 1);
ASSERT_EQ(m.strides()[1], 128);
ASSERT_EQ(m.required_span_size(),1+15*1+31*128);
ASSERT_FALSE(m.is_exhaustive());
}

Expand Down Expand Up @@ -129,6 +131,7 @@ TEST(TestLayoutStrideCTAD, test_ctad) {
ASSERT_EQ(m1.stride(1), 128);
ASSERT_EQ(m1.strides()[0], 1);
ASSERT_EQ(m1.strides()[1], 128);
ASSERT_EQ(m1.required_span_size(),1+15*1+31*128);
ASSERT_FALSE(m1.is_exhaustive());

// TODO These won't work with our current implementation, because the array will
Expand Down

0 comments on commit 8cad638

Please sign in to comment.