Skip to content

Commit f71631d

Browse files
authored
Add decimal support to arrow reader (#6151)
1 parent 39c75e8 commit f71631d

File tree

34 files changed

+766
-104
lines changed

34 files changed

+766
-104
lines changed

ydb/core/kqp/ut/scheme/kqp_scheme_ut.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7732,7 +7732,7 @@ Y_UNIT_TEST_SUITE(KqpOlapTypes) {
77327732
testHelper.ReadData("SELECT dec FROM `/Root/ColumnTableTest` WHERE id=2", "[[\"inf\"]]");
77337733
testHelper.ReadData("SELECT dec FROM `/Root/ColumnTableTest` WHERE id=3", "[[\"-inf\"]]");
77347734
testHelper.ReadData("SELECT dec FROM `/Root/ColumnTableTest` WHERE id=4", "[[\"nan\"]]");
7735-
testHelper.ReadData("SELECT dec FROM `/Root/ColumnTableTest` WHERE id=5", "[[\"-nan\"]]");
7735+
testHelper.ReadData("SELECT dec FROM `/Root/ColumnTableTest` WHERE id=5", "[[\"nan\"]]");
77367736
testHelper.ReadData("SELECT id FROM `/Root/ColumnTableTest` WHERE dec=CAST(\"10.1\" As Decimal(22,9))", "[[1]]");
77377737
testHelper.ReadData("SELECT id FROM `/Root/ColumnTableTest` WHERE dec=CAST(\"inf\" As Decimal(22,9)) ORDER BY id", "[[2];[8]]");
77387738
testHelper.ReadData("SELECT id FROM `/Root/ColumnTableTest` WHERE dec=CAST(\"-inf\" As Decimal(22,9)) ORDER BY id", "[[3];[9]]");

ydb/library/yql/minikql/arrow/arrow_util.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,76 +74,87 @@ struct TPrimitiveDataType;
7474
template<>
7575
struct TPrimitiveDataType<bool> {
7676
using TLayout = ui8;
77+
using TArithmetic = ui8;
7778
using TResult = arrow::UInt8Type;
7879
using TScalarResult = arrow::UInt8Scalar;
7980
};
8081

8182
template<>
8283
struct TPrimitiveDataType<i8> {
8384
using TLayout = i8;
85+
using TArithmetic = i8;
8486
using TResult = arrow::Int8Type;
8587
using TScalarResult = arrow::Int8Scalar;
8688
};
8789

8890
template<>
8991
struct TPrimitiveDataType<ui8> {
9092
using TLayout = ui8;
93+
using TArithmetic = ui8;
9194
using TResult = arrow::UInt8Type;
9295
using TScalarResult = arrow::UInt8Scalar;
9396
};
9497

9598
template<>
9699
struct TPrimitiveDataType<i16> {
97100
using TLayout = i16;
101+
using TArithmetic = i16;
98102
using TResult = arrow::Int16Type;
99103
using TScalarResult = arrow::Int16Scalar;
100104
};
101105

102106
template<>
103107
struct TPrimitiveDataType<ui16> {
104108
using TLayout = ui16;
109+
using TArithmetic = ui16;
105110
using TResult = arrow::UInt16Type;
106111
using TScalarResult = arrow::UInt16Scalar;
107112
};
108113

109114
template<>
110115
struct TPrimitiveDataType<i32> {
111116
using TLayout = i32;
117+
using TArithmetic = i32;
112118
using TResult = arrow::Int32Type;
113119
using TScalarResult = arrow::Int32Scalar;
114120
};
115121

116122
template<>
117123
struct TPrimitiveDataType<ui32> {
118124
using TLayout = ui32;
125+
using TArithmetic = ui32;
119126
using TResult = arrow::UInt32Type;
120127
using TScalarResult = arrow::UInt32Scalar;
121128
};
122129

123130
template<>
124131
struct TPrimitiveDataType<i64> {
125132
using TLayout = i64;
133+
using TArithmetic = i64;
126134
using TResult = arrow::Int64Type;
127135
using TScalarResult = arrow::Int64Scalar;
128136
};
129137

130138
template<>
131139
struct TPrimitiveDataType<ui64> {
132140
using TLayout = ui64;
141+
using TArithmetic = ui64;
133142
using TResult = arrow::UInt64Type;
134143
using TScalarResult = arrow::UInt64Scalar;
135144
};
136145

137146
template<>
138147
struct TPrimitiveDataType<float> {
139148
using TLayout = float;
149+
using TArithmetic = float;
140150
using TResult = arrow::FloatType;
141151
using TScalarResult = arrow::FloatScalar;
142152
};
143153

144154
template<>
145155
struct TPrimitiveDataType<double> {
146156
using TLayout = double;
157+
using TArithmetic = double;
147158
using TResult = arrow::DoubleType;
148159
using TScalarResult = arrow::DoubleScalar;
149160
};
@@ -160,6 +171,31 @@ struct TPrimitiveDataType<NYql::NUdf::TUtf8> {
160171
using TScalarResult = arrow::StringScalar;
161172
};
162173

174+
template<>
175+
struct TPrimitiveDataType<NYql::NDecimal::TInt128> {
176+
using TArithmetic = NYql::NDecimal::TDecimal;
177+
178+
class TResult: public arrow::FixedSizeBinaryType
179+
{
180+
public:
181+
TResult(): arrow::FixedSizeBinaryType(16)
182+
{ }
183+
};
184+
185+
186+
class TScalarResult: public arrow::FixedSizeBinaryScalar
187+
{
188+
public:
189+
TScalarResult(std::shared_ptr<arrow::Buffer> value)
190+
: arrow::FixedSizeBinaryScalar(std::move(value), arrow::fixed_size_binary(16))
191+
{ }
192+
193+
TScalarResult()
194+
: arrow::FixedSizeBinaryScalar(arrow::fixed_size_binary(16))
195+
{ }
196+
};
197+
};
198+
163199
template <typename T, typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
164200
inline arrow::Datum MakeScalarDatum(T value) {
165201
return arrow::Datum(std::make_shared<typename TPrimitiveDataType<T>::TScalarResult>(value));
@@ -179,3 +215,14 @@ inline std::shared_ptr<arrow::DataType> GetPrimitiveDataType() {
179215
using NYql::NUdf::TTypedBufferBuilder;
180216

181217
}
218+
219+
namespace arrow {
220+
221+
template <>
222+
struct TypeTraits<typename NKikimr::NMiniKQL::TPrimitiveDataType<NYql::NDecimal::TInt128>::TResult> {
223+
static inline std::shared_ptr<DataType> type_singleton() {
224+
return arrow::fixed_size_binary(16);
225+
}
226+
};
227+
228+
}

ydb/library/yql/minikql/comp_nodes/mkql_block_agg_minmax.cpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "mkql_block_agg_minmax.h"
2+
#include "mkql_block_agg_state_helper.h"
23

34
#include <ydb/library/yql/minikql/mkql_node_cast.h>
45
#include <ydb/library/yql/minikql/mkql_node_builder.h>
@@ -95,6 +96,12 @@ constexpr TIn InitialStateValue() {
9596
} else {
9697
return -std::numeric_limits<TIn>::infinity();
9798
}
99+
} else if constexpr (std::is_same_v<TIn, NYql::NDecimal::TInt128>) {
100+
if constexpr (IsMin) {
101+
return NYql::NDecimal::Nan();
102+
} else {
103+
return -NYql::NDecimal::Inf();
104+
}
98105
} else {
99106
if constexpr (IsMin) {
100107
return std::numeric_limits<TIn>::max();
@@ -129,7 +136,7 @@ class TColumnBuilder : public IAggColumnBuilder {
129136
}
130137

131138
void Add(const void* state) final {
132-
auto typedState = static_cast<const TStateType*>(state);
139+
auto typedState = MakeStateWrapper<TStateType>(state);
133140
if constexpr (IsNullable) {
134141
if (!typedState->IsValid) {
135142
Builder_.Add(TBlockItem());
@@ -620,8 +627,9 @@ class TMinMaxBlockFixedAggregator<TCombineAllTag, IsNullable, IsScalar, TIn, IsM
620627
Y_UNUSED(type);
621628
}
622629

623-
void InitState(void* state) final {
624-
new(state) TStateType();
630+
void InitState(void* ptr) final {
631+
TStateType state;
632+
WriteUnaligned<TStateType>(ptr, state);
625633
}
626634

627635
void DestroyState(void* state) noexcept final {
@@ -630,18 +638,18 @@ class TMinMaxBlockFixedAggregator<TCombineAllTag, IsNullable, IsScalar, TIn, IsM
630638
}
631639

632640
void AddMany(void* state, const NUdf::TUnboxedValue* columns, ui64 batchLength, std::optional<ui64> filtered) final {
633-
auto typedState = static_cast<TStateType*>(state);
641+
auto typedState = MakeStateWrapper<TStateType>(state);
634642
Y_UNUSED(batchLength);
635643
const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
636644
if constexpr (IsScalar) {
637645
Y_ENSURE(datum.is_scalar());
638646
if constexpr (IsNullable) {
639647
if (datum.scalar()->is_valid) {
640-
typedState->Value = datum.scalar_as<TInScalar>().value;
648+
typedState->Value = TIn(Cast(datum.scalar_as<TInScalar>().value));
641649
typedState->IsValid = 1;
642650
}
643651
} else {
644-
typedState->Value = datum.scalar_as<TInScalar>().value;
652+
typedState->Value = TIn(Cast(datum.scalar_as<TInScalar>().value));
645653
}
646654
} else {
647655
const auto& array = datum.array();
@@ -706,7 +714,7 @@ class TMinMaxBlockFixedAggregator<TCombineAllTag, IsNullable, IsScalar, TIn, IsM
706714
}
707715

708716
NUdf::TUnboxedValue FinishOne(const void* state) final {
709-
auto typedState = static_cast<const TStateType*>(state);
717+
auto typedState = MakeStateWrapper<TStateType>(state);
710718
if constexpr (IsNullable) {
711719
if (!typedState->IsValid) {
712720
return NUdf::TUnboxedValuePod();
@@ -727,11 +735,11 @@ static void PushValueToState(TState<IsNullable, TIn, IsMin>* typedState, const a
727735
Y_ENSURE(datum.is_scalar());
728736
if constexpr (IsNullable) {
729737
if (datum.scalar()->is_valid) {
730-
typedState->Value = datum.scalar_as<TInScalar>().value;
738+
typedState->Value = TIn(Cast(datum.scalar_as<TInScalar>().value));
731739
typedState->IsValid = 1;
732740
}
733741
} else {
734-
typedState->Value = datum.scalar_as<TInScalar>().value;
742+
typedState->Value = TIn(Cast(datum.scalar_as<TInScalar>().value));
735743
}
736744
} else {
737745
const auto &array = datum.array();
@@ -767,7 +775,8 @@ class TMinMaxBlockFixedAggregator<TCombineKeysTag, IsNullable, IsScalar, TIn, Is
767775
}
768776

769777
void InitKey(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
770-
new(state) TStateType();
778+
TStateType st;
779+
WriteUnaligned<TStateType>(state, st);
771780
UpdateKey(state, batchNum, columns, row);
772781
}
773782

@@ -778,9 +787,9 @@ class TMinMaxBlockFixedAggregator<TCombineKeysTag, IsNullable, IsScalar, TIn, Is
778787

779788
void UpdateKey(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
780789
Y_UNUSED(batchNum);
781-
auto typedState = static_cast<TStateType*>(state);
790+
auto typedState = MakeStateWrapper<TStateType>(state);
782791
const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
783-
PushValueToState<IsNullable, IsScalar, TIn, IsMin>(typedState, datum, row);
792+
PushValueToState<IsNullable, IsScalar, TIn, IsMin>(typedState.Get(), datum, row);
784793
}
785794

786795
std::unique_ptr<IAggColumnBuilder> MakeStateBuilder(ui64 size) final {
@@ -807,7 +816,8 @@ class TMinMaxBlockFixedAggregator<TFinalizeKeysTag, IsNullable, IsScalar, TIn, I
807816
}
808817

809818
void LoadState(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
810-
new(state) TStateType();
819+
TStateType st;
820+
WriteUnaligned<TStateType>(state, st);
811821
UpdateState(state, batchNum, columns, row);
812822
}
813823

@@ -818,9 +828,9 @@ class TMinMaxBlockFixedAggregator<TFinalizeKeysTag, IsNullable, IsScalar, TIn, I
818828

819829
void UpdateState(void* state, ui64 batchNum, const NUdf::TUnboxedValue* columns, ui64 row) final {
820830
Y_UNUSED(batchNum);
821-
auto typedState = static_cast<TStateType*>(state);
831+
auto typedState = MakeStateWrapper<TStateType>(state);
822832
const auto& datum = TArrowBlock::From(columns[ArgColumn_]).GetDatum();
823-
PushValueToState<IsNullable, IsScalar, TIn, IsMin>(typedState, datum, row);
833+
PushValueToState<IsNullable, IsScalar, TIn, IsMin>(typedState.Get(), datum, row);
824834
}
825835

826836
std::unique_ptr<IAggColumnBuilder> MakeResultBuilder(ui64 size) final {
@@ -963,6 +973,8 @@ std::unique_ptr<typename TTag::TPreparedAggregator> PrepareMinMax(TTupleType* tu
963973
return PrepareMinMaxFixed<TTag, float, IsMin>(dataType, isOptional, isScalar, filterColumn, argColumn);
964974
case NUdf::EDataSlot::Double:
965975
return PrepareMinMaxFixed<TTag, double, IsMin>(dataType, isOptional, isScalar, filterColumn, argColumn);
976+
case NUdf::EDataSlot::Decimal:
977+
return PrepareMinMaxFixed<TTag, NYql::NDecimal::TInt128, IsMin>(dataType, isOptional, isScalar, filterColumn, argColumn);
966978
default:
967979
throw yexception() << "Unsupported MIN/MAX input type";
968980
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#pragma once
2+
3+
#include <util/system/unaligned_mem.h>
4+
5+
namespace NKikimr {
6+
namespace NMiniKQL {
7+
8+
template <typename T, bool IsConst = std::is_const_v<T>>
9+
class TStateWrapper;
10+
11+
template <typename T>
12+
class TStateWrapper<T, true> {
13+
public:
14+
TStateWrapper(const void* ptr)
15+
: State_(ReadUnaligned<typename std::remove_const<T>::type>(ptr))
16+
{ }
17+
18+
T* Get() {
19+
return &State_;
20+
}
21+
22+
T* operator->() {
23+
return Get();
24+
}
25+
26+
private:
27+
T State_;
28+
};
29+
30+
template <typename T>
31+
class TStateWrapper<T, false> {
32+
public:
33+
TStateWrapper(void* ptr)
34+
: State_(ReadUnaligned<T>(ptr))
35+
, Ptr_(ptr)
36+
{ }
37+
38+
~TStateWrapper() {
39+
WriteUnaligned<T>(Ptr_, State_);
40+
}
41+
42+
T* Get() {
43+
return &State_;
44+
}
45+
46+
T* operator->() {
47+
return Get();
48+
}
49+
50+
private:
51+
T State_;
52+
void* Ptr_;
53+
};
54+
55+
template <typename T>
56+
inline TStateWrapper<T> MakeStateWrapper(void* ptr) {
57+
return TStateWrapper<T>(ptr);
58+
}
59+
60+
template <typename T>
61+
inline TStateWrapper<const T> MakeStateWrapper(const void* ptr) {
62+
return TStateWrapper<const T>(ptr);
63+
}
64+
65+
template<typename T>
66+
inline T Cast(T t) {
67+
return t;
68+
}
69+
70+
inline NYql::NDecimal::TDecimal Cast(const std::shared_ptr<arrow::Buffer>& buffer) {
71+
return *reinterpret_cast<const NYql::NDecimal::TDecimal*>(buffer->data());
72+
}
73+
74+
}
75+
}

0 commit comments

Comments
 (0)