Skip to content

Commit 70651ae

Browse files
author
Vadim Averin
committed
Generalize DoGenGetValues of wide While-variations
1 parent 5cc4e57 commit 70651ae

File tree

1 file changed

+48
-125
lines changed

1 file changed

+48
-125
lines changed

ydb/library/yql/minikql/comp_nodes/mkql_wide_filter.cpp

Lines changed: 48 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ class TBaseWideFilterWrapper {
4646
if (Predicate == Items[i] || Items[i]->GetDependencesCount() > 0U)
4747
*out = *fields[i];
4848
}
49+
50+
bool ApplyPredicate(TComputationContext& ctx, NUdf::TUnboxedValue*const* values) const {
51+
auto **fields = GetFields(ctx);
52+
PrepareArguments(ctx, values);
53+
for (size_t idx = 0; idx < Items.size(); idx++) {
54+
*fields[idx] = *values[idx];
55+
}
56+
return Predicate->GetValue(ctx).Get<bool>();
57+
}
58+
4959
#ifndef MKQL_DISABLE_CODEGEN
5060
template<bool ReplaceOriginalGetter = true>
5161
Value* GenGetPredicate(const TCodegenContext& ctx,
@@ -154,12 +164,7 @@ using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideFilterWit
154164
return EProcessResult::Finish;
155165
}
156166
if (fetchRes == EFetchResult::One) {
157-
auto **fields = GetFields(ctx);
158-
PrepareArguments(ctx, values);
159-
for (size_t idx = 0; idx < Items.size(); idx++) {
160-
*fields[idx] = *values[idx];
161-
}
162-
if (Predicate->GetValue(ctx).Get<bool>()) {
167+
if (ApplyPredicate(ctx, values)) {
163168
FillOutputs(ctx, values);
164169
limit--;
165170
return EProcessResult::One;
@@ -182,80 +187,37 @@ using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideFilterWit
182187
};
183188

184189
template<bool Inclusive>
185-
class TWideTakeWhileWrapper : public TStatefulWideFlowCodegeneratorNode<TWideTakeWhileWrapper<Inclusive>>, public TBaseWideFilterWrapper {
186-
using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTakeWhileWrapper<Inclusive>>;
190+
class TWideTakeWhileWrapper : public TSimpleStatefulWideFlowCodegeneratorNode<TWideTakeWhileWrapper<Inclusive>, bool>, public TBaseWideFilterWrapper {
191+
using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideTakeWhileWrapper<Inclusive>, bool>;
187192
public:
188193
TWideTakeWhileWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items,
189194
IComputationNode* predicate)
190195
: TBaseComputation(mutables, flow, EValueRepresentation::Embedded)
191196
, TBaseWideFilterWrapper(mutables, flow, std::move(items), predicate)
192197
{}
193198

194-
EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
195-
if (!state.IsInvalid()) {
196-
return EFetchResult::Finish;
197-
}
198-
199-
PrepareArguments(ctx, output);
200-
201-
auto **fields = GetFields(ctx);
202-
203-
if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result)
204-
return result;
205-
206-
const bool predicate = Predicate->GetValue(ctx).Get<bool>();
207-
if (!predicate)
208-
state = NUdf::TUnboxedValuePod();
199+
void InitState(bool& stop, TComputationContext& ctx) const {
200+
stop = false;
201+
}
209202

210-
if (Inclusive || predicate) {
211-
FillOutputs(ctx, output);
212-
return EFetchResult::One;
203+
TBaseComputation::EProcessResult DoProcess(bool& stop, TComputationContext& ctx, EFetchResult fetchRes, NUdf::TUnboxedValue*const* values) const {
204+
if (stop) {
205+
return TBaseComputation::EProcessResult::Finish;
213206
}
214-
215-
return EFetchResult::Finish;
207+
if (fetchRes == EFetchResult::One) {
208+
const bool predicate = ApplyPredicate(ctx, values);
209+
if (!predicate) {
210+
stop = true;
211+
}
212+
if (Inclusive || predicate) {
213+
FillOutputs(ctx, values);
214+
return TBaseComputation::EProcessResult::One;
215+
}
216+
return TBaseComputation::EProcessResult::Finish;
217+
}
218+
return static_cast<TBaseComputation::EProcessResult>(fetchRes);
216219
}
217-
#ifndef MKQL_DISABLE_CODEGEN
218-
ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
219-
auto& context = ctx.Codegen.GetContext();
220-
221-
const auto resultType = Type::getInt32Ty(context);
222-
223-
const auto work = BasicBlock::Create(context, "work", ctx.Func);
224-
const auto test = BasicBlock::Create(context, "test", ctx.Func);
225-
const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
226-
const auto done = BasicBlock::Create(context, "done", ctx.Func);
227-
228-
const auto result = PHINode::Create(resultType, 4U, "result", done);
229-
result->addIncoming(ConstantInt::get(resultType, static_cast<i32>(EFetchResult::Finish)), block);
230-
231-
const auto state = new LoadInst(Type::getInt128Ty(context), statePtr, "state", block);
232-
const auto finished = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, state, GetTrue(context), "finished", block);
233-
234-
BranchInst::Create(done, work, IsValid(statePtr, block), block);
235-
236-
block = work;
237-
auto status = GetNodeValues(Flow, ctx, block);
238-
const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, status.first, ConstantInt::get(resultType, 0), "special", block);
239-
result->addIncoming(status.first, block);
240-
BranchInst::Create(done, test, special, block);
241-
242-
block = test;
243-
244-
const auto predicate = GenGetPredicate(ctx, status.second, block);
245-
result->addIncoming(status.first, block);
246-
BranchInst::Create(done, stop, predicate, block);
247-
248-
block = stop;
249220

250-
new StoreInst(GetEmpty(context), statePtr, block);
251-
result->addIncoming(ConstantInt::get(resultType, static_cast<i32>(Inclusive ? EFetchResult::One: EFetchResult::Finish)), block);
252-
253-
BranchInst::Create(done, block);
254-
255-
block = done;
256-
return {result, std::move(status.second)};
257-
}
258-
#endif
259221
private:
260222
void RegisterDependencies() const final {
261223
if (const auto flow = this->FlowDependsOn(Flow)) {
@@ -266,72 +228,33 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTakeWhileWrappe
266228
};
267229

268230
template<bool Inclusive>
269-
class TWideSkipWhileWrapper : public TStatefulWideFlowCodegeneratorNode<TWideSkipWhileWrapper<Inclusive>>, public TBaseWideFilterWrapper {
270-
using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideSkipWhileWrapper<Inclusive>>;
231+
class TWideSkipWhileWrapper : public TSimpleStatefulWideFlowCodegeneratorNode<TWideSkipWhileWrapper<Inclusive>, bool>, public TBaseWideFilterWrapper {
232+
using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideSkipWhileWrapper<Inclusive>, bool>;
271233
public:
272234
TWideSkipWhileWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items, IComputationNode* predicate)
273235
: TBaseComputation(mutables, flow, EValueRepresentation::Embedded)
274236
, TBaseWideFilterWrapper(mutables, flow, std::move(items), predicate)
275237
{}
276238

277-
EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
278-
if (!state.IsInvalid()) {
279-
return Flow->FetchValues(ctx, output);
280-
}
281-
282-
auto **fields = GetFields(ctx);
283-
284-
do {
285-
PrepareArguments(ctx, output);
286-
if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result)
287-
return result;
288-
} while (Predicate->GetValue(ctx).Get<bool>());
289-
290-
state = NUdf::TUnboxedValuePod();
239+
void InitState(bool& start, TComputationContext& ctx) const {
240+
start = false;
241+
}
291242

292-
if constexpr (Inclusive)
293-
return Flow->FetchValues(ctx, output);
294-
else {
295-
FillOutputs(ctx, output);
296-
return EFetchResult::One;
243+
TBaseComputation::EProcessResult DoProcess(bool& start, TComputationContext& ctx, EFetchResult fetchRes, NUdf::TUnboxedValue*const* values) const {
244+
if (!start && fetchRes == EFetchResult::One) {
245+
const bool predicate = ApplyPredicate(ctx, values);
246+
if (!predicate) {
247+
start = true;
248+
}
249+
if (!Inclusive && !predicate) {
250+
FillOutputs(ctx, values);
251+
return TBaseComputation::EProcessResult::One;
252+
}
253+
return TBaseComputation::EProcessResult::Fetch;
297254
}
255+
return static_cast<TBaseComputation::EProcessResult>(fetchRes);
298256
}
299-
#ifndef MKQL_DISABLE_CODEGEN
300-
ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
301-
auto& context = ctx.Codegen.GetContext();
302-
303-
const auto resultType = Type::getInt32Ty(context);
304-
305-
const auto work = BasicBlock::Create(context, "work", ctx.Func);
306-
const auto test = BasicBlock::Create(context, "test", ctx.Func);
307-
const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
308-
const auto done = BasicBlock::Create(context, "done", ctx.Func);
309-
310-
BranchInst::Create(work, block);
311-
312-
block = work;
313-
314-
const auto status = GetNodeValues(Flow, ctx, block);
315-
316-
const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, status.first, ConstantInt::get(resultType, 0), "special", block);
317-
const auto passtrought = BinaryOperator::CreateOr(special, IsValid(statePtr, block), "passtrought", block);
318-
BranchInst::Create(done, test, passtrought, block);
319257

320-
block = test;
321-
322-
const auto predicate = GenGetPredicate<false>(ctx, status.second, block);
323-
BranchInst::Create(work, stop, predicate, block);
324-
325-
block = stop;
326-
327-
new StoreInst(GetEmpty(context), statePtr, block);
328-
329-
BranchInst::Create(Inclusive ? work : done, block);
330-
331-
block = done;
332-
return status;
333-
}
334-
#endif
335258
private:
336259
void RegisterDependencies() const final {
337260
if (const auto flow = this->FlowDependsOn(Flow)) {

0 commit comments

Comments
 (0)