Skip to content

Add unary kernels. #859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions ydb/library/yql/core/arrow_kernels/registry/ut/registry_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ Y_UNIT_TEST_SUITE(TKernelRegistryTest) {
});
}

Y_UNIT_TEST(TestMinus) {
TestOne([](auto& b,auto& ctx) {
auto blockInt32Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Int32));
return b.AddUnaryOp(TKernelRequestBuilder::EUnaryOp::Minus, blockInt32Type, blockInt32Type);
});
}

Y_UNIT_TEST(TestAbs) {
TestOne([](auto& b,auto& ctx) {
auto blockInt32Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Int32));
return b.AddUnaryOp(TKernelRequestBuilder::EUnaryOp::Abs, blockInt32Type, blockInt32Type);
});
}

Y_UNIT_TEST(TestCoalesece) {
TestOne([](auto& b,auto& ctx) {
auto blockStringType = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::String));
Expand Down
2 changes: 2 additions & 0 deletions ydb/library/yql/core/arrow_kernels/request/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ ui32 TKernelRequestBuilder::AddUnaryOp(EUnaryOp op, const TTypeAnnotationNode* a
Items_.emplace_back(Pb_.BlockNot(arg));
break;
case EUnaryOp::Size:
case EUnaryOp::Minus:
case EUnaryOp::Abs:
Items_.emplace_back(Pb_.BlockFunc(ToString(op), returnType, { arg }));
break;
}
Expand Down
2 changes: 2 additions & 0 deletions ydb/library/yql/core/arrow_kernels/request/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class TKernelRequestBuilder {
enum class EUnaryOp {
Not,
Size,
Minus,
Abs
};

enum class EBinaryOp {
Expand Down
19 changes: 19 additions & 0 deletions ydb/library/yql/core/expr_nodes/yql_expr_nodes.json
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,25 @@
"VarArgBase": "TExprBase",
"Match": {"Type": "Callable", "Name": "Xor"}
},
{
"Name": "TCoUnaryArithmetic",
"Base": "TCallable",
"Match": {"Type": "CallableBase"},
"Builder": {"Generate": "None"},
"Children": [
{"Index": 0, "Name": "Arg", "Type": "TExprBase"}
]
},
{
"Name": "TCoMinus",
"Base": "TCoUnaryArithmetic",
"Match": {"Type": "Callable", "Name": "Minus"}
},
{
"Name": "TCoAbs",
"Base": "TCoUnaryArithmetic",
"Match": {"Type": "Callable", "Name": "Abs"}
},
{
"Name": "TCoBinaryArithmetic",
"Base": "TCallable",
Expand Down
2 changes: 2 additions & 0 deletions ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ void RegisterDefaultOperations(IBuiltinFunctionRegistry& registry, TKernelFamily
RegisterRotRight(registry);
RegisterPlus(registry);
RegisterMinus(registry);
RegisterMinus(kernelFamilyMap);
RegisterBitNot(registry);
RegisterCountBits(registry);
RegisterAbs(registry);
RegisterAbs(kernelFamilyMap);
RegisterConvert(registry);
RegisterConcat(registry);
RegisterSubstring(registry);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,9 @@ void RegisterAbs(IBuiltinFunctionRegistry& registry) {
NDecimal::RegisterUnaryFunction<TDecimalAbs, TUnaryArgsOpt>(registry, "Abs");
}

void RegisterAbs(TKernelFamilyMap& kernelFamilyMap) {
kernelFamilyMap["Abs"] = std::make_unique<TUnaryNumericKernelFamily<TAbs>>();
}

} // namespace NMiniKQL
} // namespace NKikimr
Original file line number Diff line number Diff line change
Expand Up @@ -800,9 +800,11 @@ void RegisterRotLeft(IBuiltinFunctionRegistry& registry);
void RegisterRotRight(IBuiltinFunctionRegistry& registry);
void RegisterPlus(IBuiltinFunctionRegistry& registry);
void RegisterMinus(IBuiltinFunctionRegistry& registry);
void RegisterMinus(TKernelFamilyMap& kernelFamilyMap);
void RegisterBitNot(IBuiltinFunctionRegistry& registry);
void RegisterCountBits(IBuiltinFunctionRegistry& registry);
void RegisterAbs(IBuiltinFunctionRegistry& registry);
void RegisterAbs(TKernelFamilyMap& kernelFamilyMap);
void RegisterConvert(IBuiltinFunctionRegistry& registry);
void RegisterConcat(IBuiltinFunctionRegistry& registry);
void RegisterSubstring(IBuiltinFunctionRegistry& registry);
Expand Down Expand Up @@ -1087,6 +1089,54 @@ private:
const arrow::compute::ScalarKernel ArrowKernel;
};

template<typename TInput, typename TOutput, class TFuncInstance>
struct TUnaryKernelExecs : TUnaryKernelExecsBase<TUnaryKernelExecs<TInput, TOutput, TFuncInstance>>
{
static arrow::Status ExecScalar(arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
if (const auto& arg = batch.values.front(); !arg.scalar()->is_valid) {
*res = arrow::MakeNullScalar(GetPrimitiveDataType<TOutput>());
} else {
const auto val = GetPrimitiveScalarValue<TInput>(*arg.scalar());
*res = MakeScalarDatum<TOutput>(TFuncInstance::Do(val));
}
return arrow::Status::OK();
}

static arrow::Status ExecArray(arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
static_assert(!std::is_same<TOutput, bool>::value);

const auto& arg = batch.values.front();
auto& resArr = *res->array();

const auto& arr = *arg.array();
auto length = arr.length;
const auto values = arr.GetValues<TInput>(1);
auto resValues = resArr.GetMutableValues<TOutput>(1);
for (int64_t i = 0; i < length; ++i) {
resValues[i] = TFuncInstance::Do(values[i]);
}

return arrow::Status::OK();
}
};

template<typename TInput, typename TOutput,
template<typename, typename> class TFunc>
void AddUnaryKernel(TKernelFamilyBase& owner) {
using TInputLayout = typename TInput::TLayout;
using TOutputLayout = typename TOutput::TLayout;

using TFuncInstance = TFunc<TInputLayout, TOutputLayout>;
using TExecs = TUnaryKernelExecs<TInputLayout, TOutputLayout, TFuncInstance>;

std::vector<NUdf::TDataTypeId> argTypes({ TInput::Id });
NUdf::TDataTypeId returnType = TOutput::Id;

arrow::compute::ScalarKernel k({ GetPrimitiveInputArrowType<TInputLayout>() }, GetPrimitiveOutputArrowType<TOutputLayout>(), &TExecs::Exec);
k.null_handling = owner.NullMode == TKernelFamily::ENullMode::Default ? arrow::compute::NullHandling::INTERSECTION : arrow::compute::NullHandling::COMPUTED_PREALLOCATE;
owner.Adopt(argTypes, returnType, std::make_unique<TPlainKernel>(owner, argTypes, returnType, k));
}

template<typename TInput1, typename TInput2, typename TOutput,
template<typename, typename, typename> class TFunc>
void AddBinaryKernel(TKernelFamilyBase& owner) {
Expand Down Expand Up @@ -1125,6 +1175,19 @@ void AddBinaryKernelPoly(TKernelFamilyBase& owner) {
owner.Adopt(argTypes, returnType, std::make_unique<TPlainKernel>(owner, argTypes, returnType, k));
}

template<template<typename, typename> class TFunc>
void AddUnaryIntegralKernels(TKernelFamilyBase& owner) {
AddUnaryKernel<NUdf::TDataType<i8>, NUdf::TDataType<i8>, TFunc>(owner);
AddUnaryKernel<NUdf::TDataType<i16>, NUdf::TDataType<i16>, TFunc>(owner);
AddUnaryKernel<NUdf::TDataType<i32>, NUdf::TDataType<i32>, TFunc>(owner);
AddUnaryKernel<NUdf::TDataType<i64>, NUdf::TDataType<i64>, TFunc>(owner);

AddUnaryKernel<NUdf::TDataType<ui8>, NUdf::TDataType<ui8>, TFunc>(owner);
AddUnaryKernel<NUdf::TDataType<ui16>, NUdf::TDataType<ui16>, TFunc>(owner);
AddUnaryKernel<NUdf::TDataType<ui32>, NUdf::TDataType<ui32>, TFunc>(owner);
AddUnaryKernel<NUdf::TDataType<ui64>, NUdf::TDataType<ui64>, TFunc>(owner);
}

template<template<typename, typename, typename> class TFunc>
void AddBinaryIntegralKernels(TKernelFamilyBase& owner) {
AddBinaryKernel<NUdf::TDataType<ui8>, NUdf::TDataType<ui8>, NUdf::TDataType<ui8>, TFunc>(owner);
Expand Down Expand Up @@ -1210,6 +1273,16 @@ public:
}
};

template<template<typename, typename> class TFunc>
class TUnaryNumericKernelFamily : public TKernelFamilyBase {
public:
TUnaryNumericKernelFamily(TKernelFamily::ENullMode nullMode = TKernelFamily::ENullMode::Default)
: TKernelFamilyBase(nullMode)
{
AddUnaryIntegralKernels<TFunc>(*this);
}
};

template<typename TInput1, typename TInput2,
template<typename, typename, typename> class TFunc>
void AddBinaryPredicateKernel(TKernelFamilyBase& owner) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,9 @@ void RegisterMinus(IBuiltinFunctionRegistry& registry) {
RegisterFunctionUnOpt<NUdf::TDataType<NUdf::TInterval>, NUdf::TDataType<NUdf::TInterval>, TMinus, TUnaryArgsOpt>(registry, "Minus");
}

void RegisterMinus(TKernelFamilyMap& kernelFamilyMap) {
kernelFamilyMap["Minus"] = std::make_unique<TUnaryNumericKernelFamily<TMinus>>();
}

} // namespace NMiniKQL
} // namespace NKikimr