Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 109 additions & 4 deletions ydb/library/yql/core/common_opt/yql_co_pgselect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ void GatherUsedWindows(const TExprNode::TPtr& window, const TExprNode::TPtr& pro

TUsedColumns GatherUsedColumns(const TExprNode::TPtr& result, const TExprNode::TPtr& joinOps,
const TExprNode::TPtr& filter, const TExprNode::TPtr& groupExprs, const TExprNode::TPtr& having, const TExprNode::TPtr& extraSortColumns,
const TExprNode::TPtr& window, const TWindowsCtx& winCtx) {
const TExprNode::TPtr& window, const TWindowsCtx& winCtx, TVector<std::pair<TString, TString>>& joinUsingColumns) {
TUsedColumns usedColumns;
for (const auto& x : result->Tail().Children()) {
AddColumnsFromType(x->Child(1)->GetTypeAnn(), usedColumns);
Expand All @@ -924,6 +924,21 @@ TUsedColumns GatherUsedColumns(const TExprNode::TPtr& result, const TExprNode::T
for (ui32 i = 0; i < groupTuple->ChildrenSize(); ++i) {
auto join = groupTuple->Child(i);
auto joinType = join->Child(0)->Content();
if (join->ChildrenSize() > 2) {
Y_ENSURE(join->Child(1)->ChildrenSize() > 3, "Excepted at least 4 args there");
Y_ENSURE(join->Child(1)->IsAtom(), "Supported only USING clause there");
Y_ENSURE(join->Child(1)->Content() == "using", "Supported only USING clause there");
for (ui32 col = 0; col < join->Child(3)->ChildrenSize(); ++col) {
auto lr = join->Child(3)->Child(col);
if (lr->Child(0)->IsAtom()) {
usedColumns.insert(std::make_pair(TString(lr->Child(0)->Content()), std::make_pair(Max<ui32>(), TString())));
}
usedColumns.insert(std::make_pair(TString(lr->Child(1)->Content()), std::make_pair(Max<ui32>(), TString())));
usedColumns.erase(TString(join->Child(2)->Child(col)->Content()));
joinUsingColumns.emplace_back(ToString(groupNo), join->Child(2)->Child(col)->Content());
}
continue;
}
if (joinType != "cross") {
AddColumnsFromType(join->Tail().Child(0)->GetTypeAnn(), usedColumns);
}
Expand Down Expand Up @@ -1519,6 +1534,85 @@ std::tuple<TVector<ui32>, TExprNode::TListType> BuildJoinGroups(TPositionHandle
current = cartesian;
continue;
}
if (join->ChildrenSize() > 2) {
Y_ENSURE(join->Child(1)->IsAtom(), "expected only USING clause when join_ops children size > 2");
Y_ENSURE(join->Child(1)->Content() == "using", "expected only USING clause when join_ops children size > 2");
Y_ENSURE(join->ChildrenSize() > 3 && join->Child(3)->IsList(), "Excepted list of aliased columns in USING join");
auto left = current;
auto right = with;
TExprNode::TListType leftColumns;
TExprNode::TListType rightColumns;
TExprNode::TListType toRemove;
for (auto& col: join->Child(3)->ChildrenList()) {
if (col->Child(0)->IsAtom()) {
leftColumns.push_back(col->Child(0));
} else {
toRemove.push_back(col->Child(0)->Child(0));
leftColumns.push_back(col->Child(0)->Child(0));
}
rightColumns.push_back(col->Child(1));
}
current = BuildEquiJoin(pos, joinType, left, right, leftColumns, rightColumns, ctx);
auto secondStruct = ctx.Builder(pos)
.Lambda()
.Param("row")
.Callable("AsStruct")
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
for (size_t i = 0; i < leftColumns.size(); ++i) {
parent.List(i)
.Add(0, join->Child(2)->Child(i))
.Callable(1, "Coalesce")
.Callable(0, "Member")
.Arg(0, "row")
.Add(1, leftColumns[i])
.Seal()
.Callable(1, "Member")
.Arg(0, "row")
.Add(1, rightColumns[i])
.Seal()
.Seal()
.Seal();
}
return parent;
})
.Seal()
.Seal()
.Build();
auto removeProjection = ctx.Builder(pos)
.Lambda()
.Param("row")
.Arg(0, "row")
.Seal().Build();
current = ctx.Builder(pos)
.Callable("OrderedMap")
.Add(0, current)
.Lambda(1)
.Param("row")
.Callable("FlattenMembers")
.List(0)
.Atom(0, "")
.Do([&](TExprNodeBuilder& parent) -> TExprNodeBuilder& {
if (toRemove.size()) {
parent.Callable(1, "RemoveMembers")
.Arg(0, "row")
.Add(1, ctx.NewList(pos, std::move(toRemove)))
.Seal();
} else {
parent.Arg(1, "row");
}
return parent;
})
.Seal()
.List(1)
.Atom(0, "")
.Apply(1, secondStruct).With(0, "row").Seal()
.Seal()
.Seal()
.Seal()
.Seal()
.Build();
continue;
}

auto predicate = join->Tail().TailPtr();
if (!IsDepended(predicate->Tail(), predicate->Head().Head())) {
Expand Down Expand Up @@ -1709,7 +1803,7 @@ std::tuple<TVector<ui32>, TExprNode::TListType> BuildJoinGroups(TPositionHandle
}

TExprNode::TPtr BuildCrossJoinsBetweenGroups(TPositionHandle pos, const TExprNode::TListType& joinGroups,
const TUsedColumns& usedColumns, const TVector<ui32>& groupForIndex, TExprContext& ctx) {
const TUsedColumns& usedColumns, const TVector<ui32>& groupForIndex, TExprContext& ctx, const TVector<std::pair<TString, TString>>& joinUsingColumns) {
TExprNode::TListType args;
for (ui32 i = 0; i < joinGroups.size(); ++i) {
args.push_back(ctx.Builder(pos)
Expand Down Expand Up @@ -1742,6 +1836,16 @@ TExprNode::TPtr BuildCrossJoinsBetweenGroups(TPositionHandle pos, const TExprNod
.Build());
}

for (const auto& x: joinUsingColumns) {
settings.push_back(ctx.Builder(pos)
.List()
.Atom(0, "rename")
.Atom(1, x.first + "." + x.second)
.Atom(2, x.second)
.Seal()
.Build());
}

auto settingsNode = ctx.NewList(pos, std::move(settings));
args.push_back(settingsNode);
return ctx.NewCallable(pos, "EquiJoin", std::move(args));
Expand Down Expand Up @@ -3307,7 +3411,8 @@ TExprNode::TPtr ExpandPgSelectImpl(const TExprNode::TPtr& node, TExprContext& ct
cleanedInputs.push_back(list);
} else {
// extract all used columns
auto usedColumns = GatherUsedColumns(result, joinOps, filter, groupExprs, having, extraSortColumns, window, winCtx);
TVector<std::pair<TString, TString>> joinUsingColumns;
auto usedColumns = GatherUsedColumns(result, joinOps, filter, groupExprs, having, extraSortColumns, window, winCtx, joinUsingColumns);

// fill index of input for each column
FillInputIndices(from, finalExtTypes, usedColumns, optCtx);
Expand All @@ -3323,7 +3428,7 @@ TExprNode::TPtr ExpandPgSelectImpl(const TExprNode::TPtr& node, TExprContext& ct
if (joinGroups.size() == 1) {
list = joinGroups.front();
} else {
list = BuildCrossJoinsBetweenGroups(node->Pos(), joinGroups, usedColumns, groupForIndex, ctx);
list = BuildCrossJoinsBetweenGroups(node->Pos(), joinGroups, usedColumns, groupForIndex, ctx, joinUsingColumns);
}
}
}
Expand Down
Loading