Skip to content

Commit

Permalink
Fully support aggregate on part of a input block (#8064)
Browse files Browse the repository at this point in the history
ref #7738
  • Loading branch information
windtalker authored Sep 11, 2023
1 parent 1539364 commit f1d2d46
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 36 deletions.
3 changes: 2 additions & 1 deletion dbms/src/Common/FailPoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ namespace DB
M(skip_seek_before_read_dmfile) \
M(exception_after_large_write_exceed) \
M(proactive_flush_force_set_type) \
M(exception_when_fetch_disagg_pages)
M(exception_when_fetch_disagg_pages) \
M(force_agg_on_partial_block)

#define APPLY_FOR_PAUSEABLE_FAILPOINTS_ONCE(M) \
M(pause_with_alter_locks_acquired) \
Expand Down
10 changes: 7 additions & 3 deletions dbms/src/DataStreams/ParallelAggregatingBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,12 @@ void ParallelAggregatingBlockInputStream::Handler::onBlock(Block & block, size_t
auto & data = *parent.many_data[thread_num];
auto & agg_process_info = parent.threads_data[thread_num].agg_process_info;
agg_process_info.resetBlock(block);
parent.aggregator.executeOnBlock(agg_process_info, data, thread_num);
if (data.need_spill)
parent.aggregator.spill(data, thread_num);
do
{
parent.aggregator.executeOnBlock(agg_process_info, data, thread_num);
if (data.need_spill)
parent.aggregator.spill(data, thread_num);
} while (!agg_process_info.allBlockDataHandled());

parent.threads_data[thread_num].src_rows += block.rows();
parent.threads_data[thread_num].src_bytes += block.bytes();
Expand Down Expand Up @@ -273,6 +276,7 @@ void ParallelAggregatingBlockInputStream::execute()
aggregator.executeOnBlock(agg_process_info, data, 0);
if (data.need_spill)
aggregator.spill(data, 0);
assert(agg_process_info.allBlockDataHandled());
}
}

Expand Down
60 changes: 58 additions & 2 deletions dbms/src/Flash/tests/gtest_aggregation_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

namespace DB
{
namespace FailPoints
{
extern const char force_agg_on_partial_block[];
} // namespace FailPoints
namespace tests
{
#define DT DecimalField<Decimal32>
Expand Down Expand Up @@ -232,6 +236,17 @@ class AggExecutorTestRunner : public ExecutorTest
ColumnWithUInt64 col_pr{1, 2, 0, 3290124, 968933, 3125, 31236, 4327, 80000};
};

#define WRAP_FOR_AGG_PARTIAL_BLOCK_START \
std::vector<bool> partial_blocks{true, false}; \
for (auto partial_block : partial_blocks) \
{ \
if (partial_block) \
FailPointHelper::enableFailPoint(FailPoints::force_agg_on_partial_block); \
else \
FailPointHelper::disableFailPoint(FailPoints::force_agg_on_partial_block);

#define WRAP_FOR_AGG_PARTIAL_BLOCK_END }

/// Guarantee the correctness of group by
TEST_F(AggExecutorTestRunner, GroupBy)
try
Expand Down Expand Up @@ -340,7 +355,9 @@ try
for (size_t i = 0; i < test_num; ++i)
{
request = buildDAGRequest(std::make_pair(db_name, table_types), {}, group_by_exprs[i], projections[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, expect_cols[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}
}

Expand Down Expand Up @@ -397,7 +414,9 @@ try
for (size_t i = 0; i < test_num; ++i)
{
request = buildDAGRequest(std::make_pair(db_name, table_types), {}, group_by_exprs[i], projections[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, expect_cols[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}
}

Expand Down Expand Up @@ -429,7 +448,9 @@ try
for (size_t i = 0; i < test_num; ++i)
{
request = buildDAGRequest(std::make_pair(db_name, table_name), agg_funcs[i], group_by_exprs[i], projections[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, expect_cols[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}

/// Min function tests
Expand All @@ -448,7 +469,9 @@ try
for (size_t i = 0; i < test_num; ++i)
{
request = buildDAGRequest(std::make_pair(db_name, table_name), agg_funcs[i], group_by_exprs[i], projections[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, expect_cols[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}
}
CATCH
Expand Down Expand Up @@ -506,7 +529,9 @@ try
{
request
= buildDAGRequest(std::make_pair(db_name, table_name), {agg_funcs[i]}, group_by_exprs[i], projections[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, expect_cols[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}
}
CATCH
Expand Down Expand Up @@ -574,7 +599,9 @@ try
{agg_func},
group_by_exprs[i],
projections[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, expect_cols[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}
}
{
Expand All @@ -586,7 +613,9 @@ try
{agg_func},
group_by_exprs[i],
projections[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, expect_cols[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}
}
for (auto collation_id : {0, static_cast<int>(TiDB::ITiDBCollator::BINARY)})
Expand Down Expand Up @@ -623,7 +652,9 @@ try
{agg_func},
group_by_exprs[i],
projections[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, expect_cols[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}
}
}
Expand All @@ -636,7 +667,9 @@ try
executeAndAssertColumnsEqual(request, {{toNullableVec<String>({"banana"})}});

request = context.scan("aggnull_test", "t1").aggregation({}, {col("s1")}).build(context);
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, {{toNullableVec<String>("s1", {{}, "banana"})}});
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}
CATCH

Expand All @@ -648,7 +681,9 @@ try
= {toNullableVec<Int64>({3}), toNullableVec<Int64>({1}), toVec<UInt64>({6})};
auto test_single_function = [&](size_t index) {
auto request = context.scan("test_db", "test_table").aggregation({functions[index]}, {}).build(context);
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, {functions_result[index]});
WRAP_FOR_AGG_PARTIAL_BLOCK_END
};
for (size_t i = 0; i < functions.size(); ++i)
test_single_function(i);
Expand All @@ -669,7 +704,9 @@ try
results.push_back(functions_result[k]);

auto request = context.scan("test_db", "test_table").aggregation(funcs, {}).build(context);
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, results);
WRAP_FOR_AGG_PARTIAL_BLOCK_END

funcs.pop_back();
results.pop_back();
Expand Down Expand Up @@ -705,7 +742,9 @@ try
context.context->setSetting(
"group_by_two_level_threshold",
Field(static_cast<UInt64>(two_level_threshold)));
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, expect);
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}
}
}
Expand Down Expand Up @@ -736,6 +775,7 @@ try
"group_by_two_level_threshold",
Field(static_cast<UInt64>(two_level_threshold)));
context.context->setSetting("max_block_size", Field(static_cast<UInt64>(block_size)));
WRAP_FOR_AGG_PARTIAL_BLOCK_START
auto blocks = getExecuteStreamsReturnBlocks(request, concurrency);
size_t actual_row = 0;
for (auto & block : blocks)
Expand All @@ -744,6 +784,7 @@ try
actual_row += block.rows();
}
ASSERT_EQ(actual_row, expect_rows[i]);
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}
}
}
Expand Down Expand Up @@ -857,6 +898,7 @@ try
"group_by_two_level_threshold",
Field(static_cast<UInt64>(two_level_threshold)));
context.context->setSetting("max_block_size", Field(static_cast<UInt64>(block_size)));
WRAP_FOR_AGG_PARTIAL_BLOCK_START
auto blocks = getExecuteStreamsReturnBlocks(request, concurrency);
for (auto & block : blocks)
{
Expand All @@ -881,6 +923,7 @@ try
vstackBlocks(std::move(blocks)).getColumnsWithTypeAndName(),
false));
}
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}
}
}
Expand All @@ -907,12 +950,20 @@ try
executeAndAssertColumnsEqual(request, {});

request = context.receive("empty_recv", 5).aggregation({Max(col("s1"))}, {col("s2")}, 5).build(context);
executeAndAssertColumnsEqual(request, {});
{
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, {});
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}

request = context.scan("test_db", "empty_table")
.aggregation({Count(lit(Field(static_cast<UInt64>(1))))}, {})
.build(context);
executeAndAssertColumnsEqual(request, {toVec<UInt64>({0})});
{
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(request, {toVec<UInt64>({0})});
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}
}
CATCH

Expand Down Expand Up @@ -961,10 +1012,15 @@ try
auto baseline = executeStreams(gen_request(1), 1);
for (size_t exchange_concurrency : exchange_receiver_concurrency)
{
WRAP_FOR_AGG_PARTIAL_BLOCK_START
executeAndAssertColumnsEqual(gen_request(exchange_concurrency), baseline);
WRAP_FOR_AGG_PARTIAL_BLOCK_END
}
}
CATCH

#undef WRAP_FOR_AGG_PARTIAL_BLOCK_START
#undef WRAP_FOR_AGG_PARTIAL_BLOCK_END

} // namespace tests
} // namespace DB
26 changes: 26 additions & 0 deletions dbms/src/Flash/tests/gtest_spill_aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@

namespace DB
{
namespace FailPoints
{
extern const char force_agg_on_partial_block[];
} // namespace FailPoints

namespace tests
{
class SpillAggregationTestRunner : public DB::tests::ExecutorTest
Expand All @@ -27,6 +32,17 @@ class SpillAggregationTestRunner : public DB::tests::ExecutorTest
void initializeContext() override { ExecutorTest::initializeContext(); }
};

#define WRAP_FOR_AGG_PARTIAL_BLOCK_START \
std::vector<bool> partial_blocks{true, false}; \
for (auto partial_block : partial_blocks) \
{ \
if (partial_block) \
FailPointHelper::enableFailPoint(FailPoints::force_agg_on_partial_block); \
else \
FailPointHelper::disableFailPoint(FailPoints::force_agg_on_partial_block);

#define WRAP_FOR_AGG_PARTIAL_BLOCK_END }

#define WRAP_FOR_SPILL_TEST_BEGIN \
std::vector<bool> pipeline_bools{false, true}; \
for (auto enable_pipeline : pipeline_bools) \
Expand Down Expand Up @@ -82,9 +98,11 @@ try
context.context->setSetting("group_by_two_level_threshold_bytes", Field(static_cast<UInt64>(1)));
/// don't use `executeAndAssertColumnsEqual` since it takes too long to run
/// test single thread aggregation
WRAP_FOR_AGG_PARTIAL_BLOCK_START
ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, 1));
/// test parallel aggregation
ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams));
WRAP_FOR_AGG_PARTIAL_BLOCK_END
/// enable spill and use small max_cached_data_bytes_in_spiller
context.context->setSetting("max_cached_data_bytes_in_spiller", Field(static_cast<UInt64>(total_data_size / 200)));
/// test single thread aggregation
Expand Down Expand Up @@ -226,6 +244,7 @@ try
Field(static_cast<UInt64>(max_bytes_before_external_agg)));
context.context->setSetting("max_block_size", Field(static_cast<UInt64>(max_block_size)));
WRAP_FOR_SPILL_TEST_BEGIN
WRAP_FOR_AGG_PARTIAL_BLOCK_START
auto blocks = getExecuteStreamsReturnBlocks(request, concurrency);
for (auto & block : blocks)
{
Expand All @@ -250,6 +269,7 @@ try
vstackBlocks(std::move(blocks)).getColumnsWithTypeAndName(),
false));
}
WRAP_FOR_AGG_PARTIAL_BLOCK_END
WRAP_FOR_SPILL_TEST_END
}
}
Expand Down Expand Up @@ -377,6 +397,7 @@ try
Field(static_cast<UInt64>(max_bytes_before_external_agg)));
context.context->setSetting("max_block_size", Field(static_cast<UInt64>(max_block_size)));
WRAP_FOR_SPILL_TEST_BEGIN
WRAP_FOR_AGG_PARTIAL_BLOCK_START
auto blocks = getExecuteStreamsReturnBlocks(request, concurrency);
for (auto & block : blocks)
{
Expand All @@ -401,6 +422,7 @@ try
vstackBlocks(std::move(blocks)).getColumnsWithTypeAndName(),
false));
}
WRAP_FOR_AGG_PARTIAL_BLOCK_END
WRAP_FOR_SPILL_TEST_END
}
}
Expand Down Expand Up @@ -474,14 +496,18 @@ try
/// don't use `executeAndAssertColumnsEqual` since it takes too long to run
auto request = gen_request(exchange_concurrency);
WRAP_FOR_SPILL_TEST_BEGIN
WRAP_FOR_AGG_PARTIAL_BLOCK_START
ASSERT_COLUMNS_EQ_UR(baseline, executeStreams(request, exchange_concurrency));
WRAP_FOR_AGG_PARTIAL_BLOCK_END
WRAP_FOR_SPILL_TEST_END
}
}
CATCH

#undef WRAP_FOR_SPILL_TEST_BEGIN
#undef WRAP_FOR_SPILL_TEST_END
#undef WRAP_FOR_AGG_PARTIAL_BLOCK_START
#undef WRAP_FOR_AGG_PARTIAL_BLOCK_END

} // namespace tests
} // namespace DB
Loading

0 comments on commit f1d2d46

Please sign in to comment.