Skip to content

Add arithmetic kernels for float & double. #1291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 63 additions & 8 deletions ydb/library/yql/core/arrow_kernels/registry/ut/registry_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,57 @@ Y_UNIT_TEST_SUITE(TKernelRegistryTest) {
});
}

Y_UNIT_TEST(TestAddSubMulOps) {
for (const auto oper : {TKernelRequestBuilder::EBinaryOp::Add, TKernelRequestBuilder::EBinaryOp::Sub, TKernelRequestBuilder::EBinaryOp::Mul}) {
for (const auto slot : {EDataSlot::Int8, EDataSlot::Int16, EDataSlot::Int32, EDataSlot::Int64, EDataSlot::Uint8, EDataSlot::Uint16, EDataSlot::Uint32, EDataSlot::Uint64, EDataSlot::Float, EDataSlot::Double}) {
TestOne([slot, oper](auto& b,auto& ctx) {
const auto blockUint8Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Uint8));
const auto blockType = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(slot));
return b.AddBinaryOp(oper, blockUint8Type, blockType, blockType);
});
TestOne([slot, oper](auto& b,auto& ctx) {
const auto blockUint8Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Uint8));
const auto blockType = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(slot));
return b.AddBinaryOp(oper, blockType, blockUint8Type, blockType);
});
TestOne([slot, oper](auto& b,auto& ctx) {
const auto blockType = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(slot));
return b.AddBinaryOp(oper, blockType, blockType, blockType);
});
}
}
}

Y_UNIT_TEST(TestDivModOps) {
for (const auto oper : {TKernelRequestBuilder::EBinaryOp::Div, TKernelRequestBuilder::EBinaryOp::Mod}) {
for (const auto slot : {EDataSlot::Int8, EDataSlot::Int16, EDataSlot::Int32, EDataSlot::Int64, EDataSlot::Uint8, EDataSlot::Uint16, EDataSlot::Uint32, EDataSlot::Uint64, EDataSlot::Float, EDataSlot::Double}) {
TestOne([slot, oper](auto& b,auto& ctx) {
const auto blockUint8Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Uint8));
const auto rawType = ctx.template MakeType<TDataExprType>(slot);
const auto blockType = ctx.template MakeType<TBlockExprType>(rawType);
const auto returnType = EDataSlot::Float != slot && EDataSlot::Double != slot ?
ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TOptionalExprType>(rawType)) : blockType;
return b.AddBinaryOp(oper, blockUint8Type, blockType, returnType);
});
TestOne([slot, oper](auto& b,auto& ctx) {
const auto blockUint8Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Uint8));
const auto rawType = ctx.template MakeType<TDataExprType>(slot);
const auto blockType = ctx.template MakeType<TBlockExprType>(rawType);
const auto returnType = EDataSlot::Float != slot && EDataSlot::Double != slot ?
ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TOptionalExprType>(rawType)) : blockType;
return b.AddBinaryOp(oper, blockType, blockUint8Type, returnType);
});
TestOne([slot, oper](auto& b,auto& ctx) {
const auto rawType = ctx.template MakeType<TDataExprType>(slot);
const auto blockType = ctx.template MakeType<TBlockExprType>(rawType);
const auto returnType = EDataSlot::Float != slot && EDataSlot::Double != slot ?
ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TOptionalExprType>(rawType)) : blockType;
return b.AddBinaryOp(oper, blockType, blockType, returnType);
});
}
}
}

Y_UNIT_TEST(TestSize) {
TestOne([](auto& b,auto& ctx) {
auto blockStrType = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::String));
Expand All @@ -121,17 +172,21 @@ Y_UNIT_TEST_SUITE(TKernelRegistryTest) {
}

Y_UNIT_TEST(TestMinus) {
TestOne([](auto& b,auto& ctx) {
auto blockInt32Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Int32));
return b.AddUnaryOp(TKernelRequestBuilder::EUnaryOp::Minus, blockInt32Type, blockInt32Type);
});
for (const auto slot : {EDataSlot::Int8, EDataSlot::Int16, EDataSlot::Int32, EDataSlot::Int64, EDataSlot::Float, EDataSlot::Double}) {
TestOne([slot](auto& b,auto& ctx) {
const auto blockType = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(slot));
return b.AddUnaryOp(TKernelRequestBuilder::EUnaryOp::Minus, blockType, blockType);
});
}
}

Y_UNIT_TEST(TestAbs) {
TestOne([](auto& b,auto& ctx) {
auto blockInt32Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Int32));
return b.AddUnaryOp(TKernelRequestBuilder::EUnaryOp::Abs, blockInt32Type, blockInt32Type);
});
for (const auto slot : {EDataSlot::Int8, EDataSlot::Int16, EDataSlot::Int32, EDataSlot::Int64, EDataSlot::Float, EDataSlot::Double}) {
TestOne([slot](auto& b,auto& ctx) {
const auto blockType = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(slot));
return b.AddUnaryOp(TKernelRequestBuilder::EUnaryOp::Abs, blockType, blockType);
});
}
}

Y_UNIT_TEST(TestCoalesece) {
Expand Down
8 changes: 4 additions & 4 deletions ydb/library/yql/minikql/arrow/mkql_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TTy
}

bool match = false;
switch (kernel->Family.NullMode) {
case TKernelFamily::ENullMode::Default:
switch (kernel->NullMode) {
case TKernel::ENullMode::Default:
match = returnIsOptional == hasOptionals;
break;
case TKernelFamily::ENullMode::AlwaysNull:
case TKernel::ENullMode::AlwaysNull:
match = returnIsOptional;
break;
case TKernelFamily::ENullMode::AlwaysNotNull:
case TKernel::ENullMode::AlwaysNotNull:
match = !returnIsOptional;
break;
}
Expand Down
59 changes: 0 additions & 59 deletions ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,65 +19,6 @@ namespace NMiniKQL {

namespace {

class TForeignKernel : public TKernel {
public:
TForeignKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes, NUdf::TDataTypeId returnType,
const std::shared_ptr<arrow::compute::Function>& function)
: TKernel(family, argTypes, returnType)
, Function(function)
, ArrowKernel(ResolveKernel(Function, argTypes))
{}

const arrow::compute::ScalarKernel& GetArrowKernel() const final {
return ArrowKernel;
}

private:
static const arrow::compute::ScalarKernel& ResolveKernel(const std::shared_ptr<arrow::compute::Function>& function,
const std::vector<NUdf::TDataTypeId>& argTypes) {
std::vector<arrow::ValueDescr> args;
for (const auto& t : argTypes) {
args.emplace_back();
auto slot = NUdf::FindDataSlot(t);
MKQL_ENSURE(slot, "Unexpected data type");
MKQL_ENSURE(ConvertArrowType(*slot, args.back().type), "Can't get arrow type");
}

const auto kernel = ARROW_RESULT(function->DispatchExact(args));
return *static_cast<const arrow::compute::ScalarKernel*>(kernel);
}

private:
const std::shared_ptr<arrow::compute::Function> Function;
const arrow::compute::ScalarKernel& ArrowKernel;
};

template <typename TInput1, typename TOutput>
void RegisterUnary(const arrow::compute::FunctionRegistry& registry, std::string_view name, TKernelFamilyMap& kernelFamilyMap) {
auto func = ARROW_RESULT(registry.GetFunction(std::string(name)));

std::vector<NUdf::TDataTypeId> argTypes({ NUdf::TDataType<TInput1>::Id });
NUdf::TDataTypeId returnType = NUdf::TDataType<TOutput>::Id;

auto family = std::make_unique<TKernelFamilyBase>();
family->Adopt(argTypes, returnType, std::make_unique<TForeignKernel>(*family, argTypes, returnType, func));

Y_ENSURE(kernelFamilyMap.emplace(TString(name), std::move(family)).second);
}

template <typename TInput1, typename TInput2, typename TOutput>
void RegisterBinary(const arrow::compute::FunctionRegistry& registry, std::string_view name, TKernelFamilyMap& kernelFamilyMap) {
auto func = ARROW_RESULT(registry.GetFunction(std::string(name)));

std::vector<NUdf::TDataTypeId> argTypes({ NUdf::TDataType<TInput1>::Id, NUdf::TDataType<TInput2>::Id });
NUdf::TDataTypeId returnType = NUdf::TDataType<TOutput>::Id;

auto family = std::make_unique<TKernelFamilyBase>();
family->Adopt(argTypes, returnType, std::make_unique<TForeignKernel>(*family, argTypes, returnType, func));

Y_ENSURE(kernelFamilyMap.emplace(TString(name), std::move(family)).second);
}

void RegisterDefaultOperations(IBuiltinFunctionRegistry& registry, TKernelFamilyMap& kernelFamilyMap) {
RegisterAdd(registry);
RegisterAdd(kernelFamilyMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ inline T Abs(T v) {

template<typename TInput, typename TOutput>
struct TAbs : public TSimpleArithmeticUnary<TInput, TOutput, TAbs<TInput, TOutput>> {
static constexpr auto NullMode = TKernel::ENullMode::Default;

static TOutput Do(TInput val)
{
return Abs<TInput>(val);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace {

template<typename TLeft, typename TRight, typename TOutput>
struct TAdd : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TAdd<TLeft, TRight, TOutput>> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;

static TOutput Do(TOutput left, TOutput right)
{
Expand Down Expand Up @@ -193,7 +193,7 @@ void RegisterAdd(IBuiltinFunctionRegistry& registry) {
}

void RegisterAdd(TKernelFamilyMap& kernelFamilyMap) {
kernelFamilyMap["Add"] = std::make_unique<TBinaryNumericKernelFamily<TAdd>>();
kernelFamilyMap["Add"] = std::make_unique<TBinaryNumericKernelFamily<TAdd, TAdd>>();
}

void RegisterAggrAdd(IBuiltinFunctionRegistry& registry) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ template<typename TLeft, typename TRight, typename TOutput>
struct TDiv : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TDiv<TLeft, TRight, TOutput>> {
static_assert(std::is_floating_point<TOutput>::value, "expected floating point");

static constexpr auto NullMode = TKernel::ENullMode::Default;

static TOutput Do(TOutput left, TOutput right)
{
return left / right;
Expand All @@ -29,7 +31,7 @@ template <typename TLeft, typename TRight, typename TOutput>
struct TIntegralDiv {
static_assert(std::is_integral<TOutput>::value, "integral type expected");

static constexpr bool DefaultNulls = false;
static constexpr auto NullMode = TKernel::ENullMode::AlwaysNull;

static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right)
{
Expand Down Expand Up @@ -60,7 +62,7 @@ struct TIntegralDiv {
const auto result = PHINode::Create(type, 2, "result", done);
result->addIncoming(zero, block);

if (std::is_signed<TOutput>() && sizeof(TOutput) <= sizeof(TLeft)) {
if constexpr (std::is_signed<TOutput>() && sizeof(TOutput) <= sizeof(TLeft)) {
const auto min = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, lv, ConstantInt::get(lv->getType(), Min<TOutput>()), "min", block);
const auto one = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rv, ConstantInt::get(rv->getType(), -1), "one", block);
const auto two = BinaryOperator::CreateAnd(min, one, "two", block);
Expand Down Expand Up @@ -167,7 +169,7 @@ void RegisterDiv(IBuiltinFunctionRegistry& registry) {
}

void RegisterDiv(TKernelFamilyMap& kernelFamilyMap) {
kernelFamilyMap["Div"] = std::make_unique<TBinaryNumericKernelFamily<TIntegralDiv>>(TKernelFamily::ENullMode::AlwaysNull);
kernelFamilyMap["Div"] = std::make_unique<TBinaryNumericKernelFamily<TIntegralDiv, TDiv>>();
}

} // namespace NMiniKQL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ struct TEqualsOp;

template<typename TLeft, typename TRight>
struct TEqualsOp<TLeft, TRight, bool> : public TEquals<TLeft, TRight, false> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;
};

template<typename TLeft, typename TRight, bool Aggr>
Expand Down Expand Up @@ -190,7 +190,7 @@ struct TDiffDateEqualsOp;

template<typename TLeft, typename TRight>
struct TDiffDateEqualsOp<TLeft, TRight, NUdf::TDataType<bool>> : public TDiffDateEquals<TLeft, TRight, false> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;
};

template <typename TLeft, typename TRight, bool Aggr>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ struct TGreaterOp;

template<typename TLeft, typename TRight>
struct TGreaterOp<TLeft, TRight, bool> : public TGreater<TLeft, TRight, false> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;
};

template<typename TLeft, typename TRight, bool Aggr>
Expand Down Expand Up @@ -183,7 +183,7 @@ struct TDiffDateGreaterOp;

template<typename TLeft, typename TRight>
struct TDiffDateGreaterOp<TLeft, TRight, NUdf::TDataType<bool>> : public TDiffDateGreater<TLeft, TRight, false> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;
};

template<typename TLeft, typename TRight, bool Aggr>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ struct TGreaterOrEqualOp;

template<typename TLeft, typename TRight>
struct TGreaterOrEqualOp<TLeft, TRight, bool> : public TGreaterOrEqual<TLeft, TRight, false> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;
};

template<typename TLeft, typename TRight, bool Aggr>
Expand Down Expand Up @@ -183,7 +183,7 @@ struct TDiffDateGreaterOrEqualOp;

template<typename TLeft, typename TRight>
struct TDiffDateGreaterOrEqualOp<TLeft, TRight, NUdf::TDataType<bool>> : public TDiffDateGreaterOrEqual<TLeft, TRight, false> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;
};

template<typename TLeft, typename TRight, bool Aggr>
Expand Down
Loading