diff --git a/src/include/duckdb/common/stack_checker.hpp b/src/include/duckdb/common/stack_checker.hpp new file mode 100644 index 000000000000..a2375e8ef966 --- /dev/null +++ b/src/include/duckdb/common/stack_checker.hpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/stack_checker.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +template +class StackChecker { +public: + StackChecker(RECURSIVE_CLASS &recursive_class_p, idx_t stack_usage_p) + : recursive_class(recursive_class_p), stack_usage(stack_usage_p) { + recursive_class.stack_depth += stack_usage; + } + ~StackChecker() { + recursive_class.stack_depth -= stack_usage; + } + StackChecker(StackChecker &&other) noexcept + : recursive_class(other.recursive_class), stack_usage(other.stack_usage) { + other.stack_usage = 0; + } + StackChecker(const StackChecker &) = delete; + +private: + RECURSIVE_CLASS &recursive_class; + idx_t stack_usage; +}; + +} // namespace duckdb diff --git a/src/include/duckdb/parser/transformer.hpp b/src/include/duckdb/parser/transformer.hpp index d9db17138533..849e8740683a 100644 --- a/src/include/duckdb/parser/transformer.hpp +++ b/src/include/duckdb/parser/transformer.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/constants.hpp" #include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/stack_checker.hpp" #include "duckdb/common/types.hpp" #include "duckdb/common/unordered_map.hpp" #include "duckdb/parser/group_by_node.hpp" @@ -26,7 +27,6 @@ namespace duckdb { class ColumnDefinition; -class StackChecker; struct OrderByNode; struct CopyInfo; struct CommonTableExpressionInfo; @@ -39,7 +39,7 @@ struct PivotColumn; //! The transformer class is responsible for transforming the internal Postgres //! parser representation into the DuckDB representation class Transformer { - friend class StackChecker; + friend class StackChecker; struct CreatePivotEntry { string enum_name; @@ -343,7 +343,7 @@ class Transformer { idx_t stack_depth; void InitializeStackCheck(); - StackChecker StackCheck(idx_t extra_stack = 1); + StackChecker StackCheck(idx_t extra_stack = 1); public: template @@ -356,18 +356,6 @@ class Transformer { } }; -class StackChecker { -public: - StackChecker(Transformer &transformer, idx_t stack_usage); - ~StackChecker(); - StackChecker(StackChecker &&) noexcept; - StackChecker(const StackChecker &) = delete; - -private: - Transformer &transformer; - idx_t stack_usage; -}; - vector ReadPgListToString(duckdb_libpgquery::PGList *column_list); } // namespace duckdb diff --git a/src/include/duckdb/planner/expression_binder.hpp b/src/include/duckdb/planner/expression_binder.hpp index 3efc79964ee9..11e5a882c937 100644 --- a/src/include/duckdb/planner/expression_binder.hpp +++ b/src/include/duckdb/planner/expression_binder.hpp @@ -9,11 +9,12 @@ #pragma once #include "duckdb/common/exception.hpp" +#include "duckdb/common/stack_checker.hpp" +#include "duckdb/common/unordered_map.hpp" #include "duckdb/parser/expression/bound_expression.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/tokens.hpp" #include "duckdb/planner/expression.hpp" -#include "duckdb/common/unordered_map.hpp" namespace duckdb { @@ -51,6 +52,8 @@ struct BindResult { }; class ExpressionBinder { + friend class StackChecker; + public: ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder = false); virtual ~ExpressionBinder(); @@ -110,6 +113,15 @@ class ExpressionBinder { void ReplaceMacroParametersRecursive(unique_ptr &expr); +private: + //! Maximum stack depth + static constexpr const idx_t MAXIMUM_STACK_DEPTH = 128; + //! Current stack depth + idx_t stack_depth = DConstants::INVALID_INDEX; + + void InitializeStackCheck(); + StackChecker StackCheck(const ParsedExpression &expr, idx_t extra_stack = 1); + protected: BindResult BindExpression(BetweenExpression &expr, idx_t depth); BindResult BindExpression(CaseExpression &expr, idx_t depth); diff --git a/src/parser/transformer.cpp b/src/parser/transformer.cpp index dafb83e9e30e..c6af4d92c53a 100644 --- a/src/parser/transformer.cpp +++ b/src/parser/transformer.cpp @@ -9,20 +9,6 @@ namespace duckdb { -StackChecker::StackChecker(Transformer &transformer_p, idx_t stack_usage_p) - : transformer(transformer_p), stack_usage(stack_usage_p) { - transformer.stack_depth += stack_usage; -} - -StackChecker::~StackChecker() { - transformer.stack_depth -= stack_usage; -} - -StackChecker::StackChecker(StackChecker &&other) noexcept - : transformer(other.transformer), stack_usage(other.stack_usage) { - other.stack_usage = 0; -} - Transformer::Transformer(ParserOptions &options) : parent(nullptr), options(options), stack_depth(DConstants::INVALID_INDEX) { } @@ -59,7 +45,7 @@ void Transformer::InitializeStackCheck() { stack_depth = 0; } -StackChecker Transformer::StackCheck(idx_t extra_stack) { +StackChecker Transformer::StackCheck(idx_t extra_stack) { auto &root = RootTransformer(); D_ASSERT(root.stack_depth != DConstants::INVALID_INDEX); if (root.stack_depth + extra_stack >= options.max_expression_depth) { @@ -67,7 +53,7 @@ StackChecker Transformer::StackCheck(idx_t extra_stack) { "increase the maximum expression depth.", options.max_expression_depth); } - return StackChecker(root, extra_stack); + return StackChecker(root, extra_stack); } unique_ptr Transformer::TransformStatement(duckdb_libpgquery::PGNode &stmt) { diff --git a/src/planner/binder/expression/bind_macro_expression.cpp b/src/planner/binder/expression/bind_macro_expression.cpp index c83a29d79262..cce36d49d3ac 100644 --- a/src/planner/binder/expression/bind_macro_expression.cpp +++ b/src/planner/binder/expression/bind_macro_expression.cpp @@ -44,50 +44,6 @@ void ExpressionBinder::ReplaceMacroParametersRecursive(unique_ptr &child) { ReplaceMacroParametersRecursive(child); }); } -static void DetectInfiniteMacroRecursion(ClientContext &context, unique_ptr &expr, - reference_set_t &expanded_macros) { - optional_ptr recursive_macro; - switch (expr->GetExpressionClass()) { - case ExpressionClass::FUNCTION: { - auto &func = expr->Cast(); - auto function = Catalog::GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, func.catalog, func.schema, - func.function_name, OnEntryNotFound::RETURN_NULL); - if (function && function->type == CatalogType::MACRO_ENTRY) { - if (expanded_macros.find(*function) != expanded_macros.end()) { - throw BinderException("Infinite recursion detected in macro \"%s\"", func.function_name); - } else { - recursive_macro = &function->Cast(); - } - } - break; - } - case ExpressionClass::SUBQUERY: { - // replacing parameters within a subquery is slightly different - auto &sq = (expr->Cast()).subquery; - ParsedExpressionIterator::EnumerateQueryNodeChildren(*sq->node, [&](unique_ptr &child) { - DetectInfiniteMacroRecursion(context, child, expanded_macros); - }); - break; - } - default: // fall through - break; - } - // unfold child expressions - if (recursive_macro) { - auto ¯o_def = recursive_macro->function->Cast(); - auto rec_expr = macro_def.expression->Copy(); - expanded_macros.insert(*recursive_macro); - ParsedExpressionIterator::EnumerateChildren(*rec_expr, [&](unique_ptr &child) { - DetectInfiniteMacroRecursion(context, child, expanded_macros); - }); - expanded_macros.erase(*recursive_macro); - } else { - ParsedExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child) { - DetectInfiniteMacroRecursion(context, child, expanded_macros); - }); - } -} - BindResult ExpressionBinder::BindMacro(FunctionExpression &function, ScalarMacroCatalogEntry ¯o_func, idx_t depth, unique_ptr &expr) { // recast function so we can access the scalar member function->expression @@ -126,11 +82,6 @@ BindResult ExpressionBinder::BindMacro(FunctionExpression &function, ScalarMacro // replace current expression with stored macro expression expr = macro_def.expression->Copy(); - // detect infinite recursion - reference_set_t expanded_macros; - expanded_macros.insert(macro_func); - DetectInfiniteMacroRecursion(context, expr, expanded_macros); - // now replace the parameters ReplaceMacroParametersRecursive(expr); diff --git a/src/planner/expression_binder.cpp b/src/planner/expression_binder.cpp index 1f1dcbb5c726..266b4102c5c0 100644 --- a/src/planner/expression_binder.cpp +++ b/src/planner/expression_binder.cpp @@ -10,6 +10,7 @@ namespace duckdb { ExpressionBinder::ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder) : binder(binder), context(context) { + InitializeStackCheck(); if (replace_binder) { stored_binder = &binder.GetActiveBinder(); binder.SetActiveBinder(*this); @@ -28,7 +29,26 @@ ExpressionBinder::~ExpressionBinder() { } } +void ExpressionBinder::InitializeStackCheck() { + if (binder.HasActiveBinder()) { + stack_depth = binder.GetActiveBinder().stack_depth; + } else { + stack_depth = 0; + } +} + +StackChecker ExpressionBinder::StackCheck(const ParsedExpression &expr, idx_t extra_stack) { + D_ASSERT(stack_depth != DConstants::INVALID_INDEX); + if (stack_depth + extra_stack >= MAXIMUM_STACK_DEPTH) { + throw BinderException("Maximum recursion depth exceeded (Maximum: %llu) while binding \"%s\"", + MAXIMUM_STACK_DEPTH, expr.ToString()); + } + return StackChecker(*this, extra_stack); +} + BindResult ExpressionBinder::BindExpression(unique_ptr &expr, idx_t depth, bool root_expression) { + auto stack_checker = StackCheck(*expr); + auto &expr_ref = *expr; switch (expr_ref.expression_class) { case ExpressionClass::BETWEEN: diff --git a/test/sql/catalog/function/test_recursive_macro.test b/test/sql/catalog/function/test_recursive_macro.test index 80e36c010104..be9739561679 100644 --- a/test/sql/catalog/function/test_recursive_macro.test +++ b/test/sql/catalog/function/test_recursive_macro.test @@ -8,12 +8,12 @@ CREATE MACRO "sum"(x) AS (CASE WHEN sum(x) IS NULL THEN 0 ELSE sum(x) END); statement error SELECT sum(1); ---- -Binder Error: Infinite recursion detected +Binder Error: Maximum recursion depth exceeded statement error SELECT sum(1) WHERE 42=0 ---- -Binder Error: Infinite recursion detected +Binder Error: Maximum recursion depth exceeded statement ok DROP MACRO sum @@ -45,4 +45,19 @@ create or replace macro m1(a) as m2(a)+1; statement error select m2(42); ---- -Binder Error: Infinite recursion detected +Binder Error: Maximum recursion depth exceeded + +# also table macros +statement ok +create macro m3(a) as a+1; + +statement ok +create macro m4(a) as table select m3(a); + +statement ok +create or replace macro m3(a) as (from m4(42)); + +statement error +select m3(42); +---- +Binder Error: Maximum recursion depth exceeded