Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

connect(): Added support for shared connection; Database.is_closed property #323

Merged
merged 1 commit into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions data_diff/sqeleton/abcs/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ def close(self):
"Close connection(s) to the database instance. Querying will stop functioning."
...

@property
@abstractmethod
def is_closed(self) -> bool:
"Return whether or not the connection has been closed"

@abstractmethod
def _normalize_table_path(self, path: DbPath) -> DbPath:
...
Expand Down
6 changes: 6 additions & 0 deletions data_diff/sqeleton/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ class Database(AbstractDatabase):
CONNECT_URI_KWPARAMS = []

_interactive = False
is_closed = False

@property
def name(self):
Expand Down Expand Up @@ -440,6 +441,10 @@ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> lis
callback = partial(self._query_cursor, c)
return apply_query(callback, sql_code)

def close(self):
self.is_closed = True
return super().close()


class ThreadedDatabase(Database):
"""Access the database through singleton threads.
Expand Down Expand Up @@ -476,6 +481,7 @@ def create_connection(self):
...

def close(self):
super().close()
self._queue.shutdown()

@property
Expand Down
1 change: 1 addition & 0 deletions data_diff/sqeleton/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def _query(self, sql_code: Union[str, ThreadLocalInterpreter]):
return apply_query(self._query_atom, sql_code)

def close(self):
super().close()
self._client.close()

def select_table_schema(self, path: DbPath) -> str:
Expand Down
26 changes: 21 additions & 5 deletions data_diff/sqeleton/databases/connect.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Type, List, Optional, Union, Dict
from itertools import zip_longest
import dsnparse
from contextlib import suppress

from runtype import dataclass

from ..utils import WeakCache
from .base import Database, ThreadedDatabase
from .postgresql import PostgreSQL
from .mysql import MySQL
Expand All @@ -19,12 +21,13 @@
from .duckdb import DuckDB



@dataclass
class MatchUriPath:
database_cls: Type[Database]
params: List[str]
kwparams: List[str] = []
help_str: str
help_str: str = "<unspecified>"

def __post_init__(self):
assert self.params == self.database_cls.CONNECT_URI_PARAMS, self.params
Expand Down Expand Up @@ -101,6 +104,7 @@ def __init__(self, database_by_scheme: Dict[str, Database]):
name: MatchUriPath(cls, cls.CONNECT_URI_PARAMS, cls.CONNECT_URI_KWPARAMS, help_str=cls.CONNECT_URI_HELP)
for name, cls in database_by_scheme.items()
}
self.conn_cache = WeakCache()

def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Database:
"""Connect to the given database uri
Expand Down Expand Up @@ -200,7 +204,7 @@ def _connection_created(self, db):
"Nop function to be overridden by subclasses."
return db

def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Database:
def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True) -> Database:
"""Connect to a database using the given database configuration.

Configuration can be given either as a URI string, or as a dict of {option: value}.
Expand All @@ -213,6 +217,7 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -
Parameters:
db_conf (str | dict): The configuration for the database to connect. URI or dict.
thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1)
shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True)

Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck.

Expand All @@ -235,8 +240,19 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -
>>> connect({"driver": "mysql", "host": "localhost", "database": "db"})
<data_diff.databases.mysql.MySQL object at 0x0000025DB3F94820>
"""
if shared:
with suppress(KeyError):
conn = self.conn_cache.get(db_conf)
if not conn.is_closed:
return conn

if isinstance(db_conf, str):
return self.connect_to_uri(db_conf, thread_count)
conn = self.connect_to_uri(db_conf, thread_count)
elif isinstance(db_conf, dict):
return self.connect_with_dict(db_conf, thread_count)
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")
conn = self.connect_with_dict(db_conf, thread_count)
else:
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")

if shared:
self.conn_cache.add(db_conf, conn)
return conn
1 change: 1 addition & 0 deletions data_diff/sqeleton/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _query(self, sql_code: Union[str, ThreadLocalInterpreter]):
return self._query_conn(self._conn, sql_code)

def close(self):
super().close()
self._conn.close()

def create_connection(self):
Expand Down
1 change: 1 addition & 0 deletions data_diff/sqeleton/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _query(self, sql_code: str) -> list:
return query_cursor(c, sql_code)

def close(self):
super().close()
self._conn.close()

def select_table_schema(self, path: DbPath) -> str:
Expand Down
1 change: 1 addition & 0 deletions data_diff/sqeleton/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(self, *, schema: str, **kw):
self.default_schema = schema

def close(self):
super().close()
self._conn.close()

def _query(self, sql_code: Union[str, ThreadLocalInterpreter]):
Expand Down
26 changes: 26 additions & 0 deletions data_diff/sqeleton/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union, Dict, Any, Hashable
from weakref import ref
from typing import TypeVar
from typing import Iterable, Iterator, MutableMapping, Union, Any, Sequence, Dict
from abc import abstractmethod
Expand All @@ -9,6 +11,30 @@
# -- Common --


class WeakCache:
def __init__(self):
self._cache = {}

def _hashable_key(self, k: Union[dict, Hashable]) -> Hashable:
if isinstance(k, dict):
return tuple(k.items())
return k

def add(self, key: Union[dict, Hashable], value: Any):
key = self._hashable_key(key)
self._cache[key] = ref(value)

def get(self, key: Union[dict, Hashable]) -> Any:
key = self._hashable_key(key)

value = self._cache[key]()
if value is None:
del self._cache[key]
raise KeyError(f"Key {key} not found, or no longer a valid reference")

return value


def join_iter(joiner: Any, iterable: Iterable) -> Iterable:
it = iter(iterable)
try:
Expand Down
18 changes: 8 additions & 10 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import arrow
from datetime import datetime

from data_diff import diff_tables, connect_to_table
from data_diff import diff_tables, connect_to_table, Algorithm
from data_diff.databases import MySQL
from data_diff.sqeleton.queries import table, commit

Expand Down Expand Up @@ -36,13 +36,17 @@ def setUp(self) -> None:
)

def test_api(self):
# test basic
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name)
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, (self.table_dst_name,))
diff = list(diff_tables(t1, t2))
diff = list(diff_tables(t1, t2, algorithm=Algorithm.JOINDIFF))
assert len(diff) == 1

t1.database.close()
t2.database.close()
# test algorithm
# (also tests shared connection on connect_to_table)
for algo in (Algorithm.HASHDIFF, Algorithm.JOINDIFF):
diff = list(diff_tables(t1, t2, algorithm=algo))
assert len(diff) == 1

# test where
diff_id = diff[0][1][0]
Expand All @@ -53,9 +57,6 @@ def test_api(self):
diff = list(diff_tables(t1, t2))
assert len(diff) == 0

t1.database.close()
t2.database.close()

def test_api_get_stats_dict(self):
# XXX Likely to change in the future
expected_dict = {
Expand All @@ -76,6 +77,3 @@ def test_api_get_stats_dict(self):
self.assertEqual(expected_dict, output)
self.assertIsNotNone(diff)
assert len(list(diff)) == 1

t1.database.close()
t2.database.close()