Skip to content

YQL-16896: Common type inferring for SELECT combinators #843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
69 changes: 62 additions & 7 deletions ydb/library/yql/core/common_opt/yql_co_pgselect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,18 @@ TExprNode::TPtr BuildCrossJoinsBetweenGroups(TPositionHandle pos, const TExprNod
return ctx.NewCallable(pos, "EquiJoin", std::move(args));
}

TExprNode::TPtr BuildProjectionLambda(TPositionHandle pos, const TExprNode::TPtr& result, bool subLink, bool emitPgStar, TExprContext& ctx) {
TExprNode::TPtr BuildProjectionLambda(TPositionHandle pos, const TExprNode::TPtr& result, const TStructExprType* finalType,
const TColumnOrder& nodeColumnOrder, const TColumnOrder& setItemColumnOrder,
bool subLink, bool emitPgStar, TExprContext& ctx) {

YQL_ENSURE(nodeColumnOrder.size() == setItemColumnOrder.size());
TMap<TStringBuf, TStringBuf> columnNamesMap;
if (!emitPgStar) {
for (size_t i = 0; i < nodeColumnOrder.size(); ++i) {
columnNamesMap[setItemColumnOrder[i]] = nodeColumnOrder[i];
}
}

return ctx.Builder(pos)
.Lambda()
.Param("row")
Expand All @@ -1705,26 +1716,68 @@ TExprNode::TPtr BuildProjectionLambda(TPositionHandle pos, const TExprNode::TPtr
.Seal();
listBuilder.Seal();
};

auto addAtomToListWithCast = [&addAtomToList] (TExprNodeBuilder& listBuilder, TExprNode* x,
const TTypeAnnotationNode* expectedTypeNode) -> void {
auto actualType = x->GetTypeAnn()->Cast<TPgExprType>();
Y_ENSURE(expectedTypeNode);
const auto expectedType = expectedTypeNode->Cast<TPgExprType>();

if (actualType == expectedType) {
addAtomToList(listBuilder, x);
return;
}
listBuilder.Add(0, x->HeadPtr());
listBuilder.Callable(1, "PgCast")
.Apply(0, x->TailPtr())
.With(0, "row")
.Seal()
.Callable(1, "PgType")
.Atom(0, NPg::LookupType(expectedType->GetId()).Name)
.Seal();
listBuilder.Seal();
};

for (const auto& x : result->Tail().Children()) {
if (x->HeadPtr()->IsAtom()) {
if (!emitPgStar) {
const auto& columnName = x->Child(0)->Content();
auto listBuilder = parent.List(index++);
addAtomToList(listBuilder, x.Get());
addAtomToListWithCast(listBuilder, x.Get(), finalType->FindItemType(columnNamesMap[columnName]));
}
} else {
auto type = x->Child(1)->GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TStructExprType>();
Y_ENSURE(type);

for (const auto& item : type->GetItems()) {
TStringBuf column = item->GetName();
auto columnName = subLink ? column : NTypeAnnImpl::RemoveAlias(column);

auto listBuilder = parent.List(index++);
if (overrideColumns.contains(columnName)) {
// we never get here while processing SELECTs,
// so no need to add PgCasts due to query combining with UNION ALL et al
addAtomToList(listBuilder, overrideColumns[columnName]);
} else {
listBuilder.Atom(0, columnName);
listBuilder.Callable(1, "Member")
.Arg(0, "row")
.Atom(1, column);
listBuilder.Seal();

const auto expectedType = finalType->FindItemType(columnNamesMap[columnName]);
if (item->GetItemType() == expectedType) {
listBuilder.Callable(1, "Member")
.Arg(0, "row")
.Atom(1, column)
.Seal();
} else {
listBuilder.Callable(1, "PgCast")
.Callable(0, "Member")
.Arg(0, "row")
.Atom(1, column)
.Seal()
.Callable(1, "PgType")
.Atom(0, NPg::LookupType(expectedType->Cast<TPgExprType>()->GetId()).Name)
.Seal()
.Seal();
}
}
}
}
Expand Down Expand Up @@ -3159,7 +3212,9 @@ TExprNode::TPtr ExpandPgSelectImpl(const TExprNode::TPtr& node, TExprContext& ct
}
} else {
YQL_ENSURE(result);
TExprNode::TPtr projectionLambda = BuildProjectionLambda(node->Pos(), result, subLinkId.Defined(), emitPgStar, ctx);
auto finalType = node->GetTypeAnn()->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>();
Y_ENSURE(finalType);
TExprNode::TPtr projectionLambda = BuildProjectionLambda(node->Pos(), result, finalType, *order, *childOrder, subLinkId.Defined(), emitPgStar, ctx);
TExprNode::TPtr projectionArg = projectionLambda->Head().HeadPtr();
TExprNode::TPtr projectionRoot = projectionLambda->TailPtr();
TVector<TString> inputAliases;
Expand Down
138 changes: 125 additions & 13 deletions ydb/library/yql/core/type_ann/type_ann_pg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <ydb/library/yql/parser/pg_wrapper/interface/utils.h>

#include <util/generic/set.h>
#include <util/generic/hash.h>

namespace NYql {

Expand Down Expand Up @@ -64,8 +65,8 @@ bool ValidateInputTypes(TExprNode& node, TExprContext& ctx) {
return true;
}

TExprNodePtr WrapWithPgCast(TExprNodePtr&& node, ui32 targetTypeId, TContext& ctx) {
return ctx.Expr.Builder(node->Pos())
TExprNodePtr WrapWithPgCast(TExprNodePtr&& node, ui32 targetTypeId, TExprContext& ctx) {
return ctx.Builder(node->Pos())
.Callable("PgCast")
.Add(0, std::move(node))
.Callable(1, "PgType")
Expand All @@ -75,6 +76,113 @@ TExprNodePtr WrapWithPgCast(TExprNodePtr&& node, ui32 targetTypeId, TContext& ct
.Build();
};

TExprNodePtr FindLeftCombinatorOfNthSetItem(const TExprNode* setItems, const TExprNode* setOps, ui32 n) {
TVector<ui32> setItemsStack(setItems->ChildrenSize());
i32 sp = -1;
ui32 itemIdx = 0;
for (const auto& op : setOps->Children()) {
if (op->Content() == "push") {
setItemsStack[++sp] = itemIdx++;
} else {
if (setItemsStack[sp] == n) {
return op;
}
--sp;
Y_ENSURE(0 <= sp);
}
}
Y_UNREACHABLE();
}

IGraphTransformer::TStatus InferPgCommonType(TPositionHandle pos, const TExprNode* setItems, const TExprNode* setOps,
TColumnOrder& resultColumnOrder, const TStructExprType*& resultStructType, TExtContext& ctx)
{
TVector<TVector<ui32>> pgTypes;
size_t fieldsCnt = 0;

for (size_t i = 0; i < setItems->ChildrenSize(); ++i) {
const auto* child = setItems->Child(i);

if (!EnsureListType(*child, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
auto itemType = child->GetTypeAnn()->Cast<TListExprType>()->GetItemType();
YQL_ENSURE(itemType);

if (!EnsureStructType(child->Pos(), *itemType, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

auto childColumnOrder = ctx.Types.LookupColumnOrder(*child);
if (!childColumnOrder) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Pos()), TStringBuilder()
<< "Input #" << i << " does not have ordered columns. "
<< "Consider making column order explicit by using SELECT with column names"));
return IGraphTransformer::TStatus::Error;
}

if (0 == i) {
resultColumnOrder = *childColumnOrder;
fieldsCnt = resultColumnOrder.size();

pgTypes.resize(fieldsCnt);
for (size_t j = 0; j < fieldsCnt; ++j) {
pgTypes[j].reserve(setItems->ChildrenSize());
}
} else {
if ((*childColumnOrder).size() != fieldsCnt) {
TExprNodePtr combinator = FindLeftCombinatorOfNthSetItem(setItems, setOps, i);
Y_ENSURE(combinator);

TString op(combinator->Content());
if (op.EndsWith("_all")) {
op.erase(op.length() - 4);
}
op.to_upper();

ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(child->Pos()), TStringBuilder()
<< "each " << op << " query must have the same number of columns"));

return IGraphTransformer::TStatus::Error;
}
}

const auto structType = itemType->Cast<TStructExprType>();
{
size_t j = 0;
for (const auto& col : *childColumnOrder) {
auto itemIdx = structType->FindItem(col);
YQL_ENSURE(itemIdx);
pgTypes[j].push_back(structType->GetItems()[*itemIdx]->GetItemType()->Cast<TPgExprType>()->GetId());

++j;
}
}
}

TVector<const TItemExprType*> structItems;
for (size_t j = 0; j < fieldsCnt; ++j) {
const NPg::TTypeDesc* commonType;
if (const auto issue = NPg::LookupCommonType(pgTypes[j],
[j, &setItems, &ctx](size_t i) {
return ctx.Expr.GetPosition(setItems->Child(i)->Child(j)->Pos());
}, commonType))
{
ctx.Expr.AddError(*issue);
return IGraphTransformer::TStatus::Error;
}
structItems.push_back(ctx.Expr.MakeType<TItemExprType>(resultColumnOrder[j],
ctx.Expr.MakeType<TPgExprType>(commonType->TypeId)));
}

resultStructType = ctx.Expr.MakeType<TStructExprType>(structItems);
if (!resultStructType->Validate(pos, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

return IGraphTransformer::TStatus::Ok;
}

IGraphTransformer::TStatus PgStarWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
Y_UNUSED(output);
if (!EnsureArgsCount(*input, 0, ctx.Expr)) {
Expand Down Expand Up @@ -212,12 +320,12 @@ IGraphTransformer::TStatus PgCallWrapper(const TExprNode::TPtr& input, TExprNode
const auto& fargTypes = (*procPtr)->ArgTypes;
for (size_t i = 0; i < argTypes.size(); ++i) {
if (IsCastRequired(argTypes[i], fargTypes[i])) {
children[i+3] = WrapWithPgCast(std::move(children[i+3]), fargTypes[i], ctx);
children[i+3] = WrapWithPgCast(std::move(children[i+3]), fargTypes[i], ctx.Expr);
}
}
output = ctx.Expr.NewCallable(input->Pos(), "PgResolvedCall", std::move(children));
} else if (const auto* typePtr = std::get_if<const NPg::TTypeDesc*>(&procOrType)) {
output = WrapWithPgCast(std::move(children[2]), (*typePtr)->TypeId, ctx);
output = WrapWithPgCast(std::move(children[2]), (*typePtr)->TypeId, ctx.Expr);
} else {
Y_UNREACHABLE();
}
Expand Down Expand Up @@ -454,16 +562,16 @@ IGraphTransformer::TStatus PgOpWrapper(const TExprNode::TPtr& input, TExprNode::
switch(oper.Kind) {
case NPg::EOperKind::LeftUnary:
if (IsCastRequired(argTypes[0], oper.RightType)) {
children[1] = WrapWithPgCast(std::move(children[1]), oper.RightType, ctx);
children[1] = WrapWithPgCast(std::move(children[1]), oper.RightType, ctx.Expr);
}
break;

case NYql::NPg::EOperKind::Binary:
if (IsCastRequired(argTypes[0], oper.LeftType)) {
children[1] = WrapWithPgCast(std::move(children[1]), oper.LeftType, ctx);
children[1] = WrapWithPgCast(std::move(children[1]), oper.LeftType, ctx.Expr);
}
if (IsCastRequired(argTypes[1], oper.RightType)) {
children[2] = WrapWithPgCast(std::move(children[2]), oper.RightType, ctx);
children[2] = WrapWithPgCast(std::move(children[2]), oper.RightType, ctx.Expr);
}
break;

Expand Down Expand Up @@ -648,7 +756,7 @@ IGraphTransformer::TStatus PgAggWrapper(const TExprNode::TPtr& input, TExprNode:
for (ui32 i = 0; i < argTypes.size(); ++i, ++argIdx) {
if (IsCastRequired(argTypes[i], aggDesc.ArgTypes[i])) {
auto& argNode = input->ChildRef(argIdx);
argNode = WrapWithPgCast(std::move(argNode), aggDesc.ArgTypes[i], ctx);
argNode = WrapWithPgCast(std::move(argNode), aggDesc.ArgTypes[i], ctx.Expr);
needRetype = true;
}
}
Expand Down Expand Up @@ -4155,7 +4263,7 @@ IGraphTransformer::TStatus PgValuesListWrapper(const TExprNode::TPtr& input, TEx
if (item->GetTypeAnn()->Cast<TPgExprType>()->GetId() == commonTypes[j]) {
rowValues.push_back(item);
} else {
rowValues.push_back(WrapWithPgCast(std::move(item), commonTypes[j], ctx));
rowValues.push_back(WrapWithPgCast(std::move(item), commonTypes[j], ctx.Expr));
}
}
resultValues.push_back(ctx.Expr.NewList(value->Pos(), std::move(rowValues)));
Expand Down Expand Up @@ -4338,7 +4446,11 @@ IGraphTransformer::TStatus PgSelectWrapper(const TExprNode::TPtr& input, TExprNo

TColumnOrder resultColumnOrder;
const TStructExprType* resultStructType = nullptr;
auto status = InferPositionalUnionType(input->Pos(), setItems->ChildrenList(), resultColumnOrder, resultStructType, ctx);

auto status = (1 == setItems->ChildrenSize() && HasSetting(*setItems->Child(0)->Child(0), "unknowns_allowed"))
? InferPositionalUnionType(input->Pos(), setItems->ChildrenList(), resultColumnOrder, resultStructType, ctx)
: InferPgCommonType(input->Pos(), setItems, setOps, resultColumnOrder, resultStructType, ctx);

if (status != IGraphTransformer::TStatus::Ok) {
return status;
}
Expand Down Expand Up @@ -4471,7 +4583,7 @@ IGraphTransformer::TStatus PgArrayWrapper(const TExprNode::TPtr& input, TExprNod
if (argTypes[i] == elemType) {
castArrayElems.push_back(child);
} else {
castArrayElems.push_back(WrapWithPgCast(std::move(child), elemType, ctx));
castArrayElems.push_back(WrapWithPgCast(std::move(child), elemType, ctx.Expr));
}
}
output = ctx.Expr.NewCallable(input->Pos(), "PgArray", std::move(castArrayElems));
Expand Down Expand Up @@ -4587,7 +4699,7 @@ IGraphTransformer::TStatus PgLikeWrapper(const TExprNode::TPtr& input, TExprNode
if (argTypes[i] != textTypeId) {
if (argTypes[i] == NPg::UnknownOid) {
auto& argNode = input->ChildRef(i);
argNode = WrapWithPgCast(std::move(argNode), textTypeId, ctx);
argNode = WrapWithPgCast(std::move(argNode), textTypeId, ctx.Expr);
return IGraphTransformer::TStatus::Repeat;
}
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
Expand Down Expand Up @@ -4656,7 +4768,7 @@ IGraphTransformer::TStatus PgInWrapper(const TExprNode::TPtr& input, TExprNode::
if (itemTypePg && inputTypePg && itemTypePg != inputTypePg) {
if (inputTypePg == NPg::UnknownOid) {

input->ChildRef(0) = WrapWithPgCast(std::move(input->Child(0)), itemTypePg, ctx);
input->ChildRef(0) = WrapWithPgCast(std::move(input->Child(0)), itemTypePg, ctx.Expr);
return IGraphTransformer::TStatus::Repeat;
}
if (itemTypePg == NPg::UnknownOid) {
Expand Down
1 change: 1 addition & 0 deletions ydb/library/yql/core/type_ann/type_ann_pg.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
namespace NYql {
namespace NTypeAnnImpl {

TExprNodePtr WrapWithPgCast(TExprNodePtr&& node, ui32 targetTypeId, TExprContext& ctx);
TString MakeAliasedColumn(TStringBuf alias, TStringBuf column);
const TItemExprType* AddAlias(const TString& alias, const TItemExprType* item, TExprContext& ctx);
TStringBuf RemoveAlias(TStringBuf column);
Expand Down
11 changes: 6 additions & 5 deletions ydb/library/yql/sql/pg/pg_sql.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ class TConverter : public IPGParseEvents {
}
}

bool hasCombiningQueries = (1 < setItems.size());

TAstNode* sort = nullptr;
if (ListLength(value->sortClause) > 0) {
Expand All @@ -716,7 +717,7 @@ class TConverter : public IPGParseEvents {
return nullptr;
}

auto sort = ParseSortBy(CAST_NODE_EXT(PG_SortBy, T_SortBy, node), setItems.size() == 1, true);
auto sort = ParseSortBy(CAST_NODE_EXT(PG_SortBy, T_SortBy, node), !hasCombiningQueries, true);
if (!sort) {
return nullptr;
}
Expand All @@ -728,7 +729,7 @@ class TConverter : public IPGParseEvents {
}

TVector<TAstNode*> setItemNodes;
for (size_t id = 0; id < setItems.size(); id++) {
for (size_t id = 0; id < setItems.size(); ++id) {
const auto& x = setItems[id];
bool hasDistinctAll = false;
TVector<TAstNode*> distinctOnItems;
Expand Down Expand Up @@ -1051,11 +1052,11 @@ class TConverter : public IPGParseEvents {
setItemOptions.push_back(QL(QA("distinct_on"), distinctOn));
}

if (setItems.size() == 1 && sort) {
if (!hasCombiningQueries && sort) {
setItemOptions.push_back(QL(QA("sort"), sort));
}

if (unknownsAllowed) {
if (unknownsAllowed || hasCombiningQueries) {
setItemOptions.push_back(QL(QA("unknowns_allowed")));
}

Expand Down Expand Up @@ -1106,7 +1107,7 @@ class TConverter : public IPGParseEvents {
selectOptions.push_back(QL(QA("set_items"), QVL(setItemNodes.data(), setItemNodes.size())));
selectOptions.push_back(QL(QA("set_ops"), QVL(setOpsNodes.data(), setOpsNodes.size())));

if (setItems.size() > 1 && sort) {
if (hasCombiningQueries && sort) {
selectOptions.push_back(QL(QA("sort"), sort));
}

Expand Down
Loading