Skip to content

Commit a609ed2

Browse files
committed
Fix canonical_ice & 1 strided_slice
submdspan_canonicalize_slices now works for one strided_slice. Fix canonical_ice (wrong order in subtraction; wording is fine).
1 parent 0259594 commit a609ed2

File tree

2 files changed

+103
-5
lines changed

2 files changed

+103
-5
lines changed

include/experimental/__p2630_bits/submdspan_extents.hpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ constexpr auto canonical_ice(S s) {
448448

449449
template<class IndexType, class X, class Y>
450450
constexpr auto subtract_ice(X x, Y y) {
451-
return canonical_ice<IndexType>(x) - canonical_ice<IndexType>(y);
451+
return canonical_ice<IndexType>(y) - canonical_ice<IndexType>(x);
452452
}
453453

454454
template<class T>
@@ -560,10 +560,19 @@ template<size_t k, class S_k, class IndexType, size_t... Exts>
560560
// well-formed if it didn't fall into one of the above cases
561561
// and if it can't be destructured into two elements.
562562

563-
// We can't use s_k here, because it's not a constant expression.
564-
auto [s_k0, s_k1] = S_k{};
565-
using S_k0 = decltype(s_k0);
566-
using S_k1 = decltype(s_k1);
563+
// We can't use s_k on the right-hand side here, because it's not a constant expression.
564+
// We can't use S_k{} here either, because that presumes that it's default constructible.
565+
// We can only use std::declval<S_k>() in an unevaluated context.
566+
auto get_first = [] (S_k s_k) {
567+
auto [s_k0, _] = s_k;
568+
return s_k0;
569+
};
570+
auto get_second = [] (S_k s_k) {
571+
auto [_, s_k1] = s_k;
572+
return s_k1;
573+
};
574+
using S_k0 = decltype(get_first(std::declval<S_k>()));
575+
using S_k1 = decltype(get_second(std::declval<S_k>()));
567576
if constexpr (__mdspan_integral_constant_like<S_k0>) {
568577
if constexpr (de_ice(S_k0{}) < 0) {
569578
return check_static_bounds_result::out_of_bounds; // 14.4.1
@@ -703,6 +712,10 @@ submdspan_canonicalize_one_slice(const extents<IndexType, Extents...>& exts, Sli
703712
using S_k1 = decltype(s_k1);
704713
static_assert(std::is_convertible_v<S_k0, IndexType>);
705714
static_assert(std::is_convertible_v<S_k1, IndexType>);
715+
716+
static_assert(std::is_same_v<decltype(canonical_ice<IndexType>(s_k0)), IndexType>);
717+
static_assert(std::is_same_v<decltype(subtract_ice<IndexType>(s_k0, s_k1)), IndexType>);
718+
706719
return strided_slice{
707720
.offset = canonical_ice<IndexType>(s_k0),
708721
.extent = subtract_ice<IndexType>(s_k0, s_k1),

tests/test_canonicalize_slices.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,61 @@
2222
# error "This file requires MDSPAN_ENABLE_P3663=ON"
2323
#endif
2424

25+
namespace my_test {
26+
27+
template<class First, class Second>
28+
struct my_aggregate_pair {
29+
First first;
30+
Second second;
31+
};
32+
33+
// Not an aggregate, to force use of the tuple protocol.
34+
template<class First, class Second>
35+
class my_nonaggregate_pair {
36+
public:
37+
constexpr my_nonaggregate_pair(First first, Second second)
38+
: first_(first), second_(second)
39+
{}
40+
41+
template<std::size_t Index, class Self>
42+
constexpr decltype(auto) get(this Self&& self) {
43+
if constexpr (Index == 0) {
44+
return self.first_;
45+
}
46+
else if constexpr (Index == 1) {
47+
return self.second_;
48+
}
49+
else {
50+
static_assert(false, "Invalid index");
51+
}
52+
}
53+
54+
private:
55+
First first_;
56+
Second second_;
57+
};
58+
59+
} // namespace my_test
60+
61+
template<class First, class Second>
62+
struct std::tuple_size<my_test::my_nonaggregate_pair<First, Second>>
63+
: std::integral_constant<std::size_t, 2> {};
64+
65+
template<std::size_t Index, class First, class Second>
66+
struct std::tuple_element<Index, my_test::my_nonaggregate_pair<First, Second>> {
67+
static_assert(false, "Invalid index");
68+
};
69+
70+
template<class First, class Second>
71+
struct std::tuple_element<0, my_test::my_nonaggregate_pair<First, Second>> {
72+
using type = First;
73+
};
74+
75+
template<class First, class Second>
76+
struct std::tuple_element<1, my_test::my_nonaggregate_pair<First, Second>> {
77+
using type = Second;
78+
};
79+
2580
namespace {
2681

2782
template<class T>
@@ -44,6 +99,14 @@ constexpr bool slice_equal(const Left&, Kokkos::full_extent_t) {
4499
return std::is_convertible_v<Left, Kokkos::full_extent_t>;
45100
}
46101

102+
template<class OffsetType, class ExtentType, class StrideType>
103+
constexpr bool slice_equal(
104+
const Kokkos::strided_slice<OffsetType, ExtentType, StrideType>& left,
105+
const Kokkos::strided_slice<OffsetType, ExtentType, StrideType>& right)
106+
{
107+
return left.offset == right.offset && left.extent == right.extent && left.stride == right.stride;
108+
}
109+
47110
template<class ExpectedResult, class InputExtents, class... Slices>
48111
void
49112
test_canonicalize_slices(
@@ -92,4 +155,26 @@ TEST(CanonicalizeSlices, Rank1_integer_static) {
92155
test_canonicalize_slices(expected_slices, exts, slice0);
93156
}
94157

158+
TEST(CanonicalizeSlices, Rank1_aggregate_pair) {
159+
constexpr auto slice0 = my_test::my_aggregate_pair<int, int>{7, 11};
160+
constexpr auto expected_slices = std::tuple{Kokkos::strided_slice{
161+
.offset = size_t(7u),
162+
.extent = (size_t(11u) - size_t(7u)),
163+
.stride = std::cw<size_t(1u)>
164+
}};
165+
constexpr auto exts = Kokkos::extents<size_t, 13>{};
166+
test_canonicalize_slices(expected_slices, exts, slice0);
167+
}
168+
169+
TEST(CanonicalizeSlices, Rank1_nonaggregate_pair) {
170+
constexpr auto slice0 = my_test::my_nonaggregate_pair<int, int>(7, 11);
171+
constexpr auto expected_slices = std::tuple{Kokkos::strided_slice{
172+
.offset = size_t(7u),
173+
.extent = (size_t(11u) - size_t(7u)),
174+
.stride = std::cw<size_t(1u)>
175+
}};
176+
constexpr auto exts = Kokkos::extents<size_t, 13>{};
177+
test_canonicalize_slices(expected_slices, exts, slice0);
178+
}
179+
95180
} // namespace (anonymous)

0 commit comments

Comments
 (0)