Skip to content
59 changes: 40 additions & 19 deletions projects/composablekernel/include/ck/utility/sequence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,31 +597,52 @@ struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMa
{
};

template <typename SeqMap>
struct sequence_map_inverse
// Invert a permutation sequence: given X2Y = {a, b, c, ...}, compute Y2X where Y2X[X2Y[i]] = i
// Example: Sequence<2,0,1> (meaning pos0->2, pos1->0, pos2->1) inverts to Sequence<1,2,0>
//
// Why this implementation is faster to compile than recursive templates:
//
// The old recursive approach created a new template type for each element:
// sequence_map_inverse<Seq<2,0,1>> -> sequence_map_inverse<Seq<0,1>> ->
// sequence_map_inverse<Seq<1>>
// Each "->" is a new type the compiler must create, track, and manage. For N elements, that's
// N template types, each with overhead (name mangling, debug info, symbol table entries).
//
// This implementation uses a constexpr for loop to build the inverse in O(N) operations:
// For input Sequence<2,0,1>, the loop sets result[input[pos]] = pos for each position:
// pos=0: result[2]=0, pos=1: result[0]=1, pos=2: result[1]=2
// This builds the inverse permutation in a single pass with O(1) template instantiation depth.
//
template <index_t... Is>
struct sequence_map_inverse<Sequence<Is...>>
{
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
private:
struct InverseArray
{
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});

using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
type;
index_t data[sizeof...(Is)] = {};
};

template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
static constexpr auto build_inverse()
{
using type = WorkingY2X;
};
InverseArray result{};
constexpr index_t input[] = {Is...};
for(index_t pos = 0; pos < static_cast<index_t>(sizeof...(Is)); ++pos)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed with @cgmillette and for-loop is more readable while impact on build time is non-measurable

{
result.data[input[pos]] = pos;
}
return result;
}

using type =
typename sequence_map_inverse_impl<SeqMap,
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
0,
SeqMap::Size()>::type;
static constexpr InverseArray inverse = build_inverse();

template <index_t... Positions>
static constexpr auto compute(Sequence<Positions...>)
{
return Sequence<inverse.data[Positions]...>{};
}

public:
using type = decltype(compute(make_index_sequence<sizeof...(Is)>{}));
};

template <index_t... Xs, index_t... Ys>
Expand Down
Loading