Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Oct 17, 2024
1 parent 13aa602 commit 8a1c020
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 45 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ set(ARROW_COMPUTE_SRCS
compute/light_array_internal.cc
compute/ordering.cc
compute/registry.cc
compute/special_form.cc
compute/kernels/codegen_internal.cc
compute/kernels/ree_util_internal.cc
compute/kernels/scalar_cast_boolean.cc
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ add_arrow_test(internals_test

add_arrow_compute_test(expression_test SOURCES expression_test.cc)

add_arrow_compute_test(special_form_test SOURCES special_form_test.cc)

add_arrow_benchmark(function_benchmark PREFIX "arrow-compute")

add_subdirectory(kernels)
Expand Down
88 changes: 70 additions & 18 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,25 @@ const DataType* Expression::type() const {
return parameter->type.type;
}

if (const Call* call = this->call()) {
return call->type.type;
if (const Special* special = this->special()) {
return special->type.type;
}

return CallNotNull(*this)->type.type;
}

bool Expression::selection_vector_aware() const {
DCHECK(IsBound());

if (literal() || field_ref()) {
return true;
}

return SpecialNotNull(*this)->type.type;
if (auto special = this->special()) {
return special->selection_vector_aware;
}

return CallNotNull(*this)->selection_vector_aware;
}

namespace {
Expand Down Expand Up @@ -193,6 +207,10 @@ std::string Expression::ToString() const {
return ref->ToString();
}

if (auto sp = special()) {
return sp->special_form->name + "(special)";
}

auto call = CallNotNull(*this);
auto binary = [&](std::string op) {
return "(" + call->arguments[0].ToString() + " " + op + " " +
Expand Down Expand Up @@ -261,6 +279,10 @@ bool Expression::Equals(const Expression& other) const {
return ref->Equals(*other.field_ref());
}

if (auto special = this->special()) {
return true;
}

auto call = CallNotNull(*this);
auto other_call = CallNotNull(other);

Expand Down Expand Up @@ -296,11 +318,11 @@ size_t Expression::hash() const {
return ref->hash();
}

if (auto c = call()) {
return c->hash;
if (auto special = this->special()) {
return special->hash;
}

return SpecialNotNull(*this)->hash;
return CallNotNull(*this)->hash;
}

bool Expression::IsBound() const {
Expand All @@ -324,6 +346,8 @@ bool Expression::IsScalarExpression() const {

if (field_ref()) return true;

if (special()) return true;

auto call = CallNotNull(*this);

for (const Expression& arg : call->arguments) {
Expand Down Expand Up @@ -384,6 +408,8 @@ bool Expression::IsSatisfiable() const {

if (field_ref()) return true;

if (special()) return true;

auto call = CallNotNull(*this);

// invert(true_unless_null(x)) is always false or null by definition
Expand Down Expand Up @@ -581,6 +607,11 @@ Result<Expression> BindNonRecursive(Expression::Call call, bool insert_implicit_
kernel_context.SetState(call.kernel_state.get());
}

call.selection_vector_aware =
call.kernel->selection_vector_aware &&
std::all_of(call.arguments.begin(), call.arguments.end(),
[](const Expression& arg) { return arg.selection_vector_aware(); });

ARROW_ASSIGN_OR_RAISE(
call.type, call.kernel->signature->out_type().Resolve(&kernel_context, types));
return Status::OK();
Expand Down Expand Up @@ -641,6 +672,9 @@ Result<Expression> BindNonRecursive(Expression::Special special,

ARROW_ASSIGN_OR_RAISE(special.type,
special.special_form->Resolve(&special.arguments, exec_context));
// Selection vector awareness-es of the subexpressions is fully taken over by the
// special form so not recursive.
special.selection_vector_aware = special.special_form->selection_vector_aware;
return Expression(std::move(special));
}

Expand All @@ -665,19 +699,19 @@ Result<Expression> BindImpl(Expression expr, const TypeOrSchema& in,
return Expression{std::move(param)};
}

if (expr.call()) {
auto call = *expr.call();
for (auto& argument : call.arguments) {
if (expr.special()) {
auto special = *expr.special();
for (auto& argument : special.arguments) {
ARROW_ASSIGN_OR_RAISE(argument, BindImpl(std::move(argument), in, exec_context));
}
return BindNonRecursive(call, /*insert_implicit_casts=*/true, exec_context);
return BindNonRecursive(special, /*insert_implicit_casts=*/true, exec_context);
}

auto special = *SpecialNotNull(expr);
for (auto& argument : special.arguments) {
auto call = *CallNotNull(expr);
for (auto& argument : call.arguments) {
ARROW_ASSIGN_OR_RAISE(argument, BindImpl(std::move(argument), in, exec_context));
}
return BindNonRecursive(special, /*insert_implicit_casts=*/true, exec_context);
return BindNonRecursive(call, /*insert_implicit_casts=*/true, exec_context);
}

} // namespace
Expand Down Expand Up @@ -783,6 +817,8 @@ Result<Datum> ExecuteScalarExpression(const Expression& expr, const ExecBatch& i
"ExecuteScalarExpression cannot Execute non-scalar expression ", expr.ToString());
}

DCHECK(!input.selection_vector || expr.selection_vector_aware());

if (auto lit = expr.literal()) return *lit;

if (auto param = expr.parameter()) {
Expand Down Expand Up @@ -870,19 +906,35 @@ std::vector<FieldRef> FieldsInExpression(const Expression& expr) {
return {*ref};
}

std::vector<FieldRef> fields;
for (const Expression& arg : CallNotNull(expr)->arguments) {
auto argument_fields = FieldsInExpression(arg);
std::move(argument_fields.begin(), argument_fields.end(), std::back_inserter(fields));
const auto& fields = [](const auto& expr) {
std::vector<FieldRef> fields;
for (const Expression& arg : expr->arguments) {
auto argument_fields = FieldsInExpression(arg);
std::move(argument_fields.begin(), argument_fields.end(),
std::back_inserter(fields));
}
return fields;
};

if (auto sp = expr.special()) {
return fields(sp);
}
return fields;

return fields(CallNotNull(expr));
}

bool ExpressionHasFieldRefs(const Expression& expr) {
if (expr.literal()) return false;

if (expr.field_ref()) return true;

if (auto sp = expr.special()) {
for (const Expression& arg : sp->arguments) {
if (ExpressionHasFieldRefs(arg)) return true;
}
return false;
}

for (const Expression& arg : CallNotNull(expr)->arguments) {
if (ExpressionHasFieldRefs(arg)) return true;
}
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ARROW_EXPORT Expression {
const Kernel* kernel = NULLPTR;
std::shared_ptr<KernelState> kernel_state;
TypeHolder type;
bool selection_vector_aware;

void ComputeHash();
};
Expand All @@ -70,6 +71,7 @@ class ARROW_EXPORT Expression {

// post-Bind properties:
TypeHolder type;
bool selection_vector_aware;

void ComputeHash();
};
Expand Down Expand Up @@ -134,6 +136,8 @@ class ARROW_EXPORT Expression {
// XXX someday
// NullGeneralization::type nullable() const;

bool selection_vector_aware() const;

struct Parameter {
FieldRef ref;

Expand Down
6 changes: 0 additions & 6 deletions cpp/src/arrow/compute/expression_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@ inline const Expression::Call* CallNotNull(const Expression& expr) {
return call;
}

inline const Expression::Special* SpecialNotNull(const Expression& expr) {
auto special = expr.special();
DCHECK_NE(special, nullptr);
return special;
}

inline std::vector<TypeHolder> GetTypes(const std::vector<Expression>& exprs) {
std::vector<TypeHolder> types(exprs.size());
for (size_t i = 0; i < exprs.size(); ++i) {
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,8 @@ struct ARROW_EXPORT Kernel {
/// so that the most optimized kernel supported on a host's processor can be chosen.
SimdLevel::type simd_level = SimdLevel::NONE;

bool selection_vector_aware = false;

// Additional kernel-specific data
std::shared_ptr<KernelState> data;
};
Expand Down
120 changes: 120 additions & 0 deletions cpp/src/arrow/compute/special_form.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/compute/special_form.h"

#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/expression.h"
#include "arrow/compute/expression_internal.h"
#include "arrow/compute/registry.h"

namespace arrow::compute {

Result<Datum> Permute(const Datum& values, const Datum& indices, int64_t output_length) {
return Datum();
}

Result<TypeHolder> IfElseSpecialForm::Resolve(std::vector<Expression>* arguments,
ExecContext* exec_context) const {
ARROW_ASSIGN_OR_RAISE(auto function,
exec_context->func_registry()->GetFunction("if_else"));
std::vector<TypeHolder> types = GetTypes(*arguments);

// TODO: DispatchBest and implicit cast.
ARROW_ASSIGN_OR_RAISE(auto maybe_exact_match, function->DispatchExact(types));
compute::KernelContext kernel_context(exec_context, maybe_exact_match);
if (maybe_exact_match->init) {
const FunctionOptions* options = function->default_options();
ARROW_ASSIGN_OR_RAISE(
auto kernel_state,
maybe_exact_match->init(&kernel_context, {maybe_exact_match, types, options}));
kernel_context.SetState(kernel_state.get());
}
return maybe_exact_match->signature->out_type().Resolve(&kernel_context, types);
}

namespace {

Result<ExecBatch> TakeBySelectionVector(const ExecBatch& input,
const Datum& selection_vector,
ExecContext* exec_context) {
std::vector<Datum> values(input.num_values());
for (int i = 0; i < input.num_values(); ++i) {
ARROW_ASSIGN_OR_RAISE(
values[i], Take(input[i], selection_vector, TakeOptions{/*boundcheck=*/false},
exec_context));
}
return ExecBatch::Make(std::move(values), selection_vector.length());
}

std::shared_ptr<ChunkedArray> ChunkedArrayFromDatums(const std::vector<Datum>& datums) {
std::vector<std::shared_ptr<Array>> chunks;
for (const auto& datum : datums) {
DCHECK(datum.is_arraylike());
if (datum.is_array()) {
chunks.push_back(datum.make_array());
} else {
DCHECK(datum.is_chunked_array());
for (const auto& chunk : datum.chunked_array()->chunks()) {
chunks.push_back(chunk);
}
}
}
return std::make_shared<ChunkedArray>(std::move(chunks));
}

} // namespace

Result<Datum> IfElseSpecialForm::Execute(const std::vector<Expression>& arguments,
const ExecBatch& input,
compute::ExecContext* exec_context) const {
// TODO: Selection vector aware path.
DCHECK_EQ(input.selection_vector, nullptr);
DCHECK_EQ(arguments.size(), 3);

const auto& cond_expr = arguments[0];
DCHECK_EQ(cond_expr.type()->id(), Type::BOOL);
const auto& if_true_expr = arguments[1];
const auto& if_false_expr = arguments[2];
DCHECK_EQ(if_true_expr.type()->id(), if_false_expr.type()->id());

ARROW_ASSIGN_OR_RAISE(auto cond,
ExecuteScalarExpression(cond_expr, input, exec_context));
DCHECK_EQ(cond.type()->id(), Type::BOOL);

ARROW_ASSIGN_OR_RAISE(auto sel_true,
CallFunction("indices_nonzero", {cond}, exec_context));
ARROW_ASSIGN_OR_RAISE(auto input_true,
TakeBySelectionVector(input, sel_true, exec_context));
ARROW_ASSIGN_OR_RAISE(auto if_true,
ExecuteScalarExpression(if_true_expr, input_true, exec_context));

ARROW_ASSIGN_OR_RAISE(auto cond_inverted, CallFunction("invert", {cond}, exec_context));
ARROW_ASSIGN_OR_RAISE(auto sel_false,
CallFunction("indices_nonzero", {cond_inverted}, exec_context));
ARROW_ASSIGN_OR_RAISE(auto input_false,
TakeBySelectionVector(input, sel_false, exec_context));
ARROW_ASSIGN_OR_RAISE(
auto if_false, ExecuteScalarExpression(if_false_expr, input_false, exec_context));

auto result = ChunkedArrayFromDatums({if_true, if_false});
auto sel = ChunkedArrayFromDatums({sel_true, sel_false});
return Permute(result, sel, input.length);
}

} // namespace arrow::compute
Loading

0 comments on commit 8a1c020

Please sign in to comment.