|  | 
|  | 1 | +// Licensed to the Apache Software Foundation (ASF) under one | 
|  | 2 | +// or more contributor license agreements.  See the NOTICE file | 
|  | 3 | +// distributed with this work for additional information | 
|  | 4 | +// regarding copyright ownership.  The ASF licenses this file | 
|  | 5 | +// to you under the Apache License, Version 2.0 (the | 
|  | 6 | +// "License"); you may not use this file except in compliance | 
|  | 7 | +// with the License.  You may obtain a copy of the License at | 
|  | 8 | +// | 
|  | 9 | +//   http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 10 | +// | 
|  | 11 | +// Unless required by applicable law or agreed to in writing, | 
|  | 12 | +// software distributed under the License is distributed on an | 
|  | 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | 
|  | 14 | +// KIND, either express or implied.  See the License for the | 
|  | 15 | +// specific language governing permissions and limitations | 
|  | 16 | +// under the License. | 
|  | 17 | + | 
|  | 18 | +// Vector kernels for pairwise computation | 
|  | 19 | + | 
|  | 20 | +#include <iostream> | 
|  | 21 | +#include <memory> | 
|  | 22 | +#include "arrow/builder.h" | 
|  | 23 | +#include "arrow/compute/api_vector.h" | 
|  | 24 | +#include "arrow/compute/exec.h" | 
|  | 25 | +#include "arrow/compute/function.h" | 
|  | 26 | +#include "arrow/compute/kernel.h" | 
|  | 27 | +#include "arrow/compute/kernels/base_arithmetic_internal.h" | 
|  | 28 | +#include "arrow/compute/kernels/codegen_internal.h" | 
|  | 29 | +#include "arrow/compute/registry.h" | 
|  | 30 | +#include "arrow/compute/util.h" | 
|  | 31 | +#include "arrow/status.h" | 
|  | 32 | +#include "arrow/type.h" | 
|  | 33 | +#include "arrow/type_fwd.h" | 
|  | 34 | +#include "arrow/type_traits.h" | 
|  | 35 | +#include "arrow/util/bit_util.h" | 
|  | 36 | +#include "arrow/util/checked_cast.h" | 
|  | 37 | +#include "arrow/util/logging.h" | 
|  | 38 | +#include "arrow/visit_type_inline.h" | 
|  | 39 | + | 
|  | 40 | +namespace arrow::compute::internal { | 
|  | 41 | + | 
|  | 42 | +// We reuse the kernel exec function of a scalar binary function to compute pairwise | 
|  | 43 | +// results. For example, for pairwise_diff, we reuse subtract's kernel exec. | 
|  | 44 | +struct PairwiseState : KernelState { | 
|  | 45 | +  PairwiseState(const PairwiseOptions& options, ArrayKernelExec scalar_exec) | 
|  | 46 | +      : periods(options.periods), scalar_exec(scalar_exec) {} | 
|  | 47 | + | 
|  | 48 | +  int64_t periods; | 
|  | 49 | +  ArrayKernelExec scalar_exec; | 
|  | 50 | +}; | 
|  | 51 | + | 
|  | 52 | +/// A generic pairwise implementation that can be reused by different ops. | 
|  | 53 | +Status PairwiseExecImpl(KernelContext* ctx, const ArraySpan& input, | 
|  | 54 | +                        const ArrayKernelExec& scalar_exec, int64_t periods, | 
|  | 55 | +                        ArrayData* result) { | 
|  | 56 | +  // We only compute values in the region where the input-with-offset overlaps | 
|  | 57 | +  // the original input. The margin where these do not overlap gets filled with null. | 
|  | 58 | +  auto margin_length = std::min(abs(periods), input.length); | 
|  | 59 | +  auto computed_length = input.length - margin_length; | 
|  | 60 | +  auto margin_start = periods > 0 ? 0 : computed_length; | 
|  | 61 | +  auto computed_start = periods > 0 ? margin_length : 0; | 
|  | 62 | +  auto left_start = computed_start; | 
|  | 63 | +  auto right_start = margin_length - computed_start; | 
|  | 64 | +  // prepare bitmap | 
|  | 65 | +  bit_util::ClearBitmap(result->buffers[0]->mutable_data(), margin_start, margin_length); | 
|  | 66 | +  for (int64_t i = computed_start; i < computed_start + computed_length; i++) { | 
|  | 67 | +    if (input.IsValid(i) && input.IsValid(i - periods)) { | 
|  | 68 | +      bit_util::SetBit(result->buffers[0]->mutable_data(), i); | 
|  | 69 | +    } else { | 
|  | 70 | +      bit_util::ClearBit(result->buffers[0]->mutable_data(), i); | 
|  | 71 | +    } | 
|  | 72 | +  } | 
|  | 73 | +  // prepare input span | 
|  | 74 | +  ArraySpan left(input); | 
|  | 75 | +  left.SetSlice(left_start, computed_length); | 
|  | 76 | +  ArraySpan right(input); | 
|  | 77 | +  right.SetSlice(right_start, computed_length); | 
|  | 78 | +  // prepare output span | 
|  | 79 | +  ArraySpan output_span; | 
|  | 80 | +  output_span.SetMembers(*result); | 
|  | 81 | +  output_span.offset = computed_start; | 
|  | 82 | +  output_span.length = computed_length; | 
|  | 83 | +  ExecResult output{output_span}; | 
|  | 84 | +  // execute scalar function | 
|  | 85 | +  RETURN_NOT_OK(scalar_exec(ctx, ExecSpan({left, right}, computed_length), &output)); | 
|  | 86 | + | 
|  | 87 | +  return Status::OK(); | 
|  | 88 | +} | 
|  | 89 | + | 
|  | 90 | +Status PairwiseExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { | 
|  | 91 | +  const auto& state = checked_cast<const PairwiseState&>(*ctx->state()); | 
|  | 92 | +  auto input = batch[0].array; | 
|  | 93 | +  RETURN_NOT_OK(PairwiseExecImpl(ctx, batch[0].array, state.scalar_exec, state.periods, | 
|  | 94 | +                                 out->array_data_mutable())); | 
|  | 95 | +  return Status::OK(); | 
|  | 96 | +} | 
|  | 97 | + | 
|  | 98 | +const FunctionDoc pairwise_diff_doc( | 
|  | 99 | +    "Compute first order difference of an array", | 
|  | 100 | +    ("Computes the first order difference of an array, It internally calls \n" | 
|  | 101 | +     "the scalar function \"subtract\" to compute \n differences, so its \n" | 
|  | 102 | +     "behavior and supported types are the same as \n" | 
|  | 103 | +     "\"subtract\". The period can be specified in :struct:`PairwiseOptions`.\n" | 
|  | 104 | +     "\n" | 
|  | 105 | +     "Results will wrap around on integer overflow. Use function \n" | 
|  | 106 | +     "\"pairwise_diff_checked\" if you want overflow to return an error."), | 
|  | 107 | +    {"input"}, "PairwiseOptions"); | 
|  | 108 | + | 
|  | 109 | +const FunctionDoc pairwise_diff_checked_doc( | 
|  | 110 | +    "Compute first order difference of an array", | 
|  | 111 | +    ("Computes the first order difference of an array, It internally calls \n" | 
|  | 112 | +     "the scalar function \"subtract_checked\" (or the checked variant) to compute \n" | 
|  | 113 | +     "differences, so its behavior and supported types are the same as \n" | 
|  | 114 | +     "\"subtract_checked\". The period can be specified in :struct:`PairwiseOptions`.\n" | 
|  | 115 | +     "\n" | 
|  | 116 | +     "This function returns an error on overflow. For a variant that doesn't \n" | 
|  | 117 | +     "fail on overflow, use function \"pairwise_diff\"."), | 
|  | 118 | +    {"input"}, "PairwiseOptions"); | 
|  | 119 | + | 
|  | 120 | +const PairwiseOptions* GetDefaultPairwiseOptions() { | 
|  | 121 | +  static const auto kDefaultPairwiseOptions = PairwiseOptions::Defaults(); | 
|  | 122 | +  return &kDefaultPairwiseOptions; | 
|  | 123 | +} | 
|  | 124 | + | 
|  | 125 | +struct PairwiseKernelData { | 
|  | 126 | +  InputType input; | 
|  | 127 | +  OutputType output; | 
|  | 128 | +  ArrayKernelExec exec; | 
|  | 129 | +}; | 
|  | 130 | + | 
|  | 131 | +void RegisterPairwiseDiffKernels(std::string_view func_name, | 
|  | 132 | +                                 std::string_view base_func_name, const FunctionDoc& doc, | 
|  | 133 | +                                 FunctionRegistry* registry) { | 
|  | 134 | +  VectorKernel kernel; | 
|  | 135 | +  kernel.can_execute_chunkwise = false; | 
|  | 136 | +  kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; | 
|  | 137 | +  kernel.mem_allocation = MemAllocation::PREALLOCATE; | 
|  | 138 | +  kernel.init = OptionsWrapper<PairwiseOptions>::Init; | 
|  | 139 | +  auto func = std::make_shared<VectorFunction>(std::string(func_name), Arity::Unary(), | 
|  | 140 | +                                               doc, GetDefaultPairwiseOptions()); | 
|  | 141 | + | 
|  | 142 | +  auto base_func_result = registry->GetFunction(std::string(base_func_name)); | 
|  | 143 | +  DCHECK_OK(base_func_result.status()); | 
|  | 144 | +  const auto& base_func = checked_cast<const ScalarFunction&>(**base_func_result); | 
|  | 145 | +  DCHECK_EQ(base_func.arity().num_args, 2); | 
|  | 146 | + | 
|  | 147 | +  for (const auto& base_func_kernel : base_func.kernels()) { | 
|  | 148 | +    const auto& base_func_kernel_sig = base_func_kernel->signature; | 
|  | 149 | +    if (!base_func_kernel_sig->in_types()[0].Equals( | 
|  | 150 | +            base_func_kernel_sig->in_types()[1])) { | 
|  | 151 | +      continue; | 
|  | 152 | +    } | 
|  | 153 | +    OutputType out_type(base_func_kernel_sig->out_type()); | 
|  | 154 | +    // Need to wrap base output resolver | 
|  | 155 | +    if (out_type.kind() == OutputType::COMPUTED) { | 
|  | 156 | +      out_type = | 
|  | 157 | +          OutputType([base_resolver = base_func_kernel_sig->out_type().resolver()]( | 
|  | 158 | +                         KernelContext* ctx, const std::vector<TypeHolder>& input_types) { | 
|  | 159 | +            return base_resolver(ctx, {input_types[0], input_types[0]}); | 
|  | 160 | +          }); | 
|  | 161 | +    } | 
|  | 162 | + | 
|  | 163 | +    kernel.signature = | 
|  | 164 | +        KernelSignature::Make({base_func_kernel_sig->in_types()[0]}, out_type); | 
|  | 165 | +    kernel.exec = PairwiseExec; | 
|  | 166 | +    kernel.init = [scalar_exec = base_func_kernel->exec](KernelContext* ctx, | 
|  | 167 | +                                                         const KernelInitArgs& args) { | 
|  | 168 | +      return std::make_unique<PairwiseState>( | 
|  | 169 | +          checked_cast<const PairwiseOptions&>(*args.options), scalar_exec); | 
|  | 170 | +    }; | 
|  | 171 | +    DCHECK_OK(func->AddKernel(kernel)); | 
|  | 172 | +  } | 
|  | 173 | + | 
|  | 174 | +  DCHECK_OK(registry->AddFunction(std::move(func))); | 
|  | 175 | +} | 
|  | 176 | + | 
|  | 177 | +void RegisterVectorPairwise(FunctionRegistry* registry) { | 
|  | 178 | +  RegisterPairwiseDiffKernels("pairwise_diff", "subtract", pairwise_diff_doc, registry); | 
|  | 179 | +  RegisterPairwiseDiffKernels("pairwise_diff_checked", "subtract_checked", | 
|  | 180 | +                              pairwise_diff_checked_doc, registry); | 
|  | 181 | +} | 
|  | 182 | + | 
|  | 183 | +}  // namespace arrow::compute::internal | 
0 commit comments