Skip to content

Commit

Permalink
Fix memory leaks in custom executor
Browse files Browse the repository at this point in the history
This reverts commit c43e429.
  • Loading branch information
Y-- committed Oct 10, 2024
1 parent 39ad690 commit 5dd26f3
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 21 deletions.
30 changes: 24 additions & 6 deletions include/pgduckdb/pgduckdb_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ extern "C" {
}

#include "duckdb/common/exception.hpp"
#include "duckdb/common/error_data.hpp"

#include <vector>
#include <string>
Expand All @@ -30,7 +31,6 @@ template <typename T, typename FuncType, typename... FuncArgs>
T
PostgresFunctionGuard(FuncType postgres_function, FuncArgs... args) {
T return_value;
bool error = false;
MemoryContext ctx = CurrentMemoryContext;
ErrorData *edata = nullptr;
// clang-format off
Expand All @@ -43,11 +43,10 @@ PostgresFunctionGuard(FuncType postgres_function, FuncArgs... args) {
MemoryContextSwitchTo(ctx);
edata = CopyErrorData();
FlushErrorState();
error = true;
}
PG_END_TRY();
// clang-format on
if (error) {
if (edata) {
throw duckdb::Exception(duckdb::ExceptionType::EXECUTOR, edata->message);
}
return return_value;
Expand All @@ -56,7 +55,6 @@ PostgresFunctionGuard(FuncType postgres_function, FuncArgs... args) {
template <typename FuncType, typename... FuncArgs>
void
PostgresFunctionGuard(FuncType postgres_function, FuncArgs... args) {
bool error = false;
MemoryContext ctx = CurrentMemoryContext;
ErrorData *edata = nullptr;
// clang-format off
Expand All @@ -69,13 +67,33 @@ PostgresFunctionGuard(FuncType postgres_function, FuncArgs... args) {
MemoryContextSwitchTo(ctx);
edata = CopyErrorData();
FlushErrorState();
error = true;
}
PG_END_TRY();
// clang-format on
if (error) {
if (edata) {
throw duckdb::Exception(duckdb::ExceptionType::EXECUTOR, edata->message);
}
}

template <typename FuncType, typename... FuncArgs>
const char*
DuckDBFunctionGuard(FuncType duckdb_function, FuncArgs... args) {
try {
duckdb_function(args...);
} catch (duckdb::Exception &ex) {
duckdb::ErrorData edata(ex.what());
return pstrdup(edata.Message().c_str());
} catch (std::exception &ex) {
const auto msg = ex.what();
if (msg[0] == '{') {
duckdb::ErrorData edata(ex.what());
return pstrdup(edata.Message().c_str());
} else {
return pstrdup(ex.what());
}
}

return nullptr;
}

} // namespace pgduckdb
55 changes: 40 additions & 15 deletions src/pgduckdb_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ extern "C" {
#include "pgduckdb/pgduckdb_types.hpp"
#include "pgduckdb/pgduckdb_duckdb.hpp"
#include "pgduckdb/pgduckdb_planner.hpp"
#include "pgduckdb/pgduckdb_utils.hpp"

/* global variables */
CustomScanMethods duckdb_scan_scan_methods;
Expand All @@ -36,9 +37,21 @@ typedef struct DuckdbScanState {

static void
CleanupDuckdbScanState(DuckdbScanState *state) {
MemoryContextReset(state->css.ss.ps.ps_ExprContext->ecxt_per_tuple_memory);
ExecClearTuple(state->css.ss.ss_ScanTupleSlot);

state->query_results.reset();
delete state->prepared_statement;
delete state->duckdb_connection;
state->current_data_chunk.reset();

if (state->prepared_statement) {
delete state->prepared_statement;
state->prepared_statement = nullptr;
}

if (state->duckdb_connection) {
delete state->duckdb_connection;
state->duckdb_connection = nullptr;
}
}

/* static callbacks */
Expand Down Expand Up @@ -92,25 +105,28 @@ ExecuteQuery(DuckdbScanState *state) {
ParamExternData tmp_workspace;

/* give hook a chance in case parameter is dynamic */
if (pg_params->paramFetch != NULL)
if (pg_params->paramFetch != NULL) {
pg_param = pg_params->paramFetch(pg_params, i + 1, false, &tmp_workspace);
else
} else {
pg_param = &pg_params->params[i];
}

if (pg_param->isnull) {
duckdb_params.push_back(duckdb::Value());
} else {
if (!OidIsValid(pg_param->ptype)) {
elog(ERROR, "parameter with invalid type during execution");
}
} else if (OidIsValid(pg_param->ptype)) {
duckdb_params.push_back(pgduckdb::ConvertPostgresParameterToDuckValue(pg_param->value, pg_param->ptype));
} else {
std::ostringstream oss;
oss << "parameter " << i << " has an invalid type (" << pg_param->ptype << ") during query execution";
throw duckdb::Exception(duckdb::ExceptionType::EXECUTOR, oss.str().c_str());
}
}

auto pending = prepared.PendingQuery(duckdb_params, true);
if (pending->HasError()) {
elog(ERROR, "DuckDB execute returned an error: %s", pending->GetError().c_str());
return pending->ThrowError();
}

duckdb::PendingExecutionResult execution_result;
do {
execution_result = pending->ExecuteTask();
Expand All @@ -121,16 +137,16 @@ ExecuteQuery(DuckdbScanState *state) {
// Wait for all tasks to terminate
executor.CancelTasks();
// Delete the scan state
CleanupDuckdbScanState(state);
// Process the interrupt on the Postgres side
ProcessInterrupts();
elog(ERROR, "Query cancelled");
throw duckdb::Exception(duckdb::ExceptionType::EXECUTOR, "Query cancelled");
}
} while (!duckdb::PendingQueryResult::IsResultReady(execution_result));

if (execution_result == duckdb::PendingExecutionResult::EXECUTION_ERROR) {
CleanupDuckdbScanState(state);
elog(ERROR, "(PGDuckDB/ExecuteQuery) %s", pending->GetError().c_str());
return pending->ThrowError();
}

query_results = pending->Execute();
state->column_count = query_results->ColumnCount();
state->is_executed = true;
Expand All @@ -144,7 +160,11 @@ Duckdb_ExecCustomScan(CustomScanState *node) {

bool already_executed = duckdb_scan_state->is_executed;
if (!already_executed) {
ExecuteQuery(duckdb_scan_state);
auto err_msg = pgduckdb::DuckDBFunctionGuard(ExecuteQuery, duckdb_scan_state);
if (err_msg) {
Duckdb_EndCustomScan(node);
elog(ERROR, "(PGDuckDB/ExecuteQuery) %s", err_msg);
}
}

if (duckdb_scan_state->fetch_next) {
Expand Down Expand Up @@ -204,7 +224,12 @@ Duckdb_ReScanCustomScan(CustomScanState *node) {
void
Duckdb_ExplainCustomScan(CustomScanState *node, List *ancestors, ExplainState *es) {
DuckdbScanState *duckdb_scan_state = (DuckdbScanState *)node;
ExecuteQuery(duckdb_scan_state);
auto err_msg = pgduckdb::DuckDBFunctionGuard(ExecuteQuery, duckdb_scan_state);
if (err_msg) {
Duckdb_EndCustomScan(node);
elog(ERROR, "(PGDuckDB/Duckdb_ExecCustomScan) %s", err_msg);
}

auto chunk = duckdb_scan_state->query_results->Fetch();
if (!chunk || chunk->size() == 0) {
return;
Expand Down

0 comments on commit 5dd26f3

Please sign in to comment.