Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 144 additions & 168 deletions cpp/src/arrow/compute/kernels/vector_rank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,113 @@ void MarkDuplicates(const NullPartitionResult& sorted, ValueSelector&& value_sel
}
}

struct RankingsEmitter {
virtual ~RankingsEmitter() = default;
virtual bool NeedsDuplicates() = 0;
virtual Result<Datum> CreateRankings(ExecContext* ctx,
const NullPartitionResult& sorted) = 0;
const RankOptions* GetDefaultRankOptions() {
static const auto kDefaultRankOptions = RankOptions::Defaults();
return &kDefaultRankOptions;
}

const RankPercentileOptions* GetDefaultPercentileRankOptions() {
static const auto kDefaultPercentileRankOptions = RankPercentileOptions::Defaults();
return &kDefaultPercentileRankOptions;
}

template <typename ArrowType>
Result<NullPartitionResult> DoSortAndMarkDuplicate(
ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, const Array& input,
const std::shared_ptr<DataType>& physical_type, const SortOrder order,
const NullPlacement null_placement, bool needs_duplicates) {
using GetView = GetViewType<ArrowType>;
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;

ARROW_ASSIGN_OR_RAISE(auto array_sorter, GetArraySorter(*physical_type));

ArrayType array(input.data());
ARROW_ASSIGN_OR_RAISE(auto sorted,
array_sorter(indices_begin, indices_end, array, 0,
ArraySortOptions(order, null_placement), ctx));

if (needs_duplicates) {
auto value_selector = [&array](int64_t index) {
return GetView::LogicalValue(array.GetView(index));
};
MarkDuplicates(sorted, value_selector);
}
return sorted;
}

template <typename ArrowType>
Result<NullPartitionResult> DoSortAndMarkDuplicate(
ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end,
const ChunkedArray& input, const std::shared_ptr<DataType>& physical_type,
const SortOrder order, const NullPlacement null_placement, bool needs_duplicates) {
auto physical_chunks = GetPhysicalChunks(input, physical_type);
if (physical_chunks.empty()) {
return NullPartitionResult{};
}
ARROW_ASSIGN_OR_RAISE(auto sorted,
SortChunkedArray(ctx, indices_begin, indices_end, physical_type,
physical_chunks, order, null_placement));
if (needs_duplicates) {
const auto arrays = GetArrayPointers(physical_chunks);
auto value_selector = [resolver = ChunkedArrayResolver(span(arrays))](int64_t index) {
return resolver.Resolve(index).Value<ArrowType>();
};
MarkDuplicates(sorted, value_selector);
}
return sorted;
}

template <typename InputType>
class SortAndMarkDuplicate : public TypeVisitor {
public:
SortAndMarkDuplicate(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end,
const InputType& input, const SortOrder order,
const NullPlacement null_placement, const bool needs_duplicate)
: TypeVisitor(),
ctx_(ctx),
indices_begin_(indices_begin),
indices_end_(indices_end),
input_(input),
order_(order),
null_placement_(null_placement),
needs_duplicates_(needs_duplicate),
physical_type_(GetPhysicalType(input.type())) {}

Result<NullPartitionResult> Run() {
RETURN_NOT_OK(physical_type_->Accept(this));
return sorted_;
}

#define VISIT(TYPE) \
Status Visit(const TYPE& type) { \
ARROW_ASSIGN_OR_RAISE( \
sorted_, DoSortAndMarkDuplicate<TYPE>(ctx_, indices_begin_, indices_end_, \
input_, physical_type_, order_, \
null_placement_, needs_duplicates_)); \
return Status::OK(); \
}

VISIT_SORTABLE_PHYSICAL_TYPES(VISIT)

#undef VISIT

private:
ExecContext* ctx_;
uint64_t* indices_begin_;
uint64_t* indices_end_;
const InputType& input_;
const SortOrder order_;
const NullPlacement null_placement_;
const bool needs_duplicates_;
const std::shared_ptr<DataType> physical_type_;
NullPartitionResult sorted_{};
};

// A helper class that emits rankings for the "rank_percentile" function
struct PercentileRankingsEmitter : public RankingsEmitter {
explicit PercentileRankingsEmitter(double factor) : factor_(factor) {}

bool NeedsDuplicates() override { return true; }
struct PercentileRanker {
explicit PercentileRanker(double factor) : factor_(factor) {}

Result<Datum> CreateRankings(ExecContext* ctx,
const NullPartitionResult& sorted) override {
Result<Datum> CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted) {
const int64_t length = sorted.overall_end() - sorted.overall_begin();
ARROW_ASSIGN_OR_RAISE(auto rankings,
MakeMutableFloat64Array(length, ctx->memory_pool()));
Expand Down Expand Up @@ -114,14 +206,10 @@ struct PercentileRankingsEmitter : public RankingsEmitter {
};

// A helper class that emits rankings for the "rank" function
struct OrdinalRankingsEmitter : public RankingsEmitter {
explicit OrdinalRankingsEmitter(RankOptions::Tiebreaker tiebreaker)
: tiebreaker_(tiebreaker) {}
struct OrdinalRanker {
explicit OrdinalRanker(RankOptions::Tiebreaker tiebreaker) : tiebreaker_(tiebreaker) {}

bool NeedsDuplicates() override { return tiebreaker_ != RankOptions::First; }

Result<Datum> CreateRankings(ExecContext* ctx,
const NullPartitionResult& sorted) override {
Result<Datum> CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted) {
const int64_t length = sorted.overall_end() - sorted.overall_begin();
ARROW_ASSIGN_OR_RAISE(auto rankings,
MakeMutableUInt64Array(length, ctx->memory_pool()));
Expand Down Expand Up @@ -186,119 +274,6 @@ struct OrdinalRankingsEmitter : public RankingsEmitter {
const RankOptions::Tiebreaker tiebreaker_;
};

const RankOptions* GetDefaultRankOptions() {
static const auto kDefaultRankOptions = RankOptions::Defaults();
return &kDefaultRankOptions;
}

const RankPercentileOptions* GetDefaultPercentileRankOptions() {
static const auto kDefaultPercentileRankOptions = RankPercentileOptions::Defaults();
return &kDefaultPercentileRankOptions;
}

template <typename InputType, typename RankerType>
class RankerMixin : public TypeVisitor {
public:
RankerMixin(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end,
const InputType& input, const SortOrder order,
const NullPlacement null_placement, RankingsEmitter* emitter)
: TypeVisitor(),
ctx_(ctx),
indices_begin_(indices_begin),
indices_end_(indices_end),
input_(input),
order_(order),
null_placement_(null_placement),
physical_type_(GetPhysicalType(input.type())),
emitter_(emitter) {}

Result<Datum> Run() {
RETURN_NOT_OK(physical_type_->Accept(this));
return emitter_->CreateRankings(ctx_, sorted_);
}

#define VISIT(TYPE) \
Status Visit(const TYPE& type) { \
return static_cast<RankerType*>(this)->template SortAndMarkDuplicates<TYPE>(); \
}

VISIT_SORTABLE_PHYSICAL_TYPES(VISIT)

#undef VISIT

protected:
ExecContext* ctx_;
uint64_t* indices_begin_;
uint64_t* indices_end_;
const InputType& input_;
const SortOrder order_;
const NullPlacement null_placement_;
const std::shared_ptr<DataType> physical_type_;
RankingsEmitter* emitter_;
NullPartitionResult sorted_{};
};

template <typename T>
class Ranker;

template <>
class Ranker<Array> : public RankerMixin<Array, Ranker<Array>> {
public:
using RankerMixin::RankerMixin;

template <typename InType>
Status SortAndMarkDuplicates() {
using GetView = GetViewType<InType>;
using ArrayType = typename TypeTraits<InType>::ArrayType;

ARROW_ASSIGN_OR_RAISE(auto array_sorter, GetArraySorter(*physical_type_));

ArrayType array(input_.data());
ARROW_ASSIGN_OR_RAISE(sorted_,
array_sorter(indices_begin_, indices_end_, array, 0,
ArraySortOptions(order_, null_placement_), ctx_));

if (emitter_->NeedsDuplicates()) {
auto value_selector = [&array](int64_t index) {
return GetView::LogicalValue(array.GetView(index));
};
MarkDuplicates(sorted_, value_selector);
}
return Status::OK();
}
};

template <>
class Ranker<ChunkedArray> : public RankerMixin<ChunkedArray, Ranker<ChunkedArray>> {
public:
template <typename... Args>
explicit Ranker(Args&&... args)
: RankerMixin(std::forward<Args>(args)...),
physical_chunks_(GetPhysicalChunks(input_, physical_type_)) {}

template <typename InType>
Status SortAndMarkDuplicates() {
if (physical_chunks_.empty()) {
return Status::OK();
}
ARROW_ASSIGN_OR_RAISE(
sorted_, SortChunkedArray(ctx_, indices_begin_, indices_end_, physical_type_,
physical_chunks_, order_, null_placement_));
if (emitter_->NeedsDuplicates()) {
const auto arrays = GetArrayPointers(physical_chunks_);
auto value_selector = [resolver =
ChunkedArrayResolver(span(arrays))](int64_t index) {
return resolver.Resolve(index).Value<InType>();
};
MarkDuplicates(sorted_, value_selector);
}
return Status::OK();
}

private:
const ArrayVector physical_chunks_;
};

const FunctionDoc rank_doc(
"Compute ordinal ranks of an array (1-based)",
("This function computes a rank of the input array.\n"
Expand All @@ -324,6 +299,7 @@ const FunctionDoc rank_percentile_doc(
"in RankPercentileOptions."),
{"input"}, "RankPercentileOptions");

template <typename Derived>
class RankMetaFunctionBase : public MetaFunction {
public:
using MetaFunction::MetaFunction;
Expand All @@ -348,66 +324,66 @@ class RankMetaFunctionBase : public MetaFunction {
}

protected:
struct UnpackedOptions {
SortOrder order{SortOrder::Ascending};
NullPlacement null_placement;
std::unique_ptr<RankingsEmitter> emitter;
};

virtual UnpackedOptions UnpackOptions(const FunctionOptions&) const = 0;

template <typename T>
Result<Datum> Rank(const T& input, const FunctionOptions& function_options,
ExecContext* ctx) const {
auto options = UnpackOptions(function_options);
const auto& options =
checked_cast<const typename Derived::FunctionOptionsType&>(function_options);

SortOrder order = SortOrder::Ascending;
if (!options.sort_keys.empty()) {
order = options.sort_keys[0].order;
}

int64_t length = input.length();
ARROW_ASSIGN_OR_RAISE(auto indices,
MakeMutableUInt64Array(length, ctx->memory_pool()));
auto* indices_begin = indices->GetMutableValues<uint64_t>(1);
auto* indices_end = indices_begin + length;
std::iota(indices_begin, indices_end, 0);
auto needs_duplicates = Derived::NeedsDuplicates(options);
ARROW_ASSIGN_OR_RAISE(
auto sorted, SortAndMarkDuplicate(ctx, indices_begin, indices_end, input, order,
options.null_placement, needs_duplicates)
.Run());

Ranker<T> ranker(ctx, indices_begin, indices_end, input, options.order,
options.null_placement, options.emitter.get());
return ranker.Run();
auto ranker = Derived::GetRanker(options);
return ranker.CreateRankings(ctx, sorted);
}
};

class RankMetaFunction : public RankMetaFunctionBase {
class RankMetaFunction : public RankMetaFunctionBase<RankMetaFunction> {
public:
RankMetaFunction()
: RankMetaFunctionBase("rank", Arity::Unary(), rank_doc, GetDefaultRankOptions()) {}
using FunctionOptionsType = RankOptions;
using RankerType = OrdinalRanker;

protected:
UnpackedOptions UnpackOptions(const FunctionOptions& function_options) const override {
const auto& options = checked_cast<const RankOptions&>(function_options);
UnpackedOptions unpacked{
SortOrder::Ascending, options.null_placement,
std::make_unique<OrdinalRankingsEmitter>(options.tiebreaker)};
if (!options.sort_keys.empty()) {
unpacked.order = options.sort_keys[0].order;
}
return unpacked;
static bool NeedsDuplicates(const RankOptions& options) {
return options.tiebreaker != RankOptions::First;
}

static RankerType GetRanker(const RankOptions& options) {
return RankerType(options.tiebreaker);
}

RankMetaFunction()
: RankMetaFunctionBase("rank", Arity::Unary(), rank_doc, GetDefaultRankOptions()) {}
};

class RankPercentileMetaFunction : public RankMetaFunctionBase {
class RankPercentileMetaFunction
: public RankMetaFunctionBase<RankPercentileMetaFunction> {
public:
using FunctionOptionsType = RankPercentileOptions;
using RankerType = PercentileRanker;

static bool NeedsDuplicates(const RankPercentileOptions&) { return true; }

static RankerType GetRanker(const RankPercentileOptions& options) {
return RankerType(options.factor);
}

RankPercentileMetaFunction()
: RankMetaFunctionBase("rank_percentile", Arity::Unary(), rank_percentile_doc,
GetDefaultPercentileRankOptions()) {}

protected:
UnpackedOptions UnpackOptions(const FunctionOptions& function_options) const override {
const auto& options = checked_cast<const RankPercentileOptions&>(function_options);
UnpackedOptions unpacked{SortOrder::Ascending, options.null_placement,
std::make_unique<PercentileRankingsEmitter>(options.factor)};
if (!options.sort_keys.empty()) {
unpacked.order = options.sort_keys[0].order;
}
return unpacked;
}
};

} // namespace
Expand Down
Loading