Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion dlt/sources/sql_database/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Source that loads tables form any SQLAlchemy supported database, supports batching requests and incremental loads."""

from typing import Callable, Dict, List, Optional, Union, Iterable, Any
from typing import Callable, Dict, List, Optional, Union, Iterable, Any, Type

import dlt
from dlt.common.configuration.specs import ConnectionStringCredentials
Expand All @@ -20,6 +20,7 @@
_detect_precision_hints_deprecated,
TQueryAdapter,
TTableAdapter,
BaseTableLoader,
)
from .schema_types import (
table_to_resource_hints,
Expand All @@ -46,6 +47,7 @@ def sql_database(
query_adapter_callback: Optional[TQueryAdapter] = None,
resolve_foreign_keys: bool = False,
engine_adapter_callback: Optional[Callable[[Engine], Engine]] = None,
table_loader_class: Optional[Type[BaseTableLoader]] = None,
) -> Iterable[DltResource]:
"""
A dlt source which loads data from an SQL database using SQLAlchemy.
Expand Down Expand Up @@ -96,6 +98,9 @@ def sql_database(
engine_adapter_callback (Optional[Callable[[Engine], Engine]]): Callback to configure, modify and Engine instance that will be used to open a connection ie. to
set transaction isolation level.

table_loader_class (Optional[Type[BaseTableLoader]]): Custom TableLoader class to use for loading data from tables.
Must inherit from BaseTableLoader and implement the required abstract methods.

Yields:
DltResource: DLT resources for each table to be loaded.
"""
Expand Down Expand Up @@ -150,6 +155,7 @@ def sql_database(
query_adapter_callback=query_adapter_callback,
resolve_foreign_keys=resolve_foreign_keys,
engine_adapter_callback=engine_adapter_callback,
table_loader_class=table_loader_class,
)


Expand All @@ -173,6 +179,7 @@ def sql_table(
resolve_foreign_keys: bool = False,
engine_adapter_callback: Callable[[Engine], Engine] = None,
write_disposition: TWriteDispositionConfig = "append",
table_loader_class: Optional[Type[BaseTableLoader]] = None,
primary_key: TColumnNames = None,
merge_key: TColumnNames = None,
) -> DltResource:
Expand Down Expand Up @@ -228,6 +235,10 @@ def sql_table(
set transaction isolation level.

write_disposition (TWriteDispositionConfig): write disposition of the table resource, defaults to `append`.

table_loader_class (Optional[Type[BaseTableLoader]]): Custom TableLoader class to use for loading data from this table.
Must inherit from BaseTableLoader and implement the required abstract methods.

primary_key (TColumnNames): A list of column names that comprise a private key. Typically used with "merge" write disposition to deduplicate loaded data.
merge_key (TColumnNames): A list of column names that define a merge key. Typically used with "merge" write disposition to remove overlapping data ranges ie. to
keep a single record for a given day.
Expand Down Expand Up @@ -295,6 +306,7 @@ def sql_table(
included_columns=included_columns,
query_adapter_callback=query_adapter_callback,
resolve_foreign_keys=resolve_foreign_keys,
table_loader_class=table_loader_class,
)


Expand All @@ -308,4 +320,5 @@ def sql_table(
"TableBackend",
"TQueryAdapter",
"TTableAdapter",
"BaseTableLoader",
]
28 changes: 25 additions & 3 deletions dlt/sources/sql_database/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Union,
)
import operator
from abc import ABC, abstractmethod

import dlt
from dlt.common.configuration.specs import (
Expand Down Expand Up @@ -57,7 +58,9 @@
TTableAdapter = Callable[[Table], Optional[Union[SelectAny, Table]]]


class TableLoader:
class BaseTableLoader(ABC):
"""Abstract base class for TableLoader implementations."""

def __init__(
self,
engine: Engine,
Expand All @@ -75,6 +78,8 @@ def __init__(
self.chunk_size = chunk_size
self.query_adapter_callback = query_adapter_callback
self.incremental = incremental

# Initialize incremental-related attributes
if incremental:
column_name = extract_simple_field_name(incremental.cursor_path)

Expand Down Expand Up @@ -106,6 +111,20 @@ def __init__(
self.range_start = None
self.range_end = None

@abstractmethod
def make_query(self) -> SelectClause:
"""Create the query to be executed."""
...

@abstractmethod
def load_rows(self, backend_kwargs: Optional[Dict[str, Any]]) -> Iterator[TDataItem]:
"""Load rows from the table and yield them as data items."""
...


class TableLoader(BaseTableLoader):
"""Default TableLoader implementation for SQL database sources."""

def _make_query(self) -> SelectAny:
table = self.table
query = table.select()
Expand Down Expand Up @@ -169,7 +188,7 @@ def make_query(self) -> SelectClause:

return self._make_query()

def load_rows(self, backend_kwargs: Dict[str, Any] = None) -> Iterator[TDataItem]:
def load_rows(self, backend_kwargs: Optional[Dict[str, Any]] = None) -> Iterator[TDataItem]:
# make copy of kwargs
backend_kwargs = dict(backend_kwargs or {})
query = self.make_query()
Expand Down Expand Up @@ -262,6 +281,7 @@ def table_rows(
included_columns: Optional[List[str]],
query_adapter_callback: Optional[TQueryAdapter],
resolve_foreign_keys: bool,
table_loader_class: Optional[type[BaseTableLoader]] = None,
) -> Iterator[TDataItem]:
if isinstance(table, str): # Reflection is deferred
table = Table(
Expand Down Expand Up @@ -303,7 +323,9 @@ def table_rows(
resolve_foreign_keys=resolve_foreign_keys,
)

loader = TableLoader(
# Use custom table loader class if provided, otherwise use default
loader_class = table_loader_class or TableLoader
loader = loader_class(
engine,
backend,
table,
Expand Down