Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Tolerate empty samples & allow custom database schemas #802

Merged
merged 3 commits into from
Dec 15, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Accept external database types/schemes at runtime
  • Loading branch information
Sergey Vasilyev committed Dec 15, 2023
commit 6c9ab5fdf68f19994a45ac6d7d90a8dfb743cc9a
14 changes: 5 additions & 9 deletions data_diff/databases/_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from data_diff.databases.mssql import MsSQL


@attrs.define(frozen=True)
@attrs.frozen
class MatchUriPath:
database_cls: Type[Database]

Expand Down Expand Up @@ -98,13 +98,11 @@ class Connect:
"""Provides methods for connecting to a supported database using a URL or connection dict."""

database_by_scheme: Dict[str, Database]
match_uri_path: Dict[str, MatchUriPath]
conn_cache: MutableMapping[Hashable, Database]

def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
super().__init__()
self.database_by_scheme = database_by_scheme
self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()}
self.conn_cache = weakref.WeakValueDictionary()

def for_databases(self, *dbs) -> Self:
Expand Down Expand Up @@ -157,12 +155,10 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs)
return self.connect_with_dict(conn_dict, thread_count, **kwargs)

try:
matcher = self.match_uri_path[scheme]
cls = self.database_by_scheme[scheme]
except KeyError:
raise NotImplementedError(f"Scheme '{scheme}' currently not supported")

cls = matcher.database_cls

if scheme == "databricks":
assert not dsn.user
kw = {}
Expand All @@ -175,6 +171,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs)
kw["filepath"] = dsn.dbname
kw["dbname"] = dsn.user
else:
matcher = MatchUriPath(cls)
kw = matcher.match_path(dsn)

if scheme == "bigquery":
Expand All @@ -198,7 +195,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs)

kw = {k: v for k, v in kw.items() if v is not None}

if issubclass(cls, ThreadedDatabase):
if isinstance(cls, type) and issubclass(cls, ThreadedDatabase):
db = cls(thread_count=thread_count, **kw, **kwargs)
else:
db = cls(**kw, **kwargs)
Expand All @@ -209,11 +206,10 @@ def connect_with_dict(self, d, thread_count, **kwargs):
d = dict(d)
driver = d.pop("driver")
try:
matcher = self.match_uri_path[driver]
cls = self.database_by_scheme[driver]
except KeyError:
raise NotImplementedError(f"Driver '{driver}' currently not supported")

cls = matcher.database_cls
if issubclass(cls, ThreadedDatabase):
db = cls(thread_count=thread_count, **d, **kwargs)
else:
Expand Down