diff --git a/tools/pythonpkg/duckdb-stubs/__init__.pyi b/tools/pythonpkg/duckdb-stubs/__init__.pyi index 7d67657635dd..4c013edab718 100644 --- a/tools/pythonpkg/duckdb-stubs/__init__.pyi +++ b/tools/pythonpkg/duckdb-stubs/__init__.pyi @@ -56,7 +56,6 @@ import polars # stubgen override - This should probably not be exposed apilevel: str comment: token_type -default_connection: DuckDBPyConnection identifier: token_type keyword: token_type numeric_const: token_type @@ -600,6 +599,8 @@ class token_type: def __members__(self) -> object: ... def connect(database: Union[str, Path] = ..., read_only: bool = ..., config: dict = ...) -> DuckDBPyConnection: ... +def default_connection() -> DuckDBPyConnection: ... +def set_default_connection(connection: DuckDBPyConnection) -> None: ... def tokenize(query: str) -> List[Any]: ... # NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py. diff --git a/tools/pythonpkg/duckdb/__init__.py b/tools/pythonpkg/duckdb/__init__.py index c0f89df87a16..e6c6cc4b3cb3 100644 --- a/tools/pythonpkg/duckdb/__init__.py +++ b/tools/pythonpkg/duckdb/__init__.py @@ -251,7 +251,6 @@ __version__, apilevel, comment, - default_connection, identifier, keyword, numeric_const, @@ -269,7 +268,6 @@ "__version__", "apilevel", "comment", - "default_connection", "identifier", "keyword", "numeric_const", @@ -283,11 +281,15 @@ from .duckdb import ( - connect + connect, + default_connection, + set_default_connection, ) _exported_symbols.extend([ - "connect" + "connect", + "default_connection", + "set_default_connection", ]) # Exceptions diff --git a/tools/pythonpkg/duckdb_python.cpp b/tools/pythonpkg/duckdb_python.cpp index b08300a1e6c1..e059c7dfb95a 100644 --- a/tools/pythonpkg/duckdb_python.cpp +++ b/tools/pythonpkg/duckdb_python.cpp @@ -1078,7 +1078,11 @@ PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m) { // NOLINT m.attr("__git_revision__") = DuckDB::SourceID(); m.attr("__interactive__") = DuckDBPyConnection::DetectAndGetEnvironment(); m.attr("__jupyter__") = DuckDBPyConnection::IsJupyter(); - m.attr("default_connection") = DuckDBPyConnection::DefaultConnection(); + m.def("default_connection", &DuckDBPyConnection::DefaultConnection, + "Retrieve the connection currently registered as the default to be used by the module"); + m.def("set_default_connection", &DuckDBPyConnection::SetDefaultConnection, + "Register the provided connection as the default to be used by the module", + py::arg("connection").none(false)); m.attr("apilevel") = "2.0"; m.attr("threadsafety") = 1; m.attr("paramstyle") = "qmark"; diff --git a/tools/pythonpkg/src/include/duckdb_python/pyconnection/pyconnection.hpp b/tools/pythonpkg/src/include/duckdb_python/pyconnection/pyconnection.hpp index 13acaffe5883..50c1d2d0dcef 100644 --- a/tools/pythonpkg/src/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/pyconnection/pyconnection.hpp @@ -41,6 +41,28 @@ class RegisteredArrow : public RegisteredObject { unique_ptr arrow_factory; }; +struct DefaultConnectionHolder { +public: + DefaultConnectionHolder() { + } + ~DefaultConnectionHolder() { + } + +public: + DefaultConnectionHolder(const DefaultConnectionHolder &other) = delete; + DefaultConnectionHolder(DefaultConnectionHolder &&other) = delete; + DefaultConnectionHolder &operator=(const DefaultConnectionHolder &other) = delete; + DefaultConnectionHolder &operator=(DefaultConnectionHolder &&other) = delete; + +public: + shared_ptr Get(); + void Set(shared_ptr conn); + +private: + shared_ptr connection; + mutex l; +}; + struct ConnectionGuard { public: ConnectionGuard() { @@ -161,6 +183,7 @@ struct DuckDBPyConnection : public enable_shared_from_this { static bool DetectAndGetEnvironment(); static bool IsJupyter(); static shared_ptr DefaultConnection(); + static void SetDefaultConnection(shared_ptr conn); static PythonImportCache *ImportCache(); static bool IsInteractive(); @@ -310,7 +333,7 @@ struct DuckDBPyConnection : public enable_shared_from_this { bool FileSystemIsRegistered(const string &name); //! Default connection to an in-memory database - static shared_ptr default_connection; + static DefaultConnectionHolder default_connection; //! Caches and provides an interface to get frequently used modules+subtypes static shared_ptr import_cache; diff --git a/tools/pythonpkg/src/pyconnection.cpp b/tools/pythonpkg/src/pyconnection.cpp index 7c52e0241502..8a8521a49a54 100644 --- a/tools/pythonpkg/src/pyconnection.cpp +++ b/tools/pythonpkg/src/pyconnection.cpp @@ -66,7 +66,7 @@ namespace duckdb { -shared_ptr DuckDBPyConnection::default_connection = nullptr; // NOLINT: allow global +DefaultConnectionHolder DuckDBPyConnection::default_connection; // NOLINT: allow global DBInstanceCache instance_cache; // NOLINT: allow global shared_ptr DuckDBPyConnection::import_cache = nullptr; // NOLINT: allow global PythonEnvironmentType DuckDBPyConnection::environment = PythonEnvironmentType::NORMAL; // NOLINT: allow global @@ -1753,6 +1753,20 @@ void DuckDBPyConnection::LoadExtension(const string &extension) { ExtensionHelper::LoadExternalExtension(*connection.context, extension); } +shared_ptr DefaultConnectionHolder::Get() { + lock_guard guard(l); + if (!connection) { + py::dict config_dict; + connection = DuckDBPyConnection::Connect(py::str(":memory:"), false, config_dict); + } + return connection; +} + +void DefaultConnectionHolder::Set(shared_ptr conn) { + lock_guard guard(l); + connection = conn; +} + void DuckDBPyConnection::Cursors::AddCursor(shared_ptr conn) { lock_guard l(lock); @@ -2025,11 +2039,11 @@ case_insensitive_map_t DuckDBPyConnection::TransformPythonPa } shared_ptr DuckDBPyConnection::DefaultConnection() { - if (!default_connection) { - py::dict config_dict; - default_connection = DuckDBPyConnection::Connect(py::str(":memory:"), false, config_dict); - } - return default_connection; + return default_connection.Get(); +} + +void DuckDBPyConnection::SetDefaultConnection(shared_ptr connection) { + return default_connection.Set(std::move(connection)); } PythonImportCache *DuckDBPyConnection::ImportCache() { @@ -2074,7 +2088,7 @@ void DuckDBPyConnection::Exit(DuckDBPyConnection &self, const py::object &exc_ty } void DuckDBPyConnection::Cleanup() { - default_connection.reset(); + default_connection.Set(nullptr); import_cache.reset(); } diff --git a/tools/pythonpkg/tests/fast/relational_api/test_rapi_query.py b/tools/pythonpkg/tests/fast/relational_api/test_rapi_query.py index 9707dc368321..7d04b04b1ce4 100644 --- a/tools/pythonpkg/tests/fast/relational_api/test_rapi_query.py +++ b/tools/pythonpkg/tests/fast/relational_api/test_rapi_query.py @@ -6,17 +6,27 @@ @pytest.fixture() def tbl_table(): - con = duckdb.default_connection + con = duckdb.default_connection() con.execute("drop table if exists tbl CASCADE") con.execute("create table tbl (i integer)") yield con.execute('drop table tbl CASCADE') +@pytest.fixture() +def scoped_default(duckdb_cursor): + default = duckdb.connect(':default:') + duckdb.set_default_connection(duckdb_cursor) + # Overwrite the default connection + yield + # Set it back on finalizing of the function + duckdb.set_default_connection(default) + + class TestRAPIQuery(object): @pytest.mark.parametrize('steps', [1, 2, 3, 4]) def test_query_chain(self, steps): - con = duckdb.default_connection + con = duckdb.default_connection() amount = int(1000000) rel = None for _ in range(steps): @@ -28,7 +38,7 @@ def test_query_chain(self, steps): @pytest.mark.parametrize('input', [[5, 4, 3], [], [1000]]) def test_query_table(self, tbl_table, input): - con = duckdb.default_connection + con = duckdb.default_connection() rel = con.table("tbl") for row in input: rel.insert([row]) @@ -38,7 +48,7 @@ def test_query_table(self, tbl_table, input): assert result.fetchall() == [tuple([x]) for x in input] def test_query_table_basic(self, tbl_table): - con = duckdb.default_connection + con = duckdb.default_connection() rel = con.table("tbl") # Querying a table relation rel = rel.query("x", "select 5") @@ -46,7 +56,7 @@ def test_query_table_basic(self, tbl_table): assert result.fetchall() == [(5,)] def test_query_table_qualified(self, duckdb_cursor): - con = duckdb.default_connection + con = duckdb.default_connection() con.execute("create schema fff") # Create table in fff schema @@ -54,7 +64,7 @@ def test_query_table_qualified(self, duckdb_cursor): assert con.table("fff.t2").fetchall() == [(1,)] def test_query_insert_into_relation(self, tbl_table): - con = duckdb.default_connection + con = duckdb.default_connection() rel = con.query("select i from range(1000) tbl(i)") # Can't insert into this, not a table relation with pytest.raises(duckdb.InvalidInputException): @@ -79,7 +89,7 @@ def test_query_non_select_fail(self, duckdb_cursor): rel.query("relation", "create table tbl as select * from not_a_valid_view") def test_query_table_unrelated(self, tbl_table): - con = duckdb.default_connection + con = duckdb.default_connection() rel = con.table("tbl") # Querying a table relation rel = rel.query("x", "select 5") @@ -131,3 +141,52 @@ def test_replacement_scan_recursion(self, duckdb_cursor): other_rel = duckdb_cursor.sql('select a from rel') res = other_rel.fetchall() assert res == [(84,)] + + def test_set_default_connection(self, scoped_default): + duckdb.sql("create table t as select 42") + assert duckdb.table('t').fetchall() == [(42,)] + con = duckdb.connect(':default:') + + # Uses the same db as the module + assert con.table('t').fetchall() == [(42,)] + + con2 = duckdb.connect() + con2.sql("create table t as select 21") + assert con2.table('t').fetchall() == [(21,)] + # Change the db used by the module + duckdb.set_default_connection(con2) + + with pytest.raises(duckdb.CatalogException, match='Table with name d does not exist'): + con2.table('d').fetchall() + + assert duckdb.table('t').fetchall() == [(21,)] + + duckdb.sql("create table d as select [1,2,3]") + + assert duckdb.table('d').fetchall() == [([1, 2, 3],)] + assert con2.table('d').fetchall() == [([1, 2, 3],)] + + def test_set_default_connection_error(self, scoped_default): + with pytest.raises(TypeError, match='Invoked with: None'): + # set_default_connection does not allow None + duckdb.set_default_connection(None) + + with pytest.raises(TypeError, match='Invoked with: 5'): + duckdb.set_default_connection(5) + + assert duckdb.sql("select 42").fetchall() == [(42,)] + duckdb.close() + + with pytest.raises(duckdb.ConnectionException, match='Connection Error: Connection already closed!'): + duckdb.sql("select 42").fetchall() + + con2 = duckdb.connect() + duckdb.set_default_connection(con2) + assert duckdb.sql("select 42").fetchall() == [(42,)] + + con3 = duckdb.connect() + con3.close() + duckdb.set_default_connection(con3) + + with pytest.raises(duckdb.ConnectionException, match='Connection Error: Connection already closed!'): + duckdb.sql("select 42").fetchall() diff --git a/tools/pythonpkg/tests/fast/test_parameter_list.py b/tools/pythonpkg/tests/fast/test_parameter_list.py index 2421c76e08e1..6db0325c42be 100644 --- a/tools/pythonpkg/tests/fast/test_parameter_list.py +++ b/tools/pythonpkg/tests/fast/test_parameter_list.py @@ -25,6 +25,6 @@ def test_exception(self, duckdb_cursor, pandas): res = conn.execute("select count(*) from bool_table where a =?", [df_in]) def test_explicit_nan_param(self): - con = duckdb.default_connection + con = duckdb.default_connection() res = con.execute('select isnan(cast(? as double))', (float("nan"),)) assert res.fetchone()[0] == True diff --git a/tools/pythonpkg/tests/fast/test_parquet.py b/tools/pythonpkg/tests/fast/test_parquet.py index 498fe53ef3c2..51d8d27677f7 100644 --- a/tools/pythonpkg/tests/fast/test_parquet.py +++ b/tools/pythonpkg/tests/fast/test_parquet.py @@ -130,7 +130,7 @@ def test_parquet_binary_as_string_pragma(self, duckdb_cursor): assert res[0] == (b'foo',) def test_from_parquet_binary_as_string_default_conn(self, duckdb_cursor): - duckdb.default_connection.execute("PRAGMA binary_as_string=1") + duckdb.execute("PRAGMA binary_as_string=1") rel = duckdb.from_parquet(filename, True) assert rel.types == [VARCHAR]