Skip to content

Commit

Permalink
[KED-2865] Make sql datasets use a singleton pattern for connection (#…
Browse files Browse the repository at this point in the history
…1163)

Signed-off-by: lorenabalan <lorena.balan@quantumblack.com>
  • Loading branch information
Lorena Bălan authored Feb 3, 2022
1 parent 4e75b7d commit ceae20c
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 159 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Major features and improvements
* `pipeline` now accepts `tags` and a collection of `Node`s and/or `Pipeline`s rather than just a single `Pipeline` object. `pipeline` should be used in preference to `Pipeline` when creating a Kedro pipeline.
* `pandas.SQLTableDataSet` and `pandas.SQLQueryDataSet` now only open one connection per database, at instantiation time (therefore at catalog creation time), rather than one per load/save operation.

## Bug fixes and other changes
* Added tutorial documentation for experiment tracking (`03_tutorial/07_set_up_experiment_tracking.md`).
Expand Down
98 changes: 63 additions & 35 deletions kedro/extras/datasets/pandas/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,11 @@ class SQLTableDataSet(AbstractDataSet):
"""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {"index": False} # type: Dict[str, Any]
DEFAULT_LOAD_ARGS: Dict[str, Any] = {}
DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False}
# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}

def __init__(
self,
Expand Down Expand Up @@ -207,42 +210,50 @@ def __init__(
self._load_args["table_name"] = table_name
self._save_args["name"] = table_name

self._load_args["con"] = self._save_args["con"] = credentials["con"]
self._connection_str = credentials["con"]
self.create_connection(self._connection_str)

@classmethod
def create_connection(cls, connection_str: str) -> None:
"""Given a connection string, create singleton connection
to be used across all instances of `SQLTableDataSet` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return

try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc

cls.engines[connection_str] = engine

def _describe(self) -> Dict[str, Any]:
load_args = self._load_args.copy()
save_args = self._save_args.copy()
load_args = copy.deepcopy(self._load_args)
save_args = copy.deepcopy(self._save_args)
del load_args["table_name"]
del load_args["con"]
del save_args["name"]
del save_args["con"]
return dict(
table_name=self._load_args["table_name"],
load_args=load_args,
save_args=save_args,
)

def _load(self) -> pd.DataFrame:
try:
return pd.read_sql_table(**self._load_args)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
engine = self.engines[self._connection_str] # type:ignore
return pd.read_sql_table(con=engine, **self._load_args)

def _save(self, data: pd.DataFrame) -> None:
try:
data.to_sql(**self._save_args)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
engine = self.engines[self._connection_str] # type: ignore
data.to_sql(con=engine, **self._save_args)

def _exists(self) -> bool:
eng = create_engine(self._load_args["con"])
eng = self.engines[self._connection_str] # type: ignore
schema = self._load_args.get("schema", None)
exists = self._load_args["table_name"] in eng.table_names(schema)
eng.dispose()
return exists


Expand Down Expand Up @@ -299,6 +310,10 @@ class SQLQueryDataSet(AbstractDataSet):
"""

# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}

def __init__( # pylint: disable=too-many-arguments
self,
sql: str = None,
Expand Down Expand Up @@ -374,32 +389,45 @@ def __init__( # pylint: disable=too-many-arguments
self._protocol = protocol
self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args)
self._filepath = path
self._load_args["con"] = credentials["con"]
self._connection_str = credentials["con"]
self.create_connection(self._connection_str)

@classmethod
def create_connection(cls, connection_str: str) -> None:
"""Given a connection string, create singleton connection
to be used across all instances of `SQLQueryDataSet` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return

try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc

cls.engines[connection_str] = engine

def _describe(self) -> Dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
desc = {}
desc["sql"] = str(load_args.pop("sql", None))
desc["filepath"] = str(self._filepath)
del load_args["con"]
desc["load_args"] = str(load_args)

return desc
return dict(
sql=str(load_args.pop("sql", None)),
filepath=str(self._filepath),
load_args=str(load_args),
)

def _load(self) -> pd.DataFrame:
load_args = copy.deepcopy(self._load_args)
engine = self.engines[self._connection_str] # type: ignore

if self._filepath:
load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol)
with self._fs.open(load_path, mode="r") as fs_file:
load_args["sql"] = fs_file.read()

try:
return pd.read_sql_query(**load_args)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
return pd.read_sql_query(con=engine, **load_args)

def _save(self, data: pd.DataFrame) -> None:
raise DataSetError("`save` is not supported on SQLQueryDataSet")
Loading

0 comments on commit ceae20c

Please sign in to comment.