Skip to content

Commit

Permalink
cartesian_product: Implement internal iteration and rewrite random-…
Browse files Browse the repository at this point in the history
…access

`inc` to fix bugs.
  • Loading branch information
brycelelbach committed Jul 31, 2023
1 parent cc6067a commit c3d7880
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 18 deletions.
111 changes: 94 additions & 17 deletions include/flux/op/cartesian_product.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define FLUX_OP_CARTESIAN_PRODUCT_HPP_INCLUDED

#include <flux/core.hpp>
#include <flux/core/numeric.hpp>
#include <flux/op/from.hpp>

#include <tuple>
Expand Down Expand Up @@ -88,21 +89,36 @@ struct cartesian_product_traits_base {

template <std::size_t I, typename Self>
static constexpr auto ra_inc_impl(Self& self, cursor_type<Self>& cur, distance_t offset)
-> cursor_type<Self>&
-> cursor_type<Self>&
{
if (offset == 0)
return cur;

auto& base = std::get<I>(self.bases_);
auto& this_index = std::get<I>(cur);
auto this_size = flux::size(base);

auto this_sz = flux::size(base);
auto this_offset = offset % this_sz;
auto next_offset = offset/this_sz;
this_index += offset;

// Adjust this cursor by the corrected offset
flux::inc(base, std::get<I>(cur), this_offset);
if (this_index >= 0 && this_index < this_size)
return cur;

// Call the next level down if necessary
if constexpr (I > 0) {
if (next_offset != 0) {
ra_inc_impl<I-1>(self, cur, next_offset);
// If the new index overflows the maximum or underflows zero, calculate the carryover and fix it.
else {
offset = this_index / this_size;
this_index %= this_size;

// Correct for negative index which may happen when underflowing.
if (this_index < 0) {
this_index += this_size;
--offset;
}

// Call the next level down if necessary.
if constexpr (I > 0) {
if (offset != 0) {
ra_inc_impl<I-1>(self, cur, offset);
}
}
}

Expand Down Expand Up @@ -157,25 +173,25 @@ struct cartesian_product_traits_base {
}

template <typename Self>
requires (bidirectional_sequence<const_like_t<Self, Bases>> && ...) &&
(bounded_sequence<const_like_t<Self, Bases>> && ...)
requires (bidirectional_sequence<const_like_t<Self, Bases>> && ...) &&
(bounded_sequence<const_like_t<Self, Bases>> && ...)
static constexpr auto dec(Self& self, cursor_type<Self>& cur) -> cursor_type<Self>&
{
return dec_impl<sizeof...(Bases) - 1>(self, cur);
}

template <typename Self>
requires (random_access_sequence<const_like_t<Self, Bases>> && ...) &&
(sized_sequence<const_like_t<Self, Bases>> && ...)
requires (random_access_sequence<const_like_t<Self, Bases>> && ...) &&
(sized_sequence<const_like_t<Self, Bases>> && ...)
static constexpr auto inc(Self& self, cursor_type<Self>& cur, distance_t offset)
-> cursor_type<Self>&
-> cursor_type<Self>&
{
return ra_inc_impl<sizeof...(Bases) - 1>(self, cur, offset);
}

template <typename Self>
requires (random_access_sequence<const_like_t<Self, Bases>> && ...) &&
(sized_sequence<const_like_t<Self, Bases>> && ...)
requires (random_access_sequence<const_like_t<Self, Bases>> && ...) &&
(sized_sequence<const_like_t<Self, Bases>> && ...)
static constexpr auto distance(Self& self,
cursor_type<Self> const& from,
cursor_type<Self> const& to) -> distance_t
Expand All @@ -193,6 +209,29 @@ struct cartesian_product_traits_base {
}
};

template <typename Self, typename I>
struct cartesian_product_partial_cursor;

template <typename Self>
struct cartesian_product_partial_cursor<Self,
std::integral_constant<std::size_t, std::tuple_size_v<cursor_t<Self>> - 1>>
{
using type = std::tuple<std::tuple_element_t<std::tuple_size_v<cursor_t<Self>> - 1, cursor_t<Self>>>;
};

template <typename Self, typename I>
struct cartesian_product_partial_cursor
{
using type = decltype(std::tuple_cat(
std::declval<std::tuple<std::tuple_element_t<I::value, cursor_t<Self>>>>(),
std::declval<typename cartesian_product_partial_cursor<Self,
std::integral_constant<std::size_t, I::value + 1>>::type>()));
};

template <typename Self, std::size_t I>
using cartesian_product_partial_cursor_t =
typename cartesian_product_partial_cursor<Self, std::integral_constant<std::size_t, I>>::type;


} // end namespace detail

Expand All @@ -216,6 +255,37 @@ struct sequence_traits<detail::cartesian_product_adaptor<Bases...>>
}(std::index_sequence_for<Bases...>{});
}

template <std::size_t I, typename Self, typename Function,
typename... PartialCursor>
static constexpr auto for_each_while_impl(Self& self,
Function&& func,
PartialCursor&&... partial_cursor)
-> std::tuple<bool, detail::cartesian_product_partial_cursor_t<Self, I>>
{
// We need to iterate right to left.
if constexpr (I == sizeof...(Bases) - 1) {
bool keep_going = true;
auto this_current = flux::for_each_while(std::get<I>(self.bases_),
[&](auto&& elem) {
keep_going = std::invoke(func,
cursor_t<Self>(FLUX_FWD(partial_cursor)..., FLUX_FWD(elem)));
return keep_going;
});
return std::tuple(keep_going, std::tuple(std::move(this_current)));
} else {
bool keep_going = true;
detail::cartesian_product_partial_cursor_t<Self, I+1> nested_current;
auto this_current = flux::for_each_while(std::get<I>(self.bases_),
[&](auto&& elem) {
std::tie(keep_going, nested_current) = for_each_while_impl<I+1>(
self, func, FLUX_FWD(partial_cursor)..., FLUX_FWD(elem));
return keep_going;
});
return std::tuple(keep_going,
std::tuple_cat(std::tuple(std::move(this_current)), std::move(nested_current)));
}
}

public:
using value_type = std::tuple<value_t<Bases>...>;

Expand All @@ -242,6 +312,13 @@ struct sequence_traits<detail::cartesian_product_adaptor<Bases...>>
{
return read_(flux::move_at_unchecked, self, cur);
}

template <typename Self, typename Function>
static constexpr auto for_each_while(Self& self, Function&& func)
-> cursor_t<Self>
{
return std::get<1>(for_each_while_impl<0>(self, FLUX_FWD(func)));
}
};

FLUX_EXPORT inline constexpr auto cartesian_product = detail::cartesian_product_fn{};
Expand Down
164 changes: 163 additions & 1 deletion test/test_cartesian_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <flux/op/cartesian_product.hpp>
#include <flux/op/reverse.hpp>
#include <flux/source/iota.hpp>
#include <flux/source/empty.hpp>
#include <flux/source/iota.hpp>
#include <flux/op/for_each.hpp>
Expand Down Expand Up @@ -110,6 +111,167 @@ constexpr bool test_cartesian_product()
}
}

{
std::array arr1{100, 200};
std::array arr2{1.0f, 2.0f, 3.0f, 4.0f};
std::array arr3{0ULL, 2ULL, 4ULL};

auto cart = flux::cartesian_product(flux::mut_ref(arr1), flux::mut_ref(arr2), flux::mut_ref(arr3));

using C = decltype(cart);

static_assert(flux::sequence<C>);
static_assert(flux::multipass_sequence<C>);
static_assert(flux::bidirectional_sequence<C>);
static_assert(flux::random_access_sequence<C>);
static_assert(not flux::contiguous_sequence<C>);
static_assert(flux::bounded_sequence<C>);
static_assert(flux::sized_sequence<C>);

static_assert(flux::sequence<C const>);
static_assert(flux::multipass_sequence<C const>);
static_assert(flux::bidirectional_sequence<C const>);
static_assert(flux::random_access_sequence<C const>);
static_assert(not flux::contiguous_sequence<C const>);
static_assert(flux::bounded_sequence<C const>);
static_assert(flux::sized_sequence<C const>);

static_assert(std::same_as<flux::element_t<C>, std::tuple<int&, float&, unsigned long long&>>);
static_assert(std::same_as<flux::value_t<C>, std::tuple<int, float, unsigned long long>>);
static_assert(std::same_as<flux::rvalue_element_t<C>, std::tuple<int&&, float&&, unsigned long long&&>>);

static_assert(std::same_as<flux::element_t<C const>, std::tuple<int&, float&, unsigned long long&>>);
static_assert(std::same_as<flux::value_t<C const>, std::tuple<int, float, unsigned long long>>);
static_assert(std::same_as<flux::rvalue_element_t<C const>, std::tuple<int&&, float&&, unsigned long long&&>>);

STATIC_CHECK(flux::size(cart) == 2 * 4 * 3);

STATIC_CHECK(check_equal(cart, {
std::tuple{100, 1.0f, 0ULL},
std::tuple{100, 1.0f, 2ULL},
std::tuple{100, 1.0f, 4ULL},
std::tuple{100, 2.0f, 0ULL},
std::tuple{100, 2.0f, 2ULL},
std::tuple{100, 2.0f, 4ULL},
std::tuple{100, 3.0f, 0ULL},
std::tuple{100, 3.0f, 2ULL},
std::tuple{100, 3.0f, 4ULL},
std::tuple{100, 4.0f, 0ULL},
std::tuple{100, 4.0f, 2ULL},
std::tuple{100, 4.0f, 4ULL},
std::tuple{200, 1.0f, 0ULL},
std::tuple{200, 1.0f, 2ULL},
std::tuple{200, 1.0f, 4ULL},
std::tuple{200, 2.0f, 0ULL},
std::tuple{200, 2.0f, 2ULL},
std::tuple{200, 2.0f, 4ULL},
std::tuple{200, 3.0f, 0ULL},
std::tuple{200, 3.0f, 2ULL},
std::tuple{200, 3.0f, 4ULL},
std::tuple{200, 4.0f, 0ULL},
std::tuple{200, 4.0f, 2ULL},
std::tuple{200, 4.0f, 4ULL}
}));

STATIC_CHECK(flux::distance(cart, cart.first(), cart.last()) == 2 * 4 * 3);

{
auto cur = flux::next(cart, cart.first(), 3);
STATIC_CHECK(cart[cur] == std::tuple{100, 2.0f, 0ULL});
flux::inc(cart, cur, -3);
STATIC_CHECK(cart[cur] == std::tuple{100, 1.0f, 0ULL});
}
}

{
auto cart = flux::cartesian_product(flux::ints(0, 4), flux::ints(0, 2), flux::ints(0, 3));

using C = decltype(cart);

static_assert(flux::sequence<C>);
static_assert(flux::multipass_sequence<C>);
static_assert(flux::bidirectional_sequence<C>);
static_assert(flux::random_access_sequence<C>);
static_assert(not flux::contiguous_sequence<C>);
static_assert(flux::bounded_sequence<C>);
static_assert(flux::sized_sequence<C>);

static_assert(flux::sequence<C const>);
static_assert(flux::multipass_sequence<C const>);
static_assert(flux::bidirectional_sequence<C const>);
static_assert(flux::random_access_sequence<C const>);
static_assert(not flux::contiguous_sequence<C const>);
static_assert(flux::bounded_sequence<C const>);
static_assert(flux::sized_sequence<C const>);

static_assert(std::same_as<flux::element_t<C>, std::tuple<long, long, long>>);
static_assert(std::same_as<flux::value_t<C>, std::tuple<long, long, long>>);
static_assert(std::same_as<flux::rvalue_element_t<C>, std::tuple<long, long, long>>);

static_assert(std::same_as<flux::element_t<C const>, std::tuple<long, long, long>>);
static_assert(std::same_as<flux::value_t<C const>, std::tuple<long, long, long>>);
static_assert(std::same_as<flux::rvalue_element_t<C const>, std::tuple<long, long, long>>);

STATIC_CHECK(flux::size(cart) == 4 * 2 * 3);

STATIC_CHECK(check_equal(cart, {
std::tuple{0, 0, 0},
std::tuple{0, 0, 1},
std::tuple{0, 0, 2},
std::tuple{0, 1, 0},
std::tuple{0, 1, 1},
std::tuple{0, 1, 2},
std::tuple{1, 0, 0},
std::tuple{1, 0, 1},
std::tuple{1, 0, 2},
std::tuple{1, 1, 0},
std::tuple{1, 1, 1},
std::tuple{1, 1, 2},
std::tuple{2, 0, 0},
std::tuple{2, 0, 1},
std::tuple{2, 0, 2},
std::tuple{2, 1, 0},
std::tuple{2, 1, 1},
std::tuple{2, 1, 2},
std::tuple{3, 0, 0},
std::tuple{3, 0, 1},
std::tuple{3, 0, 2},
std::tuple{3, 1, 0},
std::tuple{3, 1, 1},
std::tuple{3, 1, 2},
}));

STATIC_CHECK(flux::distance(cart, cart.first(), cart.last()) == 4 * 2 * 3);

{
STATIC_CHECK(flux::next(cart, cart.first(), 6) == std::tuple{1, 0, 0});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 6), 1) == std::tuple{1, 0, 1});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 6), 2) == std::tuple{1, 0, 2});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 6), 3) == std::tuple{1, 1, 0});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 6), 4) == std::tuple{1, 1, 1});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 6), 5) == std::tuple{1, 1, 2});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 6), -1) == std::tuple{0, 1, 2});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 6), -2) == std::tuple{0, 1, 1});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 6), -3) == std::tuple{0, 1, 0});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 6), -4) == std::tuple{0, 0, 2});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 6), -5) == std::tuple{0, 0, 1});

STATIC_CHECK(flux::next(cart, cart.first(), 11) == std::tuple{1, 1, 2});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 11), 1) == std::tuple{2, 0, 0});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 11), 2) == std::tuple{2, 0, 1});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 11), 3) == std::tuple{2, 0, 2});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 11), 4) == std::tuple{2, 1, 0});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 11), 5) == std::tuple{2, 1, 1});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 11), -1) == std::tuple{1, 1, 1});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 11), -2) == std::tuple{1, 1, 0});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 11), -3) == std::tuple{1, 0, 2});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 11), -4) == std::tuple{1, 0, 1});
STATIC_CHECK(flux::next(cart, flux::next(cart, cart.first(), 11), -5) == std::tuple{1, 0, 0});
}
}

// TODO: Product with a zero-sized sequence works and produces an empty sequence

// Test unpack()
{
int vals[3][3] = {};
Expand All @@ -136,4 +298,4 @@ static_assert(test_cartesian_product());
TEST_CASE("cartesian_product")
{
REQUIRE(test_cartesian_product());
}
}

0 comments on commit c3d7880

Please sign in to comment.