Skip to content

Commit 5a99902

Browse files
author
Vadim Averin
committed
Improve interface of simple stateful wide flow nodes
1 parent 0c5fdb4 commit 5a99902

File tree

4 files changed

+38
-22
lines changed

4 files changed

+38
-22
lines changed

ydb/library/yql/minikql/comp_nodes/mkql_skip.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,22 @@ using TBaseComputation = TStatefulFlowCodegeneratorNode<TSkipFlowWrapper>;
121121
class TWideSkipWrapper : public TSimpleStatefulWideFlowCodegeneratorNode<TWideSkipWrapper, ui64> {
122122
using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideSkipWrapper, ui64>;
123123
public:
124-
TWideSkipWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, ui32 )
124+
TWideSkipWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, ui32 size)
125125
: TBaseComputation(mutables, flow, EValueRepresentation::Embedded)
126126
, Flow(flow)
127127
, Count(count)
128+
, StubsIndex(mutables.IncrementWideFieldsIndex(size))
128129
{}
129130

130131
void InitState(ui64& count, TComputationContext& ctx) const {
131132
count = Count->GetValue(ctx).Get<ui64>();
132133
}
133134

134-
EProcessResult DoProcess(ui64& skipCount, TComputationContext& , EFetchResult fetchRes, NUdf::TUnboxedValue*const* ) const {
135+
NUdf::TUnboxedValue*const* PrepareInput(ui64& skipCount, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
136+
return skipCount == 0 ? output : ctx.WideFields.data() + StubsIndex;
137+
}
138+
139+
EProcessResult DoProcess(ui64& skipCount, TComputationContext&, EFetchResult fetchRes, NUdf::TUnboxedValue*const*, NUdf::TUnboxedValue*const*) const {
135140
if (fetchRes == EFetchResult::One && skipCount) {
136141
skipCount--;
137142
return EProcessResult::Fetch;
@@ -147,6 +152,7 @@ using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideSkipWrapp
147152

148153
IComputationWideFlowNode* const Flow;
149154
IComputationNode* const Count;
155+
const ui32 StubsIndex;
150156
};
151157

152158
class TSkipStreamWrapper : public TMutableComputationNode<TSkipStreamWrapper> {

ydb/library/yql/minikql/comp_nodes/mkql_take.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,11 @@ using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideTakeWrapp
107107
count = Count->GetValue(ctx).Get<ui64>();
108108
}
109109

110-
EProcessResult DoProcess(ui64& takeCount, TComputationContext& , EFetchResult fetchRes, NUdf::TUnboxedValue*const* ) const {
110+
NUdf::TUnboxedValue*const* PrepareInput(ui64& takeCount, TComputationContext&, NUdf::TUnboxedValue*const* output) const {
111+
return takeCount != 0 ? output : nullptr;
112+
}
113+
114+
EProcessResult DoProcess(ui64& takeCount, TComputationContext& , EFetchResult fetchRes, NUdf::TUnboxedValue*const*, NUdf::TUnboxedValue*const*) const {
111115
if (takeCount == 0) {
112116
return EProcessResult::Finish;
113117
} else if (fetchRes == EFetchResult::One) {

ydb/library/yql/minikql/comp_nodes/mkql_wide_filter.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class TBaseWideFilterWrapper {
2727
return ctx.WideFields.data() + WideFieldsIndex;
2828
}
2929

30-
void PrepareArguments(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
30+
NUdf::TUnboxedValue*const* PrepareArguments(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
3131
auto** fields = GetFields(ctx);
3232

3333
for (auto i = 0U; i < Items.size(); ++i) {
@@ -36,6 +36,8 @@ class TBaseWideFilterWrapper {
3636
else
3737
fields[i] = output[i];
3838
}
39+
40+
return fields;
3941
}
4042

4143
void FillOutputs(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
@@ -47,15 +49,6 @@ class TBaseWideFilterWrapper {
4749
*out = *fields[i];
4850
}
4951

50-
bool ApplyPredicate(TComputationContext& ctx, NUdf::TUnboxedValue*const* values) const {
51-
auto **fields = GetFields(ctx);
52-
PrepareArguments(ctx, values);
53-
for (size_t idx = 0; idx < Items.size(); idx++) {
54-
*fields[idx] = *values[idx];
55-
}
56-
return Predicate->GetValue(ctx).Get<bool>();
57-
}
58-
5952
#ifndef MKQL_DISABLE_CODEGEN
6053
template<bool ReplaceOriginalGetter = true>
6154
Value* GenGetPredicate(const TCodegenContext& ctx,
@@ -159,12 +152,16 @@ using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideFilterWit
159152
limit = Limit->GetValue(ctx).Get<ui64>();
160153
}
161154

162-
EProcessResult DoProcess(ui64& limit, TComputationContext& ctx, EFetchResult fetchRes, NUdf::TUnboxedValue*const* values) const {
155+
NUdf::TUnboxedValue*const* PrepareInput(ui64& limit, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
156+
return limit != 0 ? PrepareArguments(ctx, output) : nullptr;
157+
}
158+
159+
EProcessResult DoProcess(ui64& limit, TComputationContext& ctx, EFetchResult fetchRes, NUdf::TUnboxedValue*const*, NUdf::TUnboxedValue*const* values) const {
163160
if (limit == 0) {
164161
return EProcessResult::Finish;
165162
}
166163
if (fetchRes == EFetchResult::One) {
167-
if (ApplyPredicate(ctx, values)) {
164+
if (Predicate->GetValue(ctx).Get<bool>()) {
168165
FillOutputs(ctx, values);
169166
limit--;
170167
return EProcessResult::One;
@@ -200,12 +197,16 @@ using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideTakeWhile
200197
stop = false;
201198
}
202199

203-
TBaseComputation::EProcessResult DoProcess(bool& stop, TComputationContext& ctx, EFetchResult fetchRes, NUdf::TUnboxedValue*const* values) const {
200+
NUdf::TUnboxedValue*const* PrepareInput(bool& stop, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
201+
return !stop ? output : PrepareArguments(ctx, output);
202+
}
203+
204+
TBaseComputation::EProcessResult DoProcess(bool& stop, TComputationContext& ctx, EFetchResult fetchRes, NUdf::TUnboxedValue*const*, NUdf::TUnboxedValue*const* values) const {
204205
if (stop) {
205206
return TBaseComputation::EProcessResult::Finish;
206207
}
207208
if (fetchRes == EFetchResult::One) {
208-
const bool predicate = ApplyPredicate(ctx, values);
209+
const bool predicate = Predicate->GetValue(ctx).Get<bool>();
209210
if (!predicate) {
210211
stop = true;
211212
}
@@ -240,9 +241,13 @@ using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideSkipWhile
240241
start = false;
241242
}
242243

243-
TBaseComputation::EProcessResult DoProcess(bool& start, TComputationContext& ctx, EFetchResult fetchRes, NUdf::TUnboxedValue*const* values) const {
244+
NUdf::TUnboxedValue*const* PrepareInput(bool&, TComputationContext&, NUdf::TUnboxedValue*const* output) const {
245+
return output;
246+
}
247+
248+
TBaseComputation::EProcessResult DoProcess(bool& start, TComputationContext& ctx, EFetchResult fetchRes, NUdf::TUnboxedValue*const*, NUdf::TUnboxedValue*const* values) const {
244249
if (!start && fetchRes == EFetchResult::One) {
245-
const bool predicate = ApplyPredicate(ctx, values);
250+
const bool predicate = Predicate->GetValue(ctx).Get<bool>();
246251
if (!predicate) {
247252
start = true;
248253
}

ydb/library/yql/minikql/computation/mkql_computation_node_codegen_impl.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class TSimpleStatefulWideFlowCodegeneratorNode
3636
for (size_t pos = 0; pos < width; pos++) {
3737
valuePtrsVec[pos] = valuesVec.data() + pos;
3838
}
39-
auto res = static_cast<const TDerived*>(this)->DoProcess(*static_cast<TState*>(state.GetRawPtr()), ctx, fetchRes, valuePtrsVec.data());
39+
auto res = static_cast<const TDerived*>(this)->DoProcess(*static_cast<TState*>(state.GetRawPtr()), ctx, fetchRes, valuePtrsVec.data(), valuePtrsVec.data());
4040
for (size_t pos = 0; pos < width; pos++) {
4141
values[pos] = valuesVec[pos].Release();
4242
}
@@ -52,8 +52,9 @@ class TSimpleStatefulWideFlowCodegeneratorNode
5252
}
5353
EProcessResult res = EProcessResult::Fetch;
5454
while (res == EProcessResult::Fetch) {
55-
auto fetchRes = SourceFlow->FetchValues(ctx, output);
56-
res = static_cast<const TDerived*>(this)->DoProcess(*static_cast<TState*>(state.GetRawPtr()), ctx, fetchRes, output);
55+
auto *const *input = static_cast<const TDerived*>(this)->PrepareInput(*static_cast<TState*>(state.GetRawPtr()), ctx, output);
56+
auto fetchRes = input ? SourceFlow->FetchValues(ctx, input) : EFetchResult::One;
57+
res = static_cast<const TDerived*>(this)->DoProcess(*static_cast<TState*>(state.GetRawPtr()), ctx, fetchRes, input, output);
5758
}
5859
return static_cast<EFetchResult>(res);
5960
}

0 commit comments

Comments
 (0)