Skip to content

Commit

Permalink
Merge pull request duckdb#12290 from lnkuiper/auto_cte_materialize
Browse files Browse the repository at this point in the history
Automatically materialize CTEs
  • Loading branch information
Mytherin authored Jul 11, 2024
2 parents 294b622 + c185b74 commit 6e0fc96
Show file tree
Hide file tree
Showing 36 changed files with 1,063 additions and 175 deletions.
31 changes: 31 additions & 0 deletions benchmark/tpch/cte/auto_cte_materialization.benchmark
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# name: benchmark/tpch/cte/auto_cte_materialization.benchmark
# description: Benchmark to check if automatic CTE materialization is working
# group: [cte]

name Automatic CTE materialization
group cte

require tpch

# create the CSV file
load
CALL dbgen(sf=1);

run
WITH my_cte AS (
SELECT
l_returnflag,
l_linestatus,
sum(l_quantity) AS sum_qty
FROM
lineitem
GROUP BY
l_returnflag,
l_linestatus
)
SELECT
(SELECT sum_qty FROM my_cte WHERE l_returnflag = 'A') +
(SELECT sum_qty FROM my_cte WHERE l_returnflag = 'R');

result I
75453860.00
42 changes: 17 additions & 25 deletions benchmark/tpch_plan_cost/queries/q15.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
WITH revenue AS (
SELECT
l_suppkey AS supplier_no,
sum(l_extendedprice * (1 - l_discount)) AS total_revenue
FROM
lineitem
WHERE
l_shipdate >= CAST('1996-01-01' AS date)
AND l_shipdate < CAST('1996-04-01' AS date)
GROUP BY
supplier_no
)
SELECT
s_suppkey,
s_name,
Expand All @@ -6,32 +18,12 @@ SELECT
total_revenue
FROM
supplier,
(
SELECT
l_suppkey AS supplier_no,
sum(l_extendedprice * (1 - l_discount)) AS total_revenue
FROM
lineitem
WHERE
l_shipdate >= CAST('1996-01-01' AS date)
AND l_shipdate < CAST('1996-04-01' AS date)
GROUP BY
supplier_no) revenue0
revenue
WHERE
s_suppkey = supplier_no
AND total_revenue = (
SELECT
max(total_revenue)
FROM (
SELECT
l_suppkey AS supplier_no,
sum(l_extendedprice * (1 - l_discount)) AS total_revenue
FROM
lineitem
WHERE
l_shipdate >= CAST('1996-01-01' AS date)
AND l_shipdate < CAST('1996-04-01' AS date)
GROUP BY
supplier_no) revenue1)
AND total_revenue = (
SELECT
max(total_revenue)
FROM revenue)
ORDER BY
s_suppkey;
2 changes: 1 addition & 1 deletion extension/tpch/dbgen/include/tpch_constants.hpp

Large diffs are not rendered by default.

36 changes: 14 additions & 22 deletions extension/tpch/dbgen/queries/q15.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
WITH revenue AS (
SELECT
l_suppkey AS supplier_no,
sum(l_extendedprice * (1 - l_discount)) AS total_revenue
FROM
lineitem
WHERE
l_shipdate >= CAST('1996-01-01' AS date)
AND l_shipdate < CAST('1996-04-01' AS date)
GROUP BY
supplier_no
)
SELECT
s_suppkey,
s_name,
Expand All @@ -6,32 +18,12 @@ SELECT
total_revenue
FROM
supplier,
(
SELECT
l_suppkey AS supplier_no,
sum(l_extendedprice * (1 - l_discount)) AS total_revenue
FROM
lineitem
WHERE
l_shipdate >= CAST('1996-01-01' AS date)
AND l_shipdate < CAST('1996-04-01' AS date)
GROUP BY
supplier_no) revenue0
revenue
WHERE
s_suppkey = supplier_no
AND total_revenue = (
SELECT
max(total_revenue)
FROM (
SELECT
l_suppkey AS supplier_no,
sum(l_extendedprice * (1 - l_discount)) AS total_revenue
FROM
lineitem
WHERE
l_shipdate >= CAST('1996-01-01' AS date)
AND l_shipdate < CAST('1996-04-01' AS date)
GROUP BY
supplier_no) revenue1)
FROM revenue)
ORDER BY
s_suppkey;
5 changes: 5 additions & 0 deletions src/common/enum_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4468,6 +4468,8 @@ const char* EnumUtil::ToChars<OptimizerType>(OptimizerType value) {
return "FILTER_PULLUP";
case OptimizerType::FILTER_PUSHDOWN:
return "FILTER_PUSHDOWN";
case OptimizerType::CTE_FILTER_PUSHER:
return "CTE_FILTER_PUSHER";
case OptimizerType::REGEX_RANGE:
return "REGEX_RANGE";
case OptimizerType::IN_CLAUSE:
Expand Down Expand Up @@ -4523,6 +4525,9 @@ OptimizerType EnumUtil::FromString<OptimizerType>(const char *value) {
if (StringUtil::Equals(value, "FILTER_PUSHDOWN")) {
return OptimizerType::FILTER_PUSHDOWN;
}
if (StringUtil::Equals(value, "CTE_FILTER_PUSHER")) {
return OptimizerType::CTE_FILTER_PUSHER;
}
if (StringUtil::Equals(value, "REGEX_RANGE")) {
return OptimizerType::REGEX_RANGE;
}
Expand Down
6 changes: 4 additions & 2 deletions src/common/enums/optimizer_type.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#include "duckdb/common/enums/optimizer_type.hpp"
#include "duckdb/common/string_util.hpp"
#include "duckdb/common/exception/parser_exception.hpp"

#include "duckdb/common/exception.hpp"
#include "duckdb/common/exception/parser_exception.hpp"
#include "duckdb/common/string_util.hpp"

namespace duckdb {

Expand All @@ -14,6 +15,7 @@ static const DefaultOptimizerType internal_optimizer_types[] = {
{"expression_rewriter", OptimizerType::EXPRESSION_REWRITER},
{"filter_pullup", OptimizerType::FILTER_PULLUP},
{"filter_pushdown", OptimizerType::FILTER_PUSHDOWN},
{"cte_filter_pusher", OptimizerType::CTE_FILTER_PUSHER},
{"regex_range", OptimizerType::REGEX_RANGE},
{"in_clause", OptimizerType::IN_CLAUSE},
{"join_order", OptimizerType::JOIN_ORDER},
Expand Down
2 changes: 2 additions & 0 deletions src/execution/operator/aggregate/physical_hash_aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,8 @@ string PhysicalHashAggregate::ParamsToString() const {
result += " Filter: " + aggregate.filter->GetName();
}
}
result += "\n[INFOSEPARATOR]\n";
result += StringUtil::Format("EC: %llu\n", estimated_cardinality);
return result;
}

Expand Down
8 changes: 5 additions & 3 deletions src/execution/operator/projection/physical_projection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@ PhysicalProjection::CreateJoinProjection(vector<LogicalType> proj_types, const v
}

string PhysicalProjection::ParamsToString() const {
string extra_info;
string result;
for (auto &expr : select_list) {
extra_info += expr->GetName() + "\n";
result += expr->GetName() + "\n";
}
return extra_info;
result += "\n[INFOSEPARATOR]\n";
result += StringUtil::Format("EC: %llu\n", estimated_cardinality);
return result;
}

} // namespace duckdb
48 changes: 31 additions & 17 deletions src/execution/operator/scan/physical_column_data_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,45 @@ PhysicalColumnDataScan::PhysicalColumnDataScan(vector<LogicalType> types, Physic
: PhysicalOperator(op_type, std::move(types), estimated_cardinality), collection(nullptr), cte_index(cte_index) {
}

class PhysicalColumnDataScanState : public GlobalSourceState {
class PhysicalColumnDataGlobalScanState : public GlobalSourceState {
public:
explicit PhysicalColumnDataScanState() : initialized(false) {
PhysicalColumnDataGlobalScanState(const ClientContext &context, const ColumnDataCollection &collection)
: max_threads(MaxValue<idx_t>(context.config.verify_parallelism ? collection.ChunkCount()
: collection.ChunkCount() / CHUNKS_PER_THREAD,
1)) {
collection.InitializeScan(global_scan_state);
}

//! The current position in the scan
ColumnDataScanState scan_state;
bool initialized;
idx_t MaxThreads() override {
return max_threads;
}

public:
ColumnDataParallelScanState global_scan_state;

static constexpr idx_t CHUNKS_PER_THREAD = 32;
const idx_t max_threads;
};

class PhysicalColumnDataLocalScanState : public LocalSourceState {
public:
ColumnDataLocalScanState local_scan_state;
};

unique_ptr<GlobalSourceState> PhysicalColumnDataScan::GetGlobalSourceState(ClientContext &context) const {
return make_uniq<PhysicalColumnDataScanState>();
return make_uniq<PhysicalColumnDataGlobalScanState>(context, *collection);
}

unique_ptr<LocalSourceState> PhysicalColumnDataScan::GetLocalSourceState(ExecutionContext &,
GlobalSourceState &) const {
return make_uniq<PhysicalColumnDataLocalScanState>();
}

SourceResultType PhysicalColumnDataScan::GetData(ExecutionContext &context, DataChunk &chunk,
OperatorSourceInput &input) const {
auto &state = input.global_state.Cast<PhysicalColumnDataScanState>();
if (collection->Count() == 0) {
return SourceResultType::FINISHED;
}
if (!state.initialized) {
collection->InitializeScan(state.scan_state);
state.initialized = true;
}
collection->Scan(state.scan_state, chunk);

auto &gstate = input.global_state.Cast<PhysicalColumnDataGlobalScanState>();
auto &lstate = input.local_state.Cast<PhysicalColumnDataLocalScanState>();
collection->Scan(gstate.global_scan_state, lstate.local_scan_state, chunk);
return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT;
}

Expand Down Expand Up @@ -108,7 +121,8 @@ string PhysicalColumnDataScan::ParamsToString() const {
default:
break;
}

result += "\n[INFOSEPARATOR]\n";
result += StringUtil::Format("EC: %llu\n", estimated_cardinality);
return result;
}

Expand Down
1 change: 1 addition & 0 deletions src/include/duckdb/common/enums/optimizer_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ enum class OptimizerType : uint32_t {
EXPRESSION_REWRITER,
FILTER_PULLUP,
FILTER_PUSHDOWN,
CTE_FILTER_PUSHER,
REGEX_RANGE,
IN_CLAUSE,
JOIN_ORDER,
Expand Down
23 changes: 18 additions & 5 deletions src/include/duckdb/common/insertion_order_preserving_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

#pragma once

#include "duckdb/common/unordered_map.hpp"
#include "duckdb/common/unordered_set.hpp"
#include "duckdb/common/string.hpp"
#include "duckdb/common/string_util.hpp"
#include "duckdb/common/helper.hpp"
#include "duckdb/common/case_insensitive_map.hpp"
#include "duckdb/common/helper.hpp"
#include "duckdb/common/pair.hpp"
#include "duckdb/common/string.hpp"
#include "duckdb/common/string_util.hpp"
#include "duckdb/common/unordered_map.hpp"
#include "duckdb/common/unordered_set.hpp"

namespace duckdb {

Expand Down Expand Up @@ -59,6 +59,14 @@ class InsertionOrderPreservingMap {
return map.end();
}

typename VECTOR_TYPE::reverse_iterator rbegin() { // NOLINT: match stl API
return map.rbegin();
}

typename VECTOR_TYPE::reverse_iterator rend() { // NOLINT: match stl API
return map.rend();
}

typename VECTOR_TYPE::iterator find(const string &key) { // NOLINT: match stl API
auto entry = map_idx.find(key);
if (entry == map_idx.end()) {
Expand Down Expand Up @@ -92,6 +100,11 @@ class InsertionOrderPreservingMap {
map_idx[key] = map.size() - 1;
}

void insert(const string &key, V &&value) { // NOLINT: match stl API
map.push_back(make_pair(key, std::move(value)));
map_idx[key] = map.size() - 1;
}

void insert(pair<string, V> &&value) { // NOLINT: match stl API
map.push_back(std::move(value));
map_idx[value.first] = map.size() - 1;
Expand Down
1 change: 1 addition & 0 deletions src/include/duckdb/common/optionally_owned_ptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include "duckdb/common/exception.hpp"
#include "duckdb/common/optional_ptr.hpp"
#include "duckdb/common/unique_ptr.hpp"

namespace duckdb {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

#pragma once

#include "duckdb/common/types/column/column_data_collection.hpp"
#include "duckdb/common/optionally_owned_ptr.hpp"
#include "duckdb/common/types/column/column_data_collection.hpp"
#include "duckdb/execution/physical_operator.hpp"

namespace duckdb {
Expand All @@ -33,12 +33,18 @@ class PhysicalColumnDataScan : public PhysicalOperator {

public:
unique_ptr<GlobalSourceState> GetGlobalSourceState(ClientContext &context) const override;
unique_ptr<LocalSourceState> GetLocalSourceState(ExecutionContext &context,
GlobalSourceState &gstate) const override;
SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override;

bool IsSource() const override {
return true;
}

bool ParallelSource() const override {
return true;
}

string ParamsToString() const override;

public:
Expand Down
Loading

0 comments on commit 6e0fc96

Please sign in to comment.