Skip to content

Commit 0259594

Browse files
committed
Fix submdspan_canonicalize_slices for one slice
1 parent c847a54 commit 0259594

File tree

3 files changed

+210
-47
lines changed

3 files changed

+210
-47
lines changed

include/experimental/__p2630_bits/submdspan_extents.hpp

Lines changed: 159 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -407,17 +407,20 @@ struct extents_constructor<0, Extents, NewStaticExtents...> {
407407

408408
#if defined(MDSPAN_ENABLE_P3663)
409409

410-
namespace impl {
410+
namespace detail {
411411

412412
template<class IndexType, class OtherIndexType>
413+
requires(std::is_signed_v<std::remove_cvref_t<OtherIndexType>> ||
414+
std::is_unsigned_v<std::remove_cvref_t<OtherIndexType>>)
413415
constexpr auto index_cast(OtherIndexType&& i) noexcept {
414-
using OIT = std::remove_cvref_t<OtherIndexType>;
415-
if (std::is_signed_v<OIT> || std::is_unsigned_v<OIT>) {
416-
return i;
417-
}
418-
else {
419-
return static_cast<IndexType>(i);
420-
}
416+
return i;
417+
}
418+
419+
template<class IndexType, class OtherIndexType>
420+
requires(! std::is_signed_v<std::remove_cvref_t<OtherIndexType>> &&
421+
!std::is_unsigned_v<std::remove_cvref_t<OtherIndexType>>)
422+
constexpr auto index_cast(OtherIndexType&& i) noexcept {
423+
return static_cast<IndexType>(i);
421424
}
422425

423426
template<class IndexType, class S>
@@ -430,11 +433,16 @@ constexpr auto canonical_ice(S s) {
430433
//
431434
// TODO Preconditions: If S is a signed or unsigned integer type,
432435
// then s is representable as a value of type IndexType.
436+
//
437+
// TODO NOT IN PROPOSAL: index-cast result needs to be
438+
// cast again to IndexType, so that we don't get a weird
439+
// constant_wrapper whose value has a different type
440+
// than the second template argument.
433441
if constexpr (__mdspan_integral_constant_like<S>) {
434-
return std::constant_wrapper<index_cast<IndexType>(S::value), IndexType>{};
442+
return std::constant_wrapper<static_cast<IndexType>(index_cast<IndexType>(S::value)), IndexType>{};
435443
}
436444
else {
437-
return index_cast<IndexType>(s);
445+
return static_cast<IndexType>(index_cast<IndexType>(s));
438446
}
439447
}
440448

@@ -461,18 +469,30 @@ enum class check_static_bounds_result {
461469

462470
// TODO It's impossible to write an "if constexpr" check for
463471
// "structured binding into two elements is well-formed." Thus, we
464-
// write check_static_bounds only for canonical slice types as inputs
465-
// -- that is, we invoke check_static_bounds post-canonicalization.
466-
//
467-
// This may suggest a change in wording, though only if
468-
// we need to call check_static_bounds on pre-canonicalized slices.
469-
470-
template<size_t k, class IndexType, size_t... Exts, class... Slices>
472+
// must assume that the input Slices are all valid slice types.
473+
// One way to do that is to invoke this only post-canonicalization.
474+
// Another way is to rely on submdspan_canonicalize_slices to be
475+
// ill-formed if called with an invalid slice type. We can do the
476+
// latter in submdspan_canonicalize_slices by expressing the four
477+
// possible categories of valid slice types in if constexpr, with
478+
// the final else attempting the structured binding into two elements.
479+
480+
// TODO NOT IN PROPOSAL: Consider rewriting to use only $S_k$
481+
// and not $s_k$ in check-static-bounds, since we can't use
482+
// the actual function parameter in a function that we want
483+
// to work in a constant expression.
484+
485+
// TODO NOT IN PROPOSAL: Taking slices parameter(s) makes use
486+
// of check_static_bounds not a constant expression.
487+
// Instead, make Slices... a template parameter pack.
488+
489+
// TODO NOT IN PROPOSAL: It's easier to have a single Slice
490+
// as a template parameter pack. This makes sense because
491+
// the function only tests one slice (the k-th one) anyway.
492+
template<size_t k, class S_k, class IndexType, size_t... Exts>
471493
constexpr check_static_bounds_result check_static_bounds(
472-
const extents<IndexType, Exts...>&, Slices... slices)
494+
const extents<IndexType, Exts...>&)
473495
{
474-
auto s_k = slices...[k];
475-
using S_k = decltype(s_k);
476496
if constexpr (std::is_convertible_v<S_k, full_extent_t>) {
477497
return check_static_bounds_result::in_bounds;
478498
}
@@ -484,7 +504,7 @@ template<size_t k, class IndexType, size_t... Exts, class... Slices>
484504
else if constexpr (Exts...[k] != dynamic_extent && Exts...[k] <= de_ice(S_k{})) {
485505
return check_static_bounds_result::out_of_bounds;
486506
}
487-
else if constexpr (Exts...[k] != dynamic_extent && de_ice(s_k) < Exts...[k]) {
507+
else if constexpr (Exts...[k] != dynamic_extent && de_ice(S_k{}) < Exts...[k]) {
488508
return check_static_bounds_result::in_bounds;
489509
}
490510
else {
@@ -495,35 +515,35 @@ template<size_t k, class IndexType, size_t... Exts, class... Slices>
495515
return check_static_bounds_result::unknown;
496516
}
497517
}
498-
else if constexpr (detail::is_strided_slice<S_k>::value) {
518+
else if constexpr (is_strided_slice<S_k>::value) {
499519
if constexpr (__mdspan_integral_constant_like<typename S_k::offset_type>) {
500-
if constexpr (de_ice(s_k.offset) < 0) {
520+
if constexpr (de_ice(S_k{}.offset) < 0) {
501521
return check_static_bounds_result::out_of_bounds; // 14.3.1
502522
}
503523
else if constexpr (
504-
Exts...[k] != dynamic_extent && Exts...[k] < de_ice(s_k.offset))
524+
Exts...[k] != dynamic_extent && Exts...[k] < de_ice(S_k{}.offset))
505525
{
506526
return check_static_bounds_result::out_of_bounds; // 14.3.2
507527
}
508528
else if constexpr (
509529
__mdspan_integral_constant_like<typename S_k::extent_type> &&
510-
de_ice(s_k.offset) + de_ice(s_k.extent) < 0)
530+
de_ice(S_k{}.offset) + de_ice(S_k{}.extent) < 0)
511531
{
512532
return check_static_bounds_result::out_of_bounds; // 14.3.3
513533
}
514534
else if constexpr (
515535
Exts...[k] != dynamic_extent &&
516536
__mdspan_integral_constant_like<typename S_k::extent_type> &&
517-
Exts...[k] < de_ice(s_k.offset) + de_ice(s_k.extent))
537+
Exts...[k] < de_ice(S_k{}.offset) + de_ice(S_k{}.extent))
518538
{
519539
return check_static_bounds_result::out_of_bounds; // 14.3.4
520540
}
521541
else if constexpr (
522542
Exts...[k] != dynamic_extent &&
523543
__mdspan_integral_constant_like<typename S_k::extent_type> &&
524-
0 <= de_ice(s_k.offset) &&
525-
de_ice(s_k.offset) <= de_ice(s_k.offset) + de_ice(s_k.extent) &&
526-
de_ice(s_k.offset) + de_ice(s_k.extent) <= Exts...[k])
544+
0 <= de_ice(S_k{}.offset) &&
545+
de_ice(S_k{}.offset) <= de_ice(S_k{}.offset) + de_ice(S_k{}.extent) &&
546+
de_ice(S_k{}.offset) + de_ice(S_k{}.extent) <= Exts...[k])
527547
{
528548
return check_static_bounds_result::in_bounds; // 14.3.5
529549
}
@@ -539,7 +559,9 @@ template<size_t k, class IndexType, size_t... Exts, class... Slices>
539559
// NOTE: This case means that check_static_bounds cannot be
540560
// well-formed if it didn't fall into one of the above cases
541561
// and if it can't be destructured into two elements.
542-
auto [s_k0, s_k1] = s_k;
562+
563+
// We can't use s_k here, because it's not a constant expression.
564+
auto [s_k0, s_k1] = S_k{};
543565
using S_k0 = decltype(s_k0);
544566
using S_k1 = decltype(s_k1);
545567
if constexpr (__mdspan_integral_constant_like<S_k0>) {
@@ -583,31 +605,127 @@ template<size_t k, class IndexType, size_t... Exts, class... Slices>
583605
}
584606
}
585607
}
586-
} // namespace impl
587608

588-
template<class IndexType>
609+
template<class T>
610+
constexpr bool is_constant_wrapper = false;
611+
612+
template<class IndexType, IndexType Value>
613+
constexpr bool is_constant_wrapper<std::constant_wrapper<Value, IndexType>> = true;
614+
615+
// [mdspan.sub.slices] 1
616+
template<class IndexType, class T>
617+
constexpr bool is_canonical_submdspan_index_type =
618+
std::is_same_v<T, IndexType> || (
619+
is_constant_wrapper<T> &&
620+
std::is_same_v<typename T::value_type, IndexType>
621+
);
622+
623+
// [mdspan.sub.slices] 2
624+
template<class IndexType, class Slice>
625+
MDSPAN_INLINE_FUNCTION
626+
constexpr bool is_canonical_slice_type() {
627+
if constexpr (
628+
std::is_same_v<Slice, full_extent_t> || // 2.1
629+
is_canonical_submdspan_index_type<IndexType, Slice>) // 2.2
630+
{
631+
return true;
632+
}
633+
else if constexpr (is_strided_slice<Slice>::value) { // 2.3
634+
if constexpr ( // 2.3.1
635+
is_canonical_submdspan_index_type<IndexType, typename Slice::offset_type> &&
636+
is_canonical_submdspan_index_type<IndexType, typename Slice::extent_type> &&
637+
is_canonical_submdspan_index_type<IndexType, typename Slice::stride_type>)
638+
{
639+
if constexpr (
640+
is_constant_wrapper<typename Slice::stride_type> &&
641+
is_constant_wrapper<typename Slice::extent_type>)
642+
{
643+
constexpr auto Stride = de_ice(typename Slice::stride_type{});
644+
constexpr auto Extent = de_ice(typename Slice::extent_type{});
645+
return Extent == 0 || Stride > 0; // 2.3.2
646+
}
647+
else {
648+
return false;
649+
}
650+
}
651+
else {
652+
return false;
653+
}
654+
}
655+
else {
656+
return false;
657+
}
658+
}
659+
660+
// [mdspan.sub.slices] 3
661+
template<size_t k, class IndexType, size_t... Extents, class Slice>
589662
MDSPAN_INLINE_FUNCTION
590663
constexpr auto
591-
submdspan_canonicalize_slices(const extents<IndexType>&)
664+
is_canonical_kth_submdspan_slice_type(const extents<IndexType, Extents...>& exts, Slice slice)
592665
{
593-
return std::tuple{};
666+
if constexpr (! is_canonical_slice_type<IndexType, Slice>()) {
667+
return false; // 3.1
668+
}
669+
else { // 3.2
670+
return check_static_bounds<k, decltype(slice)>(exts) != check_static_bounds_result::out_of_bounds;
671+
}
594672
}
595673

596-
template<class IndexType, size_t Extent>
674+
// [mdspan.sub.slices] 11
675+
template<size_t k, class Slice, class IndexType, size_t... Extents>
597676
MDSPAN_INLINE_FUNCTION
598677
constexpr auto
599-
submdspan_canonicalize_slices(const extents<IndexType, Extent>&, full_extent_t)
600-
{
601-
return std::tuple{full_extent};
678+
submdspan_canonicalize_one_slice(const extents<IndexType, Extents...>& exts, Slice s) {
679+
// Part of [mdspan.sub.slices] 9.
680+
// This could be combined with the if constexpr branches below.
681+
//
682+
// NOTE This is not a constant expression (because it takes exts).
683+
static_assert(check_static_bounds<k, decltype(s)>(exts) != check_static_bounds_result::out_of_bounds);
684+
685+
// TODO Check Precondition that s is a valid k-th submdspan slice for exts.
686+
687+
if constexpr (std::is_convertible_v<Slice, full_extent_t>) {
688+
return full_extent; // 11.1
689+
}
690+
else if constexpr (std::is_convertible_v<Slice, IndexType>) {
691+
return canonical_ice<IndexType>(s); // 11.2
692+
}
693+
else if constexpr (is_strided_slice<Slice>::value) { // 11.3
694+
return strided_slice{
695+
.offset = canonical_ice<IndexType>(s.offset),
696+
.extent = canonical_ice<IndexType>(s.extent),
697+
.stride = canonical_ice<IndexType>(s.stride)
698+
};
699+
}
700+
else { // 11.4
701+
auto [s_k0, s_k1] = s;
702+
using S_k0 = decltype(s_k0);
703+
using S_k1 = decltype(s_k1);
704+
static_assert(std::is_convertible_v<S_k0, IndexType>);
705+
static_assert(std::is_convertible_v<S_k1, IndexType>);
706+
return strided_slice{
707+
.offset = canonical_ice<IndexType>(s_k0),
708+
.extent = subtract_ice<IndexType>(s_k0, s_k1),
709+
.stride = std::cw<IndexType(1)>
710+
};
711+
}
602712
}
603713

714+
} // namespace detail
715+
604716
template<class IndexType, size_t... Extents, class... Slices>
717+
requires (sizeof...(Slices) == sizeof...(Extents)) // [mdspan.sub.slices] 8
605718
MDSPAN_INLINE_FUNCTION
606719
constexpr auto
607-
submdspan_canonicalize_slices(const extents<IndexType, Extents...>&, Slices... slices)
720+
submdspan_canonicalize_slices(const extents<IndexType, Extents...>& exts, Slices... slices)
608721
{
609-
static_assert(sizeof...(Slices) == 0, "General case not implemented yet");
610-
return std::tuple{slices...};
722+
return [&]<size_t... Inds>(std::index_sequence<Inds...>) {
723+
return std::tuple{
724+
// This is ill-formed if slices...[Inds] is not a valid slice type.
725+
// That implements the Mandates clause of [mdspan.sub.slices] 9.
726+
detail::submdspan_canonicalize_one_slice<Inds>(exts, slices...[Inds])...
727+
};
728+
} (std::make_index_sequence<sizeof...(Slices)>{});
611729
}
612730
#endif // MDSPAN_ENABLE_P3663
613731

include/experimental/__p2630_bits/submdspan_mapping.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#endif
4343

4444
namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
45+
4546
//******************************************
4647
// Return type of submdspan_mapping overloads
4748
//******************************************
@@ -52,6 +53,36 @@ template <class LayoutMapping> struct submdspan_mapping_result {
5253

5354
namespace detail {
5455

56+
#if defined(MDSPAN_ENABLE_P3663)
57+
template<layout_mapping_alike LayoutMapping>
58+
constexpr auto
59+
submdspan_mapping_with_full_extents(const LayoutMapping& mapping) {
60+
using extents_type = typename LayoutMapping::extents_type;
61+
return [&] <size_t... Inds> (std::index_sequence<Inds...>) {
62+
return submdspan_mapping(mapping, ((void) Inds, full_extent)...);
63+
} (std::make_index_sequence<extents_type::rank()>{});
64+
}
65+
66+
template<class T>
67+
constexpr bool is_submdspan_mapping_result = false;
68+
69+
template<class LayoutMapping>
70+
constexpr bool is_submdspan_mapping_result<
71+
submdspan_mapping_result<LayoutMapping>> = true;
72+
73+
template<class LayoutMapping>
74+
concept submdspan_mapping_result =
75+
is_submdspan_mapping_result<LayoutMapping>;
76+
77+
template<class LayoutMapping>
78+
concept mapping_sliceable_with_full_extents =
79+
requires(const LayoutMapping& mapping) {
80+
{
81+
submdspan_mapping_with_full_extents(mapping)
82+
} -> submdspan_mapping_result;
83+
};
84+
#endif // MDSPAN_ENABLE_P3663
85+
5586
// We use const Slice& and not Slice&& because the various
5687
// submdspan_mapping_impl overloads use their slices arguments
5788
// multiple times. This makes perfect forwarding not useful, but we

tests/test_canonicalize_slices.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424

2525
namespace {
2626

27+
template<class T>
28+
constexpr bool slice_equal(const T& left, const T& right) {
29+
return left == right;
30+
}
31+
32+
// full_extent_t lacks operator==
2733
constexpr bool slice_equal(Kokkos::full_extent_t, Kokkos::full_extent_t) {
2834
return true;
2935
}
@@ -38,12 +44,6 @@ constexpr bool slice_equal(const Left&, Kokkos::full_extent_t) {
3844
return std::is_convertible_v<Left, Kokkos::full_extent_t>;
3945
}
4046

41-
template<class Left, class Right>
42-
constexpr bool slice_equal(const Left&, const Right&) {
43-
static_assert(false, "slice_equal not implemented for this case");
44-
return false;
45-
}
46-
4747
template<class ExpectedResult, class InputExtents, class... Slices>
4848
void
4949
test_canonicalize_slices(
@@ -78,4 +78,18 @@ TEST(CanonicalizeSlices, Rank1_full) {
7878
test_canonicalize_slices(expected_result, Kokkos::extents<size_t, Kokkos::dynamic_extent>{}, full);
7979
}
8080

81+
TEST(CanonicalizeSlices, Rank1_integer_dynamic) {
82+
constexpr auto slice0 = int(7u);
83+
constexpr auto expected_slices = std::tuple{size_t(7u)};
84+
constexpr auto exts = Kokkos::extents<size_t, 10>{};
85+
test_canonicalize_slices(expected_slices, exts, slice0);
86+
}
87+
88+
TEST(CanonicalizeSlices, Rank1_integer_static) {
89+
constexpr auto slice0 = std::integral_constant<int, 7>{};
90+
constexpr auto expected_slices = std::tuple{std::cw<size_t(7u)>};
91+
constexpr auto exts = Kokkos::extents<size_t, 10>{};
92+
test_canonicalize_slices(expected_slices, exts, slice0);
93+
}
94+
8195
} // namespace (anonymous)

0 commit comments

Comments
 (0)