Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions ydb/library/yql/providers/common/pushdown/collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,38 @@ bool CompareCanBePushed(const TCoCompare& compare, const TExprNode* lambdaArg, c
return true;
}

bool SqlInCanBePushed(const TCoSqlIn& sqlIn, const TExprNode* lambdaArg, const TExprBase& lambdaBody, const TSettings& settings) {
const TExprBase& expr = sqlIn.Collection();
const TExprBase& lookup = sqlIn.Lookup();

if (!CheckExpressionNodeForPushdown(lookup, lambdaArg, settings)) {
return false;
}

TExprNode::TPtr collection;
if (expr.Ref().IsList()) {
collection = expr.Ptr();
} else if (auto maybeAsList = expr.Maybe<TCoAsList>()) {
collection = maybeAsList.Cast().Ptr();
} else {
return false;
}

const TTypeAnnotationNode* inputType = lambdaBody.Ptr()->GetTypeAnn();
for (auto& child : collection->Children()) {
if (!CheckExpressionNodeForPushdown(TExprBase(child), lambdaArg, settings)) {
return false;
}

if (!settings.IsEnabled(TSettings::EFeatureFlag::DoNotCheckCompareArgumentsTypes)) {
if (!IsComparableTypes(lookup, TExprBase(child), false, inputType, settings)) {
return false;
}
}
}
return true;
}

bool SafeCastCanBePushed(const TCoFlatMap& flatmap, const TExprNode* lambdaArg, const TSettings& settings) {
/*
* There are three ways of comparison in following format:
Expand Down Expand Up @@ -560,6 +592,9 @@ void CollectPredicates(const TExprBase& predicate, TPredicateNode& predicateTree
CollectExpressionPredicate(predicateTree, predicate.Cast<TCoMember>(), lambdaArg);
} else if (settings.IsEnabled(TSettings::EFeatureFlag::JustPassthroughOperators) && (predicate.Maybe<TCoIf>() || predicate.Maybe<TCoJust>())) {
CollectChildrenPredicates(predicate.Ref(), predicateTree, lambdaArg, lambdaBody, settings);
} else if (settings.IsEnabled(TSettings::EFeatureFlag::InOperator) && predicate.Maybe<TCoSqlIn>()) {
auto sqlIn = predicate.Cast<TCoSqlIn>();
predicateTree.CanBePushed = SqlInCanBePushed(sqlIn, lambdaArg, lambdaBody, settings);
} else {
predicateTree.CanBePushed = false;
}
Expand Down
3 changes: 2 additions & 1 deletion ydb/library/yql/providers/common/pushdown/settings.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ struct TSettings {
UnaryOperators = 1 << 15, // -, Abs, Size
DoNotCheckCompareArgumentsTypes = 1 << 16,
TimestampCtor = 1 << 17,
JustPassthroughOperators = 1 << 18 // if + coalesce + just
JustPassthroughOperators = 1 << 18, // if + coalesce + just
InOperator = 1 << 19 // IN()
};

explicit TSettings(NLog::EComponent logComponent)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace NYql {
TString FormatIsNull(const TPredicate_TIsNull& isNull);
TString FormatIsNotNull(const TPredicate_TIsNotNull& isNotNull);
TString FormatPredicate(const TPredicate& predicate, bool topLevel);
TString FormatIn(const TPredicate_TIn& in);

namespace {

Expand Down Expand Up @@ -141,6 +142,32 @@ namespace NYql {
return SerializeExpression(exists.Optional(), expressionProto, arg, err);
}

bool SerializeSqlIn(const TCoSqlIn& sqlIn, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err) {
auto* dstProto = proto->mutable_in();
const TExprBase& expr = sqlIn.Collection();
const TExprBase& lookup = sqlIn.Lookup();

auto* expressionProto = dstProto->mutable_value();
SerializeExpression(lookup, expressionProto, arg, err);

TExprNode::TPtr collection;
if (expr.Ref().IsList()) {
collection = expr.Ptr();
} else if (auto maybeAsList = expr.Maybe<TCoAsList>()) {
collection = maybeAsList.Cast().Ptr();
} else {
err << "unknown operation: " << expr.Ref().Content();
return false;
}

for (auto& child : collection->Children()) {
if (!SerializeExpression(TExprBase(child), dstProto->add_set(), arg, err)) {
return false;
}
}
return true;
}

bool SerializeAnd(const TCoAnd& andExpr, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err) {
auto* dstProto = proto->mutable_conjunction();
for (const auto& child : andExpr.Ptr()->Children()) {
Expand Down Expand Up @@ -196,6 +223,9 @@ namespace NYql {
if (auto exists = predicate.Maybe<TCoExists>()) {
return SerializeExists(exists.Cast(), proto, arg, err);
}
if (auto sqlIn = predicate.Maybe<TCoSqlIn>()) {
return SerializeSqlIn(sqlIn.Cast(), proto, arg, err);
}

err << "unknown predicate: " << predicate.Raw()->Content();
return false;
Expand Down Expand Up @@ -401,6 +431,18 @@ namespace NYql {
return left + operation + right;
}

TString FormatIn(const TPredicate_TIn& in) {
auto value = FormatExpression(in.value());
TString list;
for (const auto& expr : in.set()) {
if (!list.empty()) {
list += ",";
}
list += FormatExpression(expr);
}
return value + " IN (" + list + ")";
}

TString FormatPredicate(const TPredicate& predicate, bool topLevel ) {
switch (predicate.payload_case()) {
case TPredicate::PAYLOAD_NOT_SET:
Expand All @@ -419,6 +461,8 @@ namespace NYql {
return FormatComparison(predicate.comparison());
case TPredicate::kBoolExpression:
return FormatExpression(predicate.bool_expression().value());
case TPredicate::kIn:
return FormatIn(predicate.in());
default:
throw yexception() << "UnimplementedPredicateType, payload_case " << static_cast<ui64>(predicate.payload_case());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
{"Index": 2, "Name": "Columns", "Type": "TExprBase"},
{"Index": 3, "Name": "Settings", "Type": "TCoNameValueTupleList"},
{"Index": 4, "Name": "Token", "Type": "TCoSecureParam"},
{"Index": 5, "Name": "FilterPredicate", "Type": "TCoLambda"}
{"Index": 5, "Name": "FilterPredicate", "Type": "TCoAtom"}
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,6 @@ class TPqDataSourceTypeAnnotationTransformer : public TVisitorTransformerBase {
return TStatus::Error;
}

auto rowSchema = topic.RowSpec().Ref().GetTypeAnn()->Cast<TTypeExprType>()->GetType()->Cast<TStructExprType>();

const TStatus filterAnnotationStatus = NYql::NPushdown::AnnotateFilterPredicate(input.Ptr(), TDqPqTopicSource::idx_FilterPredicate, rowSchema, ctx);
if (filterAnnotationStatus != TStatus::Ok) {
return filterAnnotationStatus;
}

if (topic.Metadata().Empty()) {
input.Ptr()->SetTypeAnn(ctx.MakeType<TStreamExprType>(ctx.MakeType<TDataExprType>(EDataSlot::String)));
Expand Down
18 changes: 4 additions & 14 deletions ydb/library/yql/providers/pq/provider/yql_pq_dq_integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,7 @@ class TPqDqIntegration: public TDqIntegrationBase {
auto row = Build<TCoArgument>(ctx, read->Pos())
.Name("row")
.Done();
auto emptyPredicate = Build<TCoLambda>(ctx, read->Pos())
.Args({row})
.Body<TCoBool>()
.Literal().Build("true")
.Build()
.Done().Ptr();
TString emptyPredicate;

return Build<TDqSourceWrap>(ctx, read->Pos())
.Input<TDqPqTopicSource>()
Expand All @@ -155,7 +150,7 @@ class TPqDqIntegration: public TDqIntegrationBase {
.Token<TCoSecureParam>()
.Name().Build(token)
.Build()
.FilterPredicate(emptyPredicate)
.FilterPredicate().Value(emptyPredicate).Build()
.Build()
.RowType(ExpandType(pqReadTopic.Pos(), *rowType, ctx))
.DataSource(pqReadTopic.DataSource().Cast<TCoDataSource>())
Expand Down Expand Up @@ -264,13 +259,8 @@ class TPqDqIntegration: public TDqIntegrationBase {
}

NYql::NConnector::NApi::TPredicate predicateProto;
if (auto predicate = topicSource.FilterPredicate(); !NYql::IsEmptyFilterPredicate(predicate)) {
TStringBuilder err;
if (!NYql::SerializeFilterPredicate(predicate, &predicateProto, err)) {
ctx.AddWarning(TIssue(ctx.GetPosition(node.Pos()), "Failed to serialize filter predicate for source: " + err));
predicateProto.Clear();
}
}
auto serializedProto = topicSource.FilterPredicate().Ref().Content();
YQL_ENSURE (predicateProto.ParseFromString(serializedProto));

sharedReading = sharedReading && (format == "json_each_row" || format == "raw");
TString predicateSql = NYql::FormatWhere(predicateProto);
Expand Down
28 changes: 23 additions & 5 deletions ydb/library/yql/providers/pq/provider/yql_pq_logical_opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#include <ydb/library/yql/providers/common/pushdown/physical_opt.h>
#include <ydb/library/yql/providers/common/pushdown/predicate_node.h>

#include <ydb/library/yql/providers/generic/connector/api/service/protos/connector.pb.h>
#include <ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.h>

namespace NYql {

using namespace NNodes;
Expand All @@ -27,7 +30,7 @@ namespace {
: NPushdown::TSettings(NLog::EComponent::ProviderGeneric)
{
using EFlag = NPushdown::TSettings::EFeatureFlag;
Enable(EFlag::ExpressionAsPredicate | EFlag::ArithmeticalExpressions | EFlag::ImplicitConversionToInt64 | EFlag::StringTypes | EFlag::LikeOperator | EFlag::DoNotCheckCompareArgumentsTypes);
Enable(EFlag::ExpressionAsPredicate | EFlag::ArithmeticalExpressions | EFlag::ImplicitConversionToInt64 | EFlag::StringTypes | EFlag::LikeOperator | EFlag::DoNotCheckCompareArgumentsTypes | EFlag::InOperator);
}
};

Expand Down Expand Up @@ -250,15 +253,30 @@ class TPqLogicalOptProposalTransformer : public TOptimizeTransformerBase {
return node;
}
TDqPqTopicSource dqPqTopicSource = maybeDqPqTopicSource.Cast();
if (!IsEmptyFilterPredicate(dqPqTopicSource.FilterPredicate())) {
if (!dqPqTopicSource.FilterPredicate().Ref().Content().empty()) {
YQL_CLOG(TRACE, ProviderPq) << "Push filter. Lambda is already not empty";
return node;
}

auto newFilterLambda = MakePushdownPredicate(flatmap.Lambda(), ctx, node.Pos(), TPushdownSettings());
if (!newFilterLambda) {
return node;
}

auto predicate = newFilterLambda.Cast();
if (NYql::IsEmptyFilterPredicate(predicate)) {
return node;
}

TStringBuilder err;
NYql::NConnector::NApi::TPredicate predicateProto;
if (!NYql::SerializeFilterPredicate(predicate, &predicateProto, err)) {
ctx.AddWarning(TIssue(ctx.GetPosition(node.Pos()), "Failed to serialize filter predicate for source: " + err));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

почему это не ошибка?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Тут не все предикаты поддерживаются. Выше MakePushdownPredicate что-то проверяет, но возможны всё равно различия. В идеале да, нужно фейлить запрос

return node;
}

TString serializedProto;
YQL_ENSURE(predicateProto.SerializeToString(&serializedProto));
YQL_CLOG(INFO, ProviderPq) << "Build new TCoFlatMap with predicate";

if (maybeExtractMembers) {
Expand All @@ -270,7 +288,7 @@ class TPqLogicalOptProposalTransformer : public TOptimizeTransformerBase {
.InitFrom(dqSourceWrap)
.Input<TDqPqTopicSource>()
.InitFrom(dqPqTopicSource)
.FilterPredicate(newFilterLambda.Cast())
.FilterPredicate().Value(serializedProto).Build()
.Build()
.Build()
.Build()
Expand All @@ -282,7 +300,7 @@ class TPqLogicalOptProposalTransformer : public TOptimizeTransformerBase {
.InitFrom(dqSourceWrap)
.Input<TDqPqTopicSource>()
.InitFrom(dqPqTopicSource)
.FilterPredicate(newFilterLambda.Cast())
.FilterPredicate().Value(serializedProto).Build()
.Build()
.Build()
.Done();
Expand Down
55 changes: 28 additions & 27 deletions ydb/tests/fq/yds/test_row_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ def wait_row_dispatcher_sensor_value(kikimr, sensor, expected_count, exact_match

class TestPqRowDispatcher(TestYdsBase):

def run_and_check(self, kikimr, client, sql, input, output, expected_predicate):
query_id = start_yds_query(kikimr, client, sql)
wait_actor_count(kikimr, "FQ_ROW_DISPATCHER_SESSION", 1)

self.write_stream(input)
assert self.read_stream(len(output), topic_path=self.output_topic) == output

stop_yds_query(client, query_id)
wait_actor_count(kikimr, "FQ_ROW_DISPATCHER_SESSION", 0)

issues = str(client.describe_query(query_id).result.query.transient_issue)
assert expected_predicate in issues, "Incorrect Issues: " + issues

@yq_v1
def test_read_raw_format_with_row_dispatcher(self, kikimr, client):
client.create_yds_connection(
Expand Down Expand Up @@ -253,7 +266,6 @@ def test_nested_types(self, kikimr, client):
issues = str(client.describe_query(query_id).result.query.transient_issue)
assert "Row dispatcher will use the predicate:" in issues, "Incorrect Issues: " + issues

@yq_v1
def test_nested_types_without_predicate(self, kikimr, client):
client.create_yds_connection(
YDS_CONNECTION, os.getenv("YDB_DATABASE"), os.getenv("YDB_ENDPOINT"), shared_reading=True
Expand Down Expand Up @@ -284,7 +296,7 @@ def test_nested_types_without_predicate(self, kikimr, client):
stop_yds_query(client, query_id)

@yq_v1
def test_filter(self, kikimr, client):
def test_filters(self, kikimr, client):
client.create_yds_connection(
YDS_CONNECTION, os.getenv("YDB_DATABASE"), os.getenv("YDB_ENDPOINT"), shared_reading=True
)
Expand All @@ -293,34 +305,23 @@ def test_filter(self, kikimr, client):
sql = Rf'''
INSERT INTO {YDS_CONNECTION}.`{self.output_topic}`
SELECT Cast(time as String) FROM {YDS_CONNECTION}.`{self.input_topic}`
WITH (format=json_each_row, SCHEMA (time UInt64 NOT NULL, data String NOT NULL, event String NOT NULL))
WHERE time > 101 and
data = "hello2" and
event IS NOT DISTINCT FROM "event2" and
event IS DISTINCT FROM "event1";'''

query_id = start_yds_query(kikimr, client, sql)
wait_actor_count(kikimr, "FQ_ROW_DISPATCHER_SESSION", 1)

WITH (format=json_each_row, SCHEMA (time UInt64 NOT NULL, data String NOT NULL, event String NOT NULL)) WHERE '''
data = [
'{"time": 101, "data": "hello1", "event": "event1"}',
'{"time": 102, "data": "hello2", "event": "event2"}',
]

self.write_stream(data)
'{"time": 102, "data": "hello2", "event": "event2"}']
filter = "time > 101;"
expected = ['102']
assert self.read_stream(len(expected), topic_path=self.output_topic) == expected

wait_actor_count(kikimr, "DQ_PQ_READ_ACTOR", 1)

stop_yds_query(client, query_id)
# Assert that all read rules were removed after query stops
read_rules = list_read_rules(self.input_topic)
assert len(read_rules) == 0, read_rules
wait_actor_count(kikimr, "FQ_ROW_DISPATCHER_SESSION", 0)

issues = str(client.describe_query(query_id).result.query.transient_issue)
assert "Row dispatcher will use the predicate: WHERE (`time` > 101" in issues, "Incorrect Issues: " + issues
self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE `time` > 101')
filter = 'data = "hello2"'
self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE `data` = \\"hello2\\"')
filter = ' event IS NOT DISTINCT FROM "event2"'
self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE `event` IS NOT DISTINCT FROM \\"event2\\"')
filter = ' event IS DISTINCT FROM "event1"'
self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE `event` IS DISTINCT FROM \\"event1\\"')
filter = 'event IN ("event2")'
self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE `event` IN (\\"event2\\")')
filter = 'event IN ("1", "2", "3", "4", "5", "6", "7", "event2")'
self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE `event` IN (\\"1\\"')

@yq_v1
def test_filter_missing_fields(self, kikimr, client):
Expand Down
Loading