Skip to content

Commit

Permalink
Add constraint for layout_left|right|stride::stride(), and add test
Browse files Browse the repository at this point in the history
Addresses kokkos#201
  • Loading branch information
youyu3 committed Feb 3, 2023
1 parent 0a1ce8c commit b421aa5
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 7 deletions.
6 changes: 6 additions & 0 deletions include/experimental/__p0009_bits/layout_left.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ class layout_left::mapping {
MDSPAN_INLINE_FUNCTION constexpr bool is_exhaustive() const noexcept { return true; }
MDSPAN_INLINE_FUNCTION constexpr bool is_strided() const noexcept { return true; }

MDSPAN_TEMPLATE_REQUIRES(
class E = Extents,
/* requires */ (
E::rank() > 0
)
)
MDSPAN_INLINE_FUNCTION
constexpr index_type stride(rank_type i) const noexcept {
index_type value = 1;
Expand Down
6 changes: 6 additions & 0 deletions include/experimental/__p0009_bits/layout_right.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ class layout_right::mapping {
MDSPAN_INLINE_FUNCTION constexpr bool is_exhaustive() const noexcept { return true; }
MDSPAN_INLINE_FUNCTION constexpr bool is_strided() const noexcept { return true; }

MDSPAN_TEMPLATE_REQUIRES(
class E = Extents,
/* requires */ (
E::rank() > 0
)
)
MDSPAN_INLINE_FUNCTION
constexpr index_type stride(rank_type i) const noexcept {
index_type value = 1;
Expand Down
6 changes: 6 additions & 0 deletions include/experimental/__p0009_bits/layout_stride.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,12 @@ struct layout_stride {
MDSPAN_INLINE_FUNCTION static constexpr bool is_strided() noexcept { return true; }


MDSPAN_TEMPLATE_REQUIRES(
class E = Extents,
/* requires */ (
E::rank() > 0
)
)
MDSPAN_INLINE_FUNCTION
constexpr index_type stride(rank_type r) const noexcept {
return __strides_storage()[r];
Expand Down
8 changes: 5 additions & 3 deletions tests/test_exhaustive_layouts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,11 @@ TYPED_TEST(TestLayoutConversion, implicit_conversion) {
#endif
map1 = typename TestFixture::map_1_t(map2);

for(size_t r=0; r != this->exts1.rank(); ++r) {
ASSERT_EQ(map1.extents().extent(r), map2.extents().extent(r));
ASSERT_EQ(map1.stride(r), map2.stride(r));
if constexpr (this->exts1.rank() > 0) {
for(size_t r=0; r != this->exts1.rank(); ++r) {
ASSERT_EQ(map1.extents().extent(r), map2.extents().extent(r));
ASSERT_EQ(map1.stride(r), map2.stride(r));
}
}
}

32 changes: 32 additions & 0 deletions tests/test_layout_ctors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,35 @@ TEST(TestLayoutRightCTAD, test_layout_right_ctad) {
ASSERT_TRUE(m.is_exhaustive());
}
#endif

template< class T, class RankType, class = void >
struct is_stride_avail : std::false_type {};

template< class T, class RankType >
struct is_stride_avail< T
, RankType
, std::enable_if_t< std::is_same< decltype( std::declval<T>().stride( std::declval<RankType>() ) )
, typename T::index_type
>::value
>
> : std::true_type {};

TEST(TestLayoutLeftStrideConstraint, test_layout_left_stride_constraint) {
stdex::extents<int,16> ext1d{};
stdex::layout_left::mapping m1d{ext1d};
ASSERT_TRUE ((is_stride_avail< decltype(m1d), int >::value));

stdex::extents<int> ext0d{};
stdex::layout_left::mapping m0d{ext0d};
ASSERT_FALSE((is_stride_avail< decltype(m0d), int >::value));
}

TEST(TestLayoutRightStrideConstraint, test_layout_right_stride_constraint) {
stdex::extents<int,16> ext1d{};
stdex::layout_right::mapping m1d{ext1d};
ASSERT_TRUE ((is_stride_avail< decltype(m1d), int >::value));

stdex::extents<int> ext0d{};
stdex::layout_right::mapping m0d{ext0d};
ASSERT_FALSE((is_stride_avail< decltype(m0d), int >::value));
}
21 changes: 21 additions & 0 deletions tests/test_layout_stride.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,24 @@ TEST(TestLayoutStrideCTAD, test_ctad) {
}
#endif

template< class T, class RankType, class = void >
struct is_stride_avail : std::false_type {};

template< class T, class RankType >
struct is_stride_avail< T
, RankType
, std::enable_if_t< std::is_same< decltype( std::declval<T>().stride( std::declval<RankType>() ) )
, typename T::index_type
>::value
>
> : std::true_type {};

TEST(TestLayoutStrideStrideConstraint, test_layout_stride_stride_constraint) {
stdex::extents<int,16> ext1d{};
stdex::layout_stride::mapping m1d{ext1d, std::array<int,1>{1}};
ASSERT_TRUE ((is_stride_avail< decltype(m1d), int >::value));

stdex::extents<int> ext0d{};
stdex::layout_stride::mapping m0d{ext0d, std::array<int,0>{}};
ASSERT_FALSE((is_stride_avail< decltype(m0d), int >::value));
}
14 changes: 10 additions & 4 deletions tests/test_mdarray_ctors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,23 @@ void check_correctness(MDA& m, size_t rank, size_t rank_dynamic,
bool exhaustive) {
ASSERT_EQ(m.rank(), rank);
ASSERT_EQ(m.rank_dynamic(), rank_dynamic);
if(rank>0) {
if (rank>0) {
ASSERT_EQ(m.extent(0), extent_0);
ASSERT_EQ(m.stride(0), stride_0);
if constexpr (m.rank()>0) {
ASSERT_EQ(m.stride(0), stride_0);
}
}
if(rank>1) {
ASSERT_EQ(m.extent(1), extent_1);
ASSERT_EQ(m.stride(1), stride_1);
if constexpr (m.rank()>0) {
ASSERT_EQ(m.stride(1), stride_1);
}
}
if(rank>2) {
ASSERT_EQ(m.extent(2), extent_2);
ASSERT_EQ(m.stride(2), stride_2);
if constexpr (m.rank()>0) {
ASSERT_EQ(m.stride(2), stride_2);
}
}
if(ptr_matches)
ASSERT_EQ(m.data(),ptr);
Expand Down

0 comments on commit b421aa5

Please sign in to comment.