Skip to content

Commit

Permalink
[Python] Add set_default_connection to the duckdb module (duckdb#…
Browse files Browse the repository at this point in the history
…13442)

This method allows the user to change the connection that's used by the
module internally when invoking DuckDBPyConnection methods (such as
`.sql`, `.table` etc..) directly on the `duckdb` module.

I also considered adding a context manager for this, to basically create
a scope where the default_connection is overridden and reset back after
the scope ends, that might be a future PR.

### Incompatible change:
Hopefully nobody cares but I figured I'd mention it regardless,
`default_connection` is no longer an `attribute` on the module, instead
it's now a method.
It doesn't seem to be possible to create a property that can have
getters and setters on the module so turning it into
`default_connection()` and `set_default_connection(connection:
DuckDBPyConnection)` was the next best thing
  • Loading branch information
Mytherin authored Sep 25, 2024
2 parents 4198e08 + 48bcfc1 commit 403f944
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 23 deletions.
3 changes: 2 additions & 1 deletion tools/pythonpkg/duckdb-stubs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ import polars
# stubgen override - This should probably not be exposed
apilevel: str
comment: token_type
default_connection: DuckDBPyConnection
identifier: token_type
keyword: token_type
numeric_const: token_type
Expand Down Expand Up @@ -600,6 +599,8 @@ class token_type:
def __members__(self) -> object: ...

def connect(database: Union[str, Path] = ..., read_only: bool = ..., config: dict = ...) -> DuckDBPyConnection: ...
def default_connection() -> DuckDBPyConnection: ...
def set_default_connection(connection: DuckDBPyConnection) -> None: ...
def tokenize(query: str) -> List[Any]: ...

# NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py.
Expand Down
10 changes: 6 additions & 4 deletions tools/pythonpkg/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@
__version__,
apilevel,
comment,
default_connection,
identifier,
keyword,
numeric_const,
Expand All @@ -269,7 +268,6 @@
"__version__",
"apilevel",
"comment",
"default_connection",
"identifier",
"keyword",
"numeric_const",
Expand All @@ -283,11 +281,15 @@


from .duckdb import (
connect
connect,
default_connection,
set_default_connection,
)

_exported_symbols.extend([
"connect"
"connect",
"default_connection",
"set_default_connection",
])

# Exceptions
Expand Down
6 changes: 5 additions & 1 deletion tools/pythonpkg/duckdb_python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,11 @@ PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m) { // NOLINT
m.attr("__git_revision__") = DuckDB::SourceID();
m.attr("__interactive__") = DuckDBPyConnection::DetectAndGetEnvironment();
m.attr("__jupyter__") = DuckDBPyConnection::IsJupyter();
m.attr("default_connection") = DuckDBPyConnection::DefaultConnection();
m.def("default_connection", &DuckDBPyConnection::DefaultConnection,
"Retrieve the connection currently registered as the default to be used by the module");
m.def("set_default_connection", &DuckDBPyConnection::SetDefaultConnection,
"Register the provided connection as the default to be used by the module",
py::arg("connection").none(false));
m.attr("apilevel") = "2.0";
m.attr("threadsafety") = 1;
m.attr("paramstyle") = "qmark";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,28 @@ class RegisteredArrow : public RegisteredObject {
unique_ptr<PythonTableArrowArrayStreamFactory> arrow_factory;
};

struct DefaultConnectionHolder {
public:
DefaultConnectionHolder() {
}
~DefaultConnectionHolder() {
}

public:
DefaultConnectionHolder(const DefaultConnectionHolder &other) = delete;
DefaultConnectionHolder(DefaultConnectionHolder &&other) = delete;
DefaultConnectionHolder &operator=(const DefaultConnectionHolder &other) = delete;
DefaultConnectionHolder &operator=(DefaultConnectionHolder &&other) = delete;

public:
shared_ptr<DuckDBPyConnection> Get();
void Set(shared_ptr<DuckDBPyConnection> conn);

private:
shared_ptr<DuckDBPyConnection> connection;
mutex l;
};

struct ConnectionGuard {
public:
ConnectionGuard() {
Expand Down Expand Up @@ -161,6 +183,7 @@ struct DuckDBPyConnection : public enable_shared_from_this<DuckDBPyConnection> {
static bool DetectAndGetEnvironment();
static bool IsJupyter();
static shared_ptr<DuckDBPyConnection> DefaultConnection();
static void SetDefaultConnection(shared_ptr<DuckDBPyConnection> conn);
static PythonImportCache *ImportCache();
static bool IsInteractive();

Expand Down Expand Up @@ -310,7 +333,7 @@ struct DuckDBPyConnection : public enable_shared_from_this<DuckDBPyConnection> {
bool FileSystemIsRegistered(const string &name);

//! Default connection to an in-memory database
static shared_ptr<DuckDBPyConnection> default_connection;
static DefaultConnectionHolder default_connection;
//! Caches and provides an interface to get frequently used modules+subtypes
static shared_ptr<PythonImportCache> import_cache;

Expand Down
28 changes: 21 additions & 7 deletions tools/pythonpkg/src/pyconnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

namespace duckdb {

shared_ptr<DuckDBPyConnection> DuckDBPyConnection::default_connection = nullptr; // NOLINT: allow global
DefaultConnectionHolder DuckDBPyConnection::default_connection; // NOLINT: allow global
DBInstanceCache instance_cache; // NOLINT: allow global
shared_ptr<PythonImportCache> DuckDBPyConnection::import_cache = nullptr; // NOLINT: allow global
PythonEnvironmentType DuckDBPyConnection::environment = PythonEnvironmentType::NORMAL; // NOLINT: allow global
Expand Down Expand Up @@ -1753,6 +1753,20 @@ void DuckDBPyConnection::LoadExtension(const string &extension) {
ExtensionHelper::LoadExternalExtension(*connection.context, extension);
}

shared_ptr<DuckDBPyConnection> DefaultConnectionHolder::Get() {
lock_guard<mutex> guard(l);
if (!connection) {
py::dict config_dict;
connection = DuckDBPyConnection::Connect(py::str(":memory:"), false, config_dict);
}
return connection;
}

void DefaultConnectionHolder::Set(shared_ptr<DuckDBPyConnection> conn) {
lock_guard<mutex> guard(l);
connection = conn;
}

void DuckDBPyConnection::Cursors::AddCursor(shared_ptr<DuckDBPyConnection> conn) {
lock_guard<mutex> l(lock);

Expand Down Expand Up @@ -2025,11 +2039,11 @@ case_insensitive_map_t<BoundParameterData> DuckDBPyConnection::TransformPythonPa
}

shared_ptr<DuckDBPyConnection> DuckDBPyConnection::DefaultConnection() {
if (!default_connection) {
py::dict config_dict;
default_connection = DuckDBPyConnection::Connect(py::str(":memory:"), false, config_dict);
}
return default_connection;
return default_connection.Get();
}

void DuckDBPyConnection::SetDefaultConnection(shared_ptr<DuckDBPyConnection> connection) {
return default_connection.Set(std::move(connection));
}

PythonImportCache *DuckDBPyConnection::ImportCache() {
Expand Down Expand Up @@ -2074,7 +2088,7 @@ void DuckDBPyConnection::Exit(DuckDBPyConnection &self, const py::object &exc_ty
}

void DuckDBPyConnection::Cleanup() {
default_connection.reset();
default_connection.Set(nullptr);
import_cache.reset();
}

Expand Down
73 changes: 66 additions & 7 deletions tools/pythonpkg/tests/fast/relational_api/test_rapi_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,27 @@

@pytest.fixture()
def tbl_table():
con = duckdb.default_connection
con = duckdb.default_connection()
con.execute("drop table if exists tbl CASCADE")
con.execute("create table tbl (i integer)")
yield
con.execute('drop table tbl CASCADE')


@pytest.fixture()
def scoped_default(duckdb_cursor):
default = duckdb.connect(':default:')
duckdb.set_default_connection(duckdb_cursor)
# Overwrite the default connection
yield
# Set it back on finalizing of the function
duckdb.set_default_connection(default)


class TestRAPIQuery(object):
@pytest.mark.parametrize('steps', [1, 2, 3, 4])
def test_query_chain(self, steps):
con = duckdb.default_connection
con = duckdb.default_connection()
amount = int(1000000)
rel = None
for _ in range(steps):
Expand All @@ -28,7 +38,7 @@ def test_query_chain(self, steps):

@pytest.mark.parametrize('input', [[5, 4, 3], [], [1000]])
def test_query_table(self, tbl_table, input):
con = duckdb.default_connection
con = duckdb.default_connection()
rel = con.table("tbl")
for row in input:
rel.insert([row])
Expand All @@ -38,23 +48,23 @@ def test_query_table(self, tbl_table, input):
assert result.fetchall() == [tuple([x]) for x in input]

def test_query_table_basic(self, tbl_table):
con = duckdb.default_connection
con = duckdb.default_connection()
rel = con.table("tbl")
# Querying a table relation
rel = rel.query("x", "select 5")
result = rel.execute()
assert result.fetchall() == [(5,)]

def test_query_table_qualified(self, duckdb_cursor):
con = duckdb.default_connection
con = duckdb.default_connection()
con.execute("create schema fff")

# Create table in fff schema
con.execute("create table fff.t2 as select 1 as t")
assert con.table("fff.t2").fetchall() == [(1,)]

def test_query_insert_into_relation(self, tbl_table):
con = duckdb.default_connection
con = duckdb.default_connection()
rel = con.query("select i from range(1000) tbl(i)")
# Can't insert into this, not a table relation
with pytest.raises(duckdb.InvalidInputException):
Expand All @@ -79,7 +89,7 @@ def test_query_non_select_fail(self, duckdb_cursor):
rel.query("relation", "create table tbl as select * from not_a_valid_view")

def test_query_table_unrelated(self, tbl_table):
con = duckdb.default_connection
con = duckdb.default_connection()
rel = con.table("tbl")
# Querying a table relation
rel = rel.query("x", "select 5")
Expand Down Expand Up @@ -131,3 +141,52 @@ def test_replacement_scan_recursion(self, duckdb_cursor):
other_rel = duckdb_cursor.sql('select a from rel')
res = other_rel.fetchall()
assert res == [(84,)]

def test_set_default_connection(self, scoped_default):
duckdb.sql("create table t as select 42")
assert duckdb.table('t').fetchall() == [(42,)]
con = duckdb.connect(':default:')

# Uses the same db as the module
assert con.table('t').fetchall() == [(42,)]

con2 = duckdb.connect()
con2.sql("create table t as select 21")
assert con2.table('t').fetchall() == [(21,)]
# Change the db used by the module
duckdb.set_default_connection(con2)

with pytest.raises(duckdb.CatalogException, match='Table with name d does not exist'):
con2.table('d').fetchall()

assert duckdb.table('t').fetchall() == [(21,)]

duckdb.sql("create table d as select [1,2,3]")

assert duckdb.table('d').fetchall() == [([1, 2, 3],)]
assert con2.table('d').fetchall() == [([1, 2, 3],)]

def test_set_default_connection_error(self, scoped_default):
with pytest.raises(TypeError, match='Invoked with: None'):
# set_default_connection does not allow None
duckdb.set_default_connection(None)

with pytest.raises(TypeError, match='Invoked with: 5'):
duckdb.set_default_connection(5)

assert duckdb.sql("select 42").fetchall() == [(42,)]
duckdb.close()

with pytest.raises(duckdb.ConnectionException, match='Connection Error: Connection already closed!'):
duckdb.sql("select 42").fetchall()

con2 = duckdb.connect()
duckdb.set_default_connection(con2)
assert duckdb.sql("select 42").fetchall() == [(42,)]

con3 = duckdb.connect()
con3.close()
duckdb.set_default_connection(con3)

with pytest.raises(duckdb.ConnectionException, match='Connection Error: Connection already closed!'):
duckdb.sql("select 42").fetchall()
2 changes: 1 addition & 1 deletion tools/pythonpkg/tests/fast/test_parameter_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ def test_exception(self, duckdb_cursor, pandas):
res = conn.execute("select count(*) from bool_table where a =?", [df_in])

def test_explicit_nan_param(self):
con = duckdb.default_connection
con = duckdb.default_connection()
res = con.execute('select isnan(cast(? as double))', (float("nan"),))
assert res.fetchone()[0] == True
2 changes: 1 addition & 1 deletion tools/pythonpkg/tests/fast/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_parquet_binary_as_string_pragma(self, duckdb_cursor):
assert res[0] == (b'foo',)

def test_from_parquet_binary_as_string_default_conn(self, duckdb_cursor):
duckdb.default_connection.execute("PRAGMA binary_as_string=1")
duckdb.execute("PRAGMA binary_as_string=1")

rel = duckdb.from_parquet(filename, True)
assert rel.types == [VARCHAR]
Expand Down

0 comments on commit 403f944

Please sign in to comment.