Skip to content

Commit

Permalink
implement StackCheck in ExpressionBinder
Browse files Browse the repository at this point in the history
  • Loading branch information
lnkuiper committed Jul 17, 2023
1 parent fe946a6 commit f222b41
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 84 deletions.
34 changes: 34 additions & 0 deletions src/include/duckdb/common/stack_checker.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===----------------------------------------------------------------------===//
// DuckDB
//
// duckdb/common/stack_checker.hpp
//
//
//===----------------------------------------------------------------------===//

#pragma once

namespace duckdb {

template <class RECURSIVE_CLASS>
class StackChecker {
public:
StackChecker(RECURSIVE_CLASS &recursive_class_p, idx_t stack_usage_p)
: recursive_class(recursive_class_p), stack_usage(stack_usage_p) {
recursive_class.stack_depth += stack_usage;
}
~StackChecker() {
recursive_class.stack_depth -= stack_usage;
}
StackChecker(StackChecker &&other) noexcept
: recursive_class(other.recursive_class), stack_usage(other.stack_usage) {
other.stack_usage = 0;
}
StackChecker(const StackChecker &) = delete;

private:
RECURSIVE_CLASS &recursive_class;
idx_t stack_usage;
};

} // namespace duckdb
18 changes: 3 additions & 15 deletions src/include/duckdb/parser/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "duckdb/common/case_insensitive_map.hpp"
#include "duckdb/common/constants.hpp"
#include "duckdb/common/enums/expression_type.hpp"
#include "duckdb/common/stack_checker.hpp"
#include "duckdb/common/types.hpp"
#include "duckdb/common/unordered_map.hpp"
#include "duckdb/parser/group_by_node.hpp"
Expand All @@ -26,7 +27,6 @@
namespace duckdb {

class ColumnDefinition;
class StackChecker;
struct OrderByNode;
struct CopyInfo;
struct CommonTableExpressionInfo;
Expand All @@ -39,7 +39,7 @@ struct PivotColumn;
//! The transformer class is responsible for transforming the internal Postgres
//! parser representation into the DuckDB representation
class Transformer {
friend class StackChecker;
friend class StackChecker<Transformer>;

struct CreatePivotEntry {
string enum_name;
Expand Down Expand Up @@ -343,7 +343,7 @@ class Transformer {
idx_t stack_depth;

void InitializeStackCheck();
StackChecker StackCheck(idx_t extra_stack = 1);
StackChecker<Transformer> StackCheck(idx_t extra_stack = 1);

public:
template <class T>
Expand All @@ -356,18 +356,6 @@ class Transformer {
}
};

class StackChecker {
public:
StackChecker(Transformer &transformer, idx_t stack_usage);
~StackChecker();
StackChecker(StackChecker &&) noexcept;
StackChecker(const StackChecker &) = delete;

private:
Transformer &transformer;
idx_t stack_usage;
};

vector<string> ReadPgListToString(duckdb_libpgquery::PGList *column_list);

} // namespace duckdb
14 changes: 13 additions & 1 deletion src/include/duckdb/planner/expression_binder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
#pragma once

#include "duckdb/common/exception.hpp"
#include "duckdb/common/stack_checker.hpp"
#include "duckdb/common/unordered_map.hpp"
#include "duckdb/parser/expression/bound_expression.hpp"
#include "duckdb/parser/parsed_expression.hpp"
#include "duckdb/parser/tokens.hpp"
#include "duckdb/planner/expression.hpp"
#include "duckdb/common/unordered_map.hpp"

namespace duckdb {

Expand Down Expand Up @@ -51,6 +52,8 @@ struct BindResult {
};

class ExpressionBinder {
friend class StackChecker<ExpressionBinder>;

public:
ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder = false);
virtual ~ExpressionBinder();
Expand Down Expand Up @@ -110,6 +113,15 @@ class ExpressionBinder {

void ReplaceMacroParametersRecursive(unique_ptr<ParsedExpression> &expr);

private:
//! Maximum stack depth
static constexpr const idx_t MAXIMUM_STACK_DEPTH = 128;
//! Current stack depth
idx_t stack_depth = DConstants::INVALID_INDEX;

void InitializeStackCheck();
StackChecker<ExpressionBinder> StackCheck(const ParsedExpression &expr, idx_t extra_stack = 1);

protected:
BindResult BindExpression(BetweenExpression &expr, idx_t depth);
BindResult BindExpression(CaseExpression &expr, idx_t depth);
Expand Down
18 changes: 2 additions & 16 deletions src/parser/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,6 @@

namespace duckdb {

StackChecker::StackChecker(Transformer &transformer_p, idx_t stack_usage_p)
: transformer(transformer_p), stack_usage(stack_usage_p) {
transformer.stack_depth += stack_usage;
}

StackChecker::~StackChecker() {
transformer.stack_depth -= stack_usage;
}

StackChecker::StackChecker(StackChecker &&other) noexcept
: transformer(other.transformer), stack_usage(other.stack_usage) {
other.stack_usage = 0;
}

Transformer::Transformer(ParserOptions &options)
: parent(nullptr), options(options), stack_depth(DConstants::INVALID_INDEX) {
}
Expand Down Expand Up @@ -59,15 +45,15 @@ void Transformer::InitializeStackCheck() {
stack_depth = 0;
}

StackChecker Transformer::StackCheck(idx_t extra_stack) {
StackChecker<Transformer> Transformer::StackCheck(idx_t extra_stack) {
auto &root = RootTransformer();
D_ASSERT(root.stack_depth != DConstants::INVALID_INDEX);
if (root.stack_depth + extra_stack >= options.max_expression_depth) {
throw ParserException("Max expression depth limit of %lld exceeded. Use \"SET max_expression_depth TO x\" to "
"increase the maximum expression depth.",
options.max_expression_depth);
}
return StackChecker(root, extra_stack);
return StackChecker<Transformer>(root, extra_stack);
}

unique_ptr<SQLStatement> Transformer::TransformStatement(duckdb_libpgquery::PGNode &stmt) {
Expand Down
49 changes: 0 additions & 49 deletions src/planner/binder/expression/bind_macro_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,50 +44,6 @@ void ExpressionBinder::ReplaceMacroParametersRecursive(unique_ptr<ParsedExpressi
*expr, [&](unique_ptr<ParsedExpression> &child) { ReplaceMacroParametersRecursive(child); });
}

static void DetectInfiniteMacroRecursion(ClientContext &context, unique_ptr<ParsedExpression> &expr,
reference_set_t<CatalogEntry> &expanded_macros) {
optional_ptr<ScalarMacroCatalogEntry> recursive_macro;
switch (expr->GetExpressionClass()) {
case ExpressionClass::FUNCTION: {
auto &func = expr->Cast<FunctionExpression>();
auto function = Catalog::GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, func.catalog, func.schema,
func.function_name, OnEntryNotFound::RETURN_NULL);
if (function && function->type == CatalogType::MACRO_ENTRY) {
if (expanded_macros.find(*function) != expanded_macros.end()) {
throw BinderException("Infinite recursion detected in macro \"%s\"", func.function_name);
} else {
recursive_macro = &function->Cast<ScalarMacroCatalogEntry>();
}
}
break;
}
case ExpressionClass::SUBQUERY: {
// replacing parameters within a subquery is slightly different
auto &sq = (expr->Cast<SubqueryExpression>()).subquery;
ParsedExpressionIterator::EnumerateQueryNodeChildren(*sq->node, [&](unique_ptr<ParsedExpression> &child) {
DetectInfiniteMacroRecursion(context, child, expanded_macros);
});
break;
}
default: // fall through
break;
}
// unfold child expressions
if (recursive_macro) {
auto &macro_def = recursive_macro->function->Cast<ScalarMacroFunction>();
auto rec_expr = macro_def.expression->Copy();
expanded_macros.insert(*recursive_macro);
ParsedExpressionIterator::EnumerateChildren(*rec_expr, [&](unique_ptr<ParsedExpression> &child) {
DetectInfiniteMacroRecursion(context, child, expanded_macros);
});
expanded_macros.erase(*recursive_macro);
} else {
ParsedExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr<ParsedExpression> &child) {
DetectInfiniteMacroRecursion(context, child, expanded_macros);
});
}
}

BindResult ExpressionBinder::BindMacro(FunctionExpression &function, ScalarMacroCatalogEntry &macro_func, idx_t depth,
unique_ptr<ParsedExpression> &expr) {
// recast function so we can access the scalar member function->expression
Expand Down Expand Up @@ -126,11 +82,6 @@ BindResult ExpressionBinder::BindMacro(FunctionExpression &function, ScalarMacro
// replace current expression with stored macro expression
expr = macro_def.expression->Copy();

// detect infinite recursion
reference_set_t<CatalogEntry> expanded_macros;
expanded_macros.insert(macro_func);
DetectInfiniteMacroRecursion(context, expr, expanded_macros);

// now replace the parameters
ReplaceMacroParametersRecursive(expr);

Expand Down
20 changes: 20 additions & 0 deletions src/planner/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace duckdb {

ExpressionBinder::ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder)
: binder(binder), context(context) {
InitializeStackCheck();
if (replace_binder) {
stored_binder = &binder.GetActiveBinder();
binder.SetActiveBinder(*this);
Expand All @@ -28,7 +29,26 @@ ExpressionBinder::~ExpressionBinder() {
}
}

void ExpressionBinder::InitializeStackCheck() {
if (binder.HasActiveBinder()) {
stack_depth = binder.GetActiveBinder().stack_depth;
} else {
stack_depth = 0;
}
}

StackChecker<ExpressionBinder> ExpressionBinder::StackCheck(const ParsedExpression &expr, idx_t extra_stack) {
D_ASSERT(stack_depth != DConstants::INVALID_INDEX);
if (stack_depth + extra_stack >= MAXIMUM_STACK_DEPTH) {
throw BinderException("Maximum recursion depth exceeded (Maximum: %llu) while binding \"%s\"",
MAXIMUM_STACK_DEPTH, expr.ToString());
}
return StackChecker<ExpressionBinder>(*this, extra_stack);
}

BindResult ExpressionBinder::BindExpression(unique_ptr<ParsedExpression> &expr, idx_t depth, bool root_expression) {
auto stack_checker = StackCheck(*expr);

auto &expr_ref = *expr;
switch (expr_ref.expression_class) {
case ExpressionClass::BETWEEN:
Expand Down
21 changes: 18 additions & 3 deletions test/sql/catalog/function/test_recursive_macro.test
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ CREATE MACRO "sum"(x) AS (CASE WHEN sum(x) IS NULL THEN 0 ELSE sum(x) END);
statement error
SELECT sum(1);
----
Binder Error: Infinite recursion detected
Binder Error: Maximum recursion depth exceeded

statement error
SELECT sum(1) WHERE 42=0
----
Binder Error: Infinite recursion detected
Binder Error: Maximum recursion depth exceeded

statement ok
DROP MACRO sum
Expand Down Expand Up @@ -45,4 +45,19 @@ create or replace macro m1(a) as m2(a)+1;
statement error
select m2(42);
----
Binder Error: Infinite recursion detected
Binder Error: Maximum recursion depth exceeded

# also table macros
statement ok
create macro m3(a) as a+1;

statement ok
create macro m4(a) as table select m3(a);

statement ok
create or replace macro m3(a) as (from m4(42));

statement error
select m3(42);
----
Binder Error: Maximum recursion depth exceeded

0 comments on commit f222b41

Please sign in to comment.