Skip to content

YQL-9517: Over BlockExpandChunked #1366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions ydb/library/yql/core/type_ann/type_ann_blocks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,24 @@ IGraphTransformer::TStatus BlockExpandChunkedWrapper(const TExprNode::TPtr& inpu
return IGraphTransformer::TStatus::Error;
}

TTypeAnnotationNode::TListType itemTypes;
TTypeAnnotationNode::TListType blockItemTypes;
if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

auto flowItemTypes = input->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems();
bool allScalars = AllOf(flowItemTypes, [](const TTypeAnnotationNode* item) { return item->GetKind() == ETypeAnnotationKind::Scalar; });
if (input->Head().GetTypeAnn()->GetKind() == ETypeAnnotationKind::Stream) {
if (!EnsureWideStreamBlockType(input->Head(), blockItemTypes, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

itemTypes = input->Head().GetTypeAnn()->Cast<TStreamExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems();
} else {
if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

itemTypes = input->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems();
}

bool allScalars = AllOf(itemTypes, [](const TTypeAnnotationNode* item) { return item->GetKind() == ETypeAnnotationKind::Scalar; });
if (allScalars) {
output = input->HeadPtr();
return IGraphTransformer::TStatus::Repeat;
Expand Down
68 changes: 61 additions & 7 deletions ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,52 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockExpandChunkedW
const size_t WideFieldsIndex_;
};

class TBlockExpandChunkedStreamWrapper : public TMutableComputationNode<TBlockExpandChunkedStreamWrapper> {
using TBaseComputation = TMutableComputationNode<TBlockExpandChunkedStreamWrapper>;
class TExpanderState : public TComputationValue<TExpanderState> {
using TBase = TComputationValue<TExpanderState>;
public:
TExpanderState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, NUdf::TUnboxedValue&& stream, size_t width)
: TBase(memInfo), HolderFactory_(ctx.HolderFactory), State_(ctx.HolderFactory.Create<TBlockState>(width)), Stream_(stream) {}

NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) {
auto& s = *static_cast<TBlockState*>(State_.AsBoxed().Get());
if (!s.Count) {
s.ClearValues();
auto result = Stream_.WideFetch(s.Values.data(), width);
if (NUdf::EFetchStatus::Ok != result) {
return result;
}
s.FillArrays();
}

const auto sliceSize = s.Slice();
for (size_t i = 0; i < width; ++i) {
output[i] = s.Get(sliceSize, HolderFactory_, i);
}
return NUdf::EFetchStatus::Ok;
}

private:
const THolderFactory& HolderFactory_;
NUdf::TUnboxedValue State_;
NUdf::TUnboxedValue Stream_;
};
public:
TBlockExpandChunkedStreamWrapper(TComputationMutables& mutables, IComputationNode* stream, size_t width)
: TBaseComputation(mutables, EValueRepresentation::Boxed)
, Stream_(stream)
, Width_(width) {}

NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
return ctx.HolderFactory.Create<TExpanderState>(ctx, std::move(Stream_->GetValue(ctx)), Width_);
}
void RegisterDependencies() const override {}
private:
IComputationNode* const Stream_;
const size_t Width_;
};

} // namespace

IComputationNode* WrapToBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
Expand Down Expand Up @@ -1184,13 +1230,21 @@ IComputationNode* WrapReplicateScalar(TCallable& callable, const TComputationNod

IComputationNode* WrapBlockExpandChunked(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());

const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
const auto wideComponents = GetWideComponents(flowType);

const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
return new TBlockExpandChunkedWrapper(ctx.Mutables, wideFlow, wideComponents.size());
if (callable.GetInput(0).GetStaticType()->IsStream()) {
const auto streamType = AS_TYPE(TStreamType, callable.GetInput(0).GetStaticType());
const auto wideComponents = GetWideComponents(streamType);
const auto computation = dynamic_cast<IComputationNode*>(LocateNode(ctx.NodeLocator, callable, 0));

MKQL_ENSURE(computation != nullptr, "Expected computation node");
return new TBlockExpandChunkedStreamWrapper(ctx.Mutables, computation, wideComponents.size());
} else {
const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
const auto wideComponents = GetWideComponents(flowType);

const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
return new TBlockExpandChunkedWrapper(ctx.Mutables, wideFlow, wideComponents.size());
}
}

}
Expand Down
30 changes: 26 additions & 4 deletions ydb/library/yql/minikql/mkql_program_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,24 @@ bool ReduceOptionalElements(const TType* type, const TArrayRef<const ui32>& test
return multiOptional;
}

std::vector<TType*> ValidateBlockStreamType(const TType* streamType) {
const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, streamType));
MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column");
std::vector<TType*> streamItems;
streamItems.reserve(wideComponents.size());
bool isScalar;
for (size_t i = 0; i < wideComponents.size(); ++i) {
auto blockType = AS_TYPE(TBlockType, wideComponents[i]);
isScalar = blockType->GetShape() == TBlockType::EShape::Scalar;
auto withoutBlock = blockType->GetItemType();
streamItems.push_back(withoutBlock);
}

MKQL_ENSURE(isScalar, "Last column should be scalar");
MKQL_ENSURE(AS_TYPE(TDataType, streamItems.back())->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64");
return streamItems;
}

std::vector<TType*> ValidateBlockFlowType(const TType* flowType) {
const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flowType));
MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column");
Expand Down Expand Up @@ -1550,10 +1568,14 @@ TRuntimeNode TProgramBuilder::BlockCompress(TRuntimeNode flow, ui32 bitmapIndex)
return TRuntimeNode(callableBuilder.Build(), false);
}

TRuntimeNode TProgramBuilder::BlockExpandChunked(TRuntimeNode flow) {
ValidateBlockFlowType(flow.GetStaticType());
TCallableBuilder callableBuilder(Env, __func__, flow.GetStaticType());
callableBuilder.Add(flow);
TRuntimeNode TProgramBuilder::BlockExpandChunked(TRuntimeNode comp) {
if (comp.GetStaticType()->IsStream()) {
ValidateBlockStreamType(comp.GetStaticType());
} else {
ValidateBlockFlowType(comp.GetStaticType());
}
TCallableBuilder callableBuilder(Env, __func__, comp.GetStaticType());
callableBuilder.Add(comp);
return TRuntimeNode(callableBuilder.Build(), false);
}

Expand Down
11 changes: 6 additions & 5 deletions ydb/library/yql/providers/dq/opt/dqs_opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,13 @@ namespace NYql::NDqs {

YQL_CLOG(INFO, ProviderDq) << "DqsRewritePhyBlockReadOnDqIntegration";
return Build<TCoWideFromBlocks>(ctx, node->Pos())
.Input(Build<TCoToFlow>(ctx, node->Pos())
.Input(
Build<TCoToFlow>(ctx, node->Pos())
.Input(Build<TDqReadBlockWideWrap>(ctx, node->Pos())
.Input(readWideWrap.Input())
.Flags(readWideWrap.Flags())
.Token(readWideWrap.Token())
.Done())
.Input(readWideWrap.Input())
.Flags(readWideWrap.Flags())
.Token(readWideWrap.Token())
.Done().Ptr())
.Done())
.Done().Ptr();
}, ctx, optSettings);
Expand Down
10 changes: 5 additions & 5 deletions ydb/library/yql/providers/yt/provider/yql_yt_mkql_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,11 @@ void RegisterDqYtMkqlCompilers(NCommon::TMkqlCallableCompilerBase& compiler, con
for (const auto& flag : wrapper.Flags())
if (solid = flag.Value() == "Solid")
break;

if (solid)
return BuildDqYtInputCall<false>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight);
else
return BuildDqYtInputCall<true>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight);
return ctx.ProgramBuilder.BlockExpandChunked(
solid
? BuildDqYtInputCall<false>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight)
: BuildDqYtInputCall<true>(outputType, inputItemType, cluster, tokenName, ytRead.Input(), state, ctx, inflight, timeout, true && inflight)
);
}

return TRuntimeNode();
Expand Down