Skip to content

Commit f002be0

Browse files
authored
YQL-16402: Implement key cache for Top and TopSort nodes (#4761)
1 parent 2d98fad commit f002be0

File tree

2 files changed

+240
-6
lines changed

2 files changed

+240
-6
lines changed

ydb/library/yql/minikql/comp_nodes/ut/mkql_sort_ut.cpp

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <ydb/library/yql/minikql/mkql_node.h>
66
#include <ydb/library/yql/minikql/mkql_program_builder.h>
77
#include <ydb/library/yql/minikql/mkql_string_util.h>
8+
#include <ydb/library/yql/public/udf/udf_helpers.h>
89

910
#include <ydb/library/yql/utils/sort.h>
1011

@@ -519,5 +520,225 @@ Y_UNIT_TEST_SUITE(TMiniKQLSortTest) {
519520
UNIT_ASSERT(copy == res);
520521
}
521522
}
523+
524+
Y_UNIT_TEST_SUITE(TMiniKQLStreamKeyExtractorCacheTest) {
525+
static thread_local size_t echoCounter;
526+
527+
SIMPLE_UDF(TEchoU64, ui64(ui64)) {
528+
Y_UNUSED(valueBuilder);
529+
echoCounter++;
530+
return args[0];
531+
}
532+
533+
SIMPLE_MODULE(TCountCallsModule, TEchoU64);
534+
535+
Y_UNIT_TEST(TestStreamTopSort) {
536+
echoCounter = 0;
537+
constexpr ui64 total = 999ULL;
538+
539+
std::uniform_real_distribution<ui64> urdist;
540+
std::default_random_engine rand;
541+
rand.seed(std::time(nullptr));
542+
543+
std::vector<ui64> test;
544+
test.reserve(total);
545+
std::generate_n(std::back_inserter(test), total, [&]() { return urdist(rand) % 100ULL; });
546+
547+
TSetup<false> setup;
548+
NYql::NUdf::AddToStaticUdfRegistry<TCountCallsModule>();
549+
auto mutableRegistry = setup.FunctionRegistry->Clone();
550+
FillStaticModules(*mutableRegistry);
551+
setup.FunctionRegistry = mutableRegistry;
552+
setup.PgmBuilder.Reset(new TProgramBuilder(*setup.Env, *setup.FunctionRegistry));
553+
TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
554+
555+
std::array<TRuntimeNode, total> data;
556+
std::transform(test.cbegin(), test.cend(), data.begin(), [&](const ui64& v) {
557+
return pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<ui64>(v)});
558+
});
559+
560+
constexpr ui64 n = 17ULL;
561+
const auto echoUdf = pgmBuilder.Udf("CountCalls.EchoU64");
562+
const auto tupleType = pgmBuilder.NewTupleType({pgmBuilder.NewDataType(NUdf::TDataType<ui64>::Id)});
563+
const auto ascending = pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<bool>(false)});
564+
const auto list = pgmBuilder.NewList(tupleType, data);
565+
const auto extractor = [&pgmBuilder, echoUdf](TRuntimeNode item) {
566+
return pgmBuilder.NewTuple({ pgmBuilder.Apply(echoUdf, {pgmBuilder.Nth(item, 0U)})});
567+
};
568+
const auto limit = pgmBuilder.NewDataLiteral<ui64>(n);
569+
const auto pgmRoot = pgmBuilder.TopSort(pgmBuilder.Iterator(list, {}), limit, ascending, extractor);
570+
const auto graph = setup.BuildGraph(pgmRoot);
571+
const auto& value = graph->GetValue();
572+
573+
NYql::FastPartialSort(test.begin(), test.begin() + n, test.end(), std::greater<ui64>());
574+
test.resize(n);
575+
576+
std::vector<ui64> res;
577+
res.reserve(n);
578+
for (NUdf::TUnboxedValue item; NUdf::EFetchStatus::Ok == value.Fetch(item);) {
579+
res.emplace_back(item.GetElement(0U).template Get<ui64>());
580+
}
581+
582+
UNIT_ASSERT_VALUES_EQUAL(res.size(), n);
583+
UNIT_ASSERT(res == test);
584+
UNIT_ASSERT_VALUES_EQUAL(echoCounter, total);
585+
}
586+
587+
Y_UNIT_TEST(TestStreamTop) {
588+
echoCounter = 0;
589+
constexpr ui64 total = 999ULL;
590+
591+
std::uniform_real_distribution<ui64> urdist;
592+
std::default_random_engine rand;
593+
rand.seed(std::time(nullptr));
594+
595+
std::vector<ui64> test;
596+
test.reserve(total);
597+
std::generate_n(std::back_inserter(test), total, [&]() { return urdist(rand) % 100ULL; });
598+
599+
TSetup<false> setup;
600+
NYql::NUdf::AddToStaticUdfRegistry<TCountCallsModule>();
601+
auto mutableRegistry = setup.FunctionRegistry->Clone();
602+
FillStaticModules(*mutableRegistry);
603+
setup.FunctionRegistry = mutableRegistry;
604+
setup.PgmBuilder.Reset(new TProgramBuilder(*setup.Env, *setup.FunctionRegistry));
605+
TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
606+
607+
std::array<TRuntimeNode, total> data;
608+
std::transform(test.cbegin(), test.cend(), data.begin(), [&](const ui64& v) {
609+
return pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<ui64>(v)});
610+
});
611+
612+
constexpr ui64 n = 17ULL;
613+
const auto echoUdf = pgmBuilder.Udf("CountCalls.EchoU64");
614+
const auto tupleType = pgmBuilder.NewTupleType({pgmBuilder.NewDataType(NUdf::TDataType<ui64>::Id)});
615+
const auto ascending = pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<bool>(false)});
616+
const auto list = pgmBuilder.NewList(tupleType, data);
617+
const auto extractor = [&pgmBuilder, echoUdf](TRuntimeNode item) {
618+
return pgmBuilder.NewTuple({ pgmBuilder.Apply(echoUdf, {pgmBuilder.Nth(item, 0U)})});
619+
};
620+
const auto limit = pgmBuilder.NewDataLiteral<ui64>(n);
621+
const auto pgmRoot = pgmBuilder.Top(pgmBuilder.Iterator(list, {}), limit, ascending, extractor);
622+
const auto graph = setup.BuildGraph(pgmRoot);
623+
const auto& value = graph->GetValue();
624+
625+
NYql::FastPartialSort(test.begin(), test.begin() + n, test.end(), std::greater<ui64>());
626+
test.resize(n);
627+
628+
std::vector<ui64> res;
629+
res.reserve(n);
630+
for (NUdf::TUnboxedValue item; NUdf::EFetchStatus::Ok == value.Fetch(item);) {
631+
res.emplace_back(item.GetElement(0U).template Get<ui64>());
632+
}
633+
634+
UNIT_ASSERT_VALUES_EQUAL(res.size(), n);
635+
UNIT_ASSERT(res == test);
636+
UNIT_ASSERT_VALUES_EQUAL(echoCounter, total);
637+
}
638+
639+
Y_UNIT_TEST(TestFlowTopSort) {
640+
echoCounter = 0;
641+
constexpr ui64 total = 999ULL;
642+
643+
std::uniform_real_distribution<ui64> urdist;
644+
std::default_random_engine rand;
645+
rand.seed(std::time(nullptr));
646+
647+
std::vector<ui64> test;
648+
test.reserve(total);
649+
std::generate_n(std::back_inserter(test), total, [&]() { return urdist(rand) % 100ULL; });
650+
651+
TSetup<false> setup;
652+
NYql::NUdf::AddToStaticUdfRegistry<TCountCallsModule>();
653+
auto mutableRegistry = setup.FunctionRegistry->Clone();
654+
FillStaticModules(*mutableRegistry);
655+
setup.FunctionRegistry = mutableRegistry;
656+
setup.PgmBuilder.Reset(new TProgramBuilder(*setup.Env, *setup.FunctionRegistry));
657+
TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
658+
659+
std::array<TRuntimeNode, total> data;
660+
std::transform(test.cbegin(), test.cend(), data.begin(), [&](const ui64& v) {
661+
return pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<ui64>(v)});
662+
});
663+
664+
constexpr ui64 n = 17ULL;
665+
const auto echoUdf = pgmBuilder.Udf("CountCalls.EchoU64");
666+
const auto tupleType = pgmBuilder.NewTupleType({pgmBuilder.NewDataType(NUdf::TDataType<ui64>::Id)});
667+
const auto ascending = pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<bool>(false)});
668+
const auto list = pgmBuilder.NewList(tupleType, data);
669+
const auto extractor = [&pgmBuilder, echoUdf](TRuntimeNode item) {
670+
return pgmBuilder.NewTuple({ pgmBuilder.Apply(echoUdf, {pgmBuilder.Nth(item, 0U)})});
671+
};
672+
const auto limit = pgmBuilder.NewDataLiteral<ui64>(n);
673+
const auto pgmRoot = pgmBuilder.FromFlow(pgmBuilder.TopSort(pgmBuilder.ToFlow(list), limit, ascending, extractor));
674+
const auto graph = setup.BuildGraph(pgmRoot);
675+
const auto& value = graph->GetValue();
676+
677+
NYql::FastPartialSort(test.begin(), test.begin() + n, test.end(), std::greater<ui64>());
678+
test.resize(n);
679+
680+
std::vector<ui64> res;
681+
res.reserve(n);
682+
for (NUdf::TUnboxedValue item; NUdf::EFetchStatus::Ok == value.Fetch(item);) {
683+
res.emplace_back(item.GetElement(0U).template Get<ui64>());
684+
}
685+
686+
UNIT_ASSERT_VALUES_EQUAL(res.size(), n);
687+
UNIT_ASSERT(res == test);
688+
UNIT_ASSERT_VALUES_EQUAL(echoCounter, total);
689+
}
690+
691+
Y_UNIT_TEST(TestFlowTop) {
692+
echoCounter = 0;
693+
constexpr ui64 total = 999ULL;
694+
695+
std::uniform_real_distribution<ui64> urdist;
696+
std::default_random_engine rand;
697+
rand.seed(std::time(nullptr));
698+
699+
std::vector<ui64> test;
700+
test.reserve(total);
701+
std::generate_n(std::back_inserter(test), total, [&]() { return urdist(rand) % 100ULL; });
702+
703+
TSetup<false> setup;
704+
NYql::NUdf::AddToStaticUdfRegistry<TCountCallsModule>();
705+
auto mutableRegistry = setup.FunctionRegistry->Clone();
706+
FillStaticModules(*mutableRegistry);
707+
setup.FunctionRegistry = mutableRegistry;
708+
setup.PgmBuilder.Reset(new TProgramBuilder(*setup.Env, *setup.FunctionRegistry));
709+
TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
710+
711+
std::array<TRuntimeNode, total> data;
712+
std::transform(test.cbegin(), test.cend(), data.begin(), [&](const ui64& v) {
713+
return pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<ui64>(v)});
714+
});
715+
716+
constexpr ui64 n = 17ULL;
717+
const auto echoUdf = pgmBuilder.Udf("CountCalls.EchoU64");
718+
const auto tupleType = pgmBuilder.NewTupleType({pgmBuilder.NewDataType(NUdf::TDataType<ui64>::Id)});
719+
const auto ascending = pgmBuilder.NewTuple({pgmBuilder.NewDataLiteral<bool>(false)});
720+
const auto list = pgmBuilder.NewList(tupleType, data);
721+
const auto extractor = [&pgmBuilder, echoUdf](TRuntimeNode item) {
722+
return pgmBuilder.NewTuple({ pgmBuilder.Apply(echoUdf, {pgmBuilder.Nth(item, 0U)})});
723+
};
724+
const auto limit = pgmBuilder.NewDataLiteral<ui64>(n);
725+
const auto pgmRoot = pgmBuilder.FromFlow(pgmBuilder.Top(pgmBuilder.ToFlow(list), limit, ascending, extractor));
726+
const auto graph = setup.BuildGraph(pgmRoot);
727+
const auto& value = graph->GetValue();
728+
729+
NYql::FastPartialSort(test.begin(), test.begin() + n, test.end(), std::greater<ui64>());
730+
test.resize(n);
731+
732+
std::vector<ui64> res;
733+
res.reserve(n);
734+
for (NUdf::TUnboxedValue item; NUdf::EFetchStatus::Ok == value.Fetch(item);) {
735+
res.emplace_back(item.GetElement(0U).template Get<ui64>());
736+
}
737+
738+
UNIT_ASSERT_VALUES_EQUAL(res.size(), n);
739+
UNIT_ASSERT(res == test);
740+
UNIT_ASSERT_VALUES_EQUAL(echoCounter, total);
741+
}
742+
}
522743
} // NMiniKQL
523744
} // NKikimr

ydb/library/yql/minikql/mkql_program_builder.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,13 +1887,20 @@ TRuntimeNode TProgramBuilder::BuildWideTopOrSort(const std::string_view& callabl
18871887

18881888
TRuntimeNode TProgramBuilder::Top(TRuntimeNode flow, TRuntimeNode count, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) {
18891889
if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
1890+
const TUnaryLambda getKey = [&](TRuntimeNode item) { return Nth(item, 0U); };
1891+
const TUnaryLambda getItem = [&](TRuntimeNode item) { return Nth(item, 1U); };
1892+
const TUnaryLambda cacheKeyExtractor = [&](TRuntimeNode item) {
1893+
return NewTuple({keyExtractor(item), item});
1894+
};
18901895

1891-
return FlatMap(Condense1(flow,
1896+
return FlatMap(Condense1(Map(flow, cacheKeyExtractor),
18921897
[&](TRuntimeNode item) { return AsList(item); },
18931898
[this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
1894-
[&](TRuntimeNode item, TRuntimeNode state) { return KeepTop(count, state, item, ascending, keyExtractor); }
1899+
[&](TRuntimeNode item, TRuntimeNode state) {
1900+
return KeepTop(count, state, item, ascending, getKey);
1901+
}
18951902
),
1896-
[&](TRuntimeNode list) { return Top(list, count, ascending, keyExtractor); }
1903+
[&](TRuntimeNode list) { return Map(Top(list, count, ascending, getKey), getItem); }
18971904
);
18981905
}
18991906

@@ -1902,14 +1909,20 @@ TRuntimeNode TProgramBuilder::Top(TRuntimeNode flow, TRuntimeNode count, TRuntim
19021909

19031910
TRuntimeNode TProgramBuilder::TopSort(TRuntimeNode flow, TRuntimeNode count, TRuntimeNode ascending, const TUnaryLambda& keyExtractor) {
19041911
if (const auto flowType = flow.GetStaticType(); flowType->IsFlow() || flowType->IsStream()) {
1905-
return FlatMap(Condense1(flow,
1912+
const TUnaryLambda getKey = [&](TRuntimeNode item) { return Nth(item, 0U); };
1913+
const TUnaryLambda getItem = [&](TRuntimeNode item) { return Nth(item, 1U); };
1914+
const TUnaryLambda cacheKeyExtractor = [&](TRuntimeNode item) {
1915+
return NewTuple({keyExtractor(item), item});
1916+
};
1917+
1918+
return FlatMap(Condense1(Map(flow, cacheKeyExtractor),
19061919
[&](TRuntimeNode item) { return AsList(item); },
19071920
[this](TRuntimeNode, TRuntimeNode) { return NewDataLiteral<bool>(false); },
19081921
[&](TRuntimeNode item, TRuntimeNode state) {
1909-
return KeepTop(count, state, item, ascending, keyExtractor);
1922+
return KeepTop(count, state, item, ascending, getKey);
19101923
}
19111924
),
1912-
[&](TRuntimeNode list) { return TopSort(list, count, ascending, keyExtractor); }
1925+
[&](TRuntimeNode list) { return Map(TopSort(list, count, ascending, getKey), getItem); }
19131926
);
19141927
}
19151928

0 commit comments

Comments
 (0)