Skip to content

<mdspan>: Cache fwd-prod-of-extents and rev-prod-of-extents when all extents are static #3715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 83 additions & 24 deletions stl/inc/mdspan
Original file line number Diff line number Diff line change
Expand Up @@ -231,24 +231,6 @@ public:
}
}

// TRANSITION, LWG ISSUE? I believe that this function should return 'index_type'
_NODISCARD constexpr index_type _Fwd_prod_of_extents(const rank_type _Idx) const noexcept {
index_type _Result = 1;
for (rank_type _Dim = 0; _Dim < _Idx; ++_Dim) {
_Result *= extent(_Dim);
}
return _Result;
}

// TRANSITION, LWG ISSUE? I believe that this function should return 'index_type'
_NODISCARD constexpr index_type _Rev_prod_of_extents(const rank_type _Idx) const noexcept {
index_type _Result = 1;
for (rank_type _Dim = _Idx + 1; _Dim < _Rank; ++_Dim) {
_Result *= extent(_Dim);
}
return _Result;
}

_NODISCARD static consteval bool _Is_index_space_size_representable() {
if constexpr (rank_dynamic() == 0 && rank() > 0) {
return _STD in_range<index_type>((_Extents * ...));
Expand Down Expand Up @@ -293,6 +275,82 @@ inline constexpr bool _Is_extents = false;
template <class _IndexType, size_t... _Args>
inline constexpr bool _Is_extents<extents<_IndexType, _Args...>> = true;

template <class _Extents>
requires _Is_extents<_Extents>
class _Fwd_prod_of_extents {
public:
_NODISCARD static constexpr _Extents::index_type _Calculate(const _Extents& _Exts, const size_t _Idx) noexcept {
if constexpr (_Extents::rank() == 0) {
return 1;
} else {
typename _Extents::index_type _Result = 1;
for (size_t _Dim = 0; _Dim < _Idx; ++_Dim) {
_Result *= _Exts.extent(_Dim);
}
return _Result;
}
}
};

template <class _IndexType, size_t... _Extents>
requires ((_Extents != dynamic_extent) && ...)
class _Fwd_prod_of_extents<extents<_IndexType, _Extents...>> {
private:
using _Ty = extents<_IndexType, _Extents...>;

_NODISCARD static consteval auto _Make_prods() noexcept {
array<typename _Ty::index_type, _Ty::rank() + 1> _Result;
_Result.front() = 1;
for (size_t _Dim = 1; _Dim < _Ty::_Rank + 1; ++_Dim) {
_Result[_Dim] = static_cast<_Ty::index_type>(_Result[_Dim - 1] * _Ty::static_extent(_Dim - 1));
}
return _Result;
}

static constexpr array<typename _Ty::index_type, _Ty::rank() + 1> _Cache = _Make_prods();

public:
_NODISCARD static constexpr _Ty::index_type _Calculate(const _Ty&, const size_t _Idx) noexcept {
return _Cache[_Idx];
}
};

template <class _Extents>
requires _Is_extents<_Extents> && (_Extents::rank() > 0)
class _Rev_prod_of_extents {
public:
_NODISCARD static constexpr _Extents::index_type _Calculate(const _Extents& _Exts, const size_t _Idx) noexcept {
typename _Extents::index_type _Result = 1;
for (size_t _Dim = _Idx + 1; _Dim < _Extents::_Rank; ++_Dim) {
_Result *= _Exts.extent(_Dim);
}
return _Result;
}
};

template <class _IndexType, size_t... _Extents>
requires ((_Extents != dynamic_extent) && ...)
class _Rev_prod_of_extents<extents<_IndexType, _Extents...>> {
private:
using _Ty = extents<_IndexType, _Extents...>;

_NODISCARD static consteval auto _Make_prods() noexcept {
array<typename _Ty::index_type, _Ty::rank()> _Result;
_Result.back() = 1;
for (size_t _Dim = _Ty::_Rank; _Dim-- > 1;) {
_Result[_Dim - 1] = static_cast<_Ty::index_type>(_Result[_Dim] * _Ty::static_extent(_Dim));
}
return _Result;
}

static constexpr array<typename _Ty::index_type, _Ty::rank()> _Cache = _Make_prods();

public:
_NODISCARD static constexpr _Ty::index_type _Calculate(const _Ty&, const size_t _Idx) noexcept {
return _Cache[_Idx];
}
};

template <class _Layout, class _Mapping>
inline constexpr bool _Is_mapping_of =
is_same_v<typename _Layout::template mapping<typename _Mapping::extents_type>, _Mapping>;
Expand Down Expand Up @@ -382,7 +440,7 @@ public:
}

_NODISCARD constexpr index_type required_span_size() const noexcept {
return _Exts._Fwd_prod_of_extents(extents_type::_Rank);
return _Fwd_prod_of_extents<extents_type>::_Calculate(_Exts, extents_type::_Rank);
}

template <class... _IndexTypes>
Expand Down Expand Up @@ -425,7 +483,7 @@ public:
{
_STL_VERIFY(_Idx < extents_type::_Rank,
"Value of i must be less than extents_type::rank() (N4950 [mdspan.layout.left.obs]/6).");
return _Exts._Fwd_prod_of_extents(_Idx);
return _Fwd_prod_of_extents<extents_type>::_Calculate(_Exts, _Idx);
}

template <class _OtherExtents>
Expand Down Expand Up @@ -518,7 +576,7 @@ public:
}

_NODISCARD constexpr index_type required_span_size() const noexcept {
return _Exts._Fwd_prod_of_extents(extents_type::_Rank);
return _Fwd_prod_of_extents<extents_type>::_Calculate(_Exts, extents_type::_Rank);
}

template <class... _IndexTypes>
Expand Down Expand Up @@ -561,7 +619,7 @@ public:
{
_STL_VERIFY(_Idx < extents_type::_Rank,
"Value of i must be less than extents_type::rank() (N4950 [mdspan.layout.right.obs]/6).");
return _Exts._Rev_prod_of_extents(_Idx);
return _Rev_prod_of_extents<extents_type>::_Calculate(_Exts, _Idx);
}

template <class _OtherExtents>
Expand Down Expand Up @@ -742,7 +800,7 @@ public:
if constexpr (extents_type::rank() == 0) {
return true;
} else {
return required_span_size() == _Exts._Fwd_prod_of_extents(extents_type::_Rank);
return required_span_size() == _Fwd_prod_of_extents<extents_type>::_Calculate(_Exts, extents_type::_Rank);
}
}

Expand Down Expand Up @@ -960,7 +1018,8 @@ public:
}

_NODISCARD constexpr size_type size() const noexcept {
return static_cast<size_type>(_Map.extents()._Fwd_prod_of_extents(extents_type::_Rank));
return static_cast<size_type>(
_Fwd_prod_of_extents<extents_type>::_Calculate(_Map.extents(), extents_type::_Rank));
}

_NODISCARD constexpr bool empty() const noexcept {
Expand Down