Skip to content

Commit

Permalink
perf(datasets): don't connect in __init__ method
Browse files Browse the repository at this point in the history
Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
  • Loading branch information
deepyaman committed Oct 9, 2023
1 parent 82d50eb commit 80b6069
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
36 changes: 19 additions & 17 deletions kedro-datasets/kedro_datasets/pandas/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,7 @@ class SQLQueryDataSet(AbstractDataSet[None, pd.DataFrame]):
>>> date: "%Y-%m-%d %H:%M:%S.%f0 %z"
"""

# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}
engines: Dict[str, Engine] = {}

def __init__( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -474,22 +472,27 @@ def __init__( # pylint: disable=too-many-arguments
self.adapt_mssql_date_params()

@classmethod
def create_connection(cls, connection_str: str) -> None:
def create_connection(cls, connection_str: str) -> Engine:
"""Given a connection string, create singleton connection
to be used across all instances of `SQLQueryDataSet` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return
if connection_str not in cls.engines:
try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc

try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
cls.engines[connection_str] = engine

return cls.engines[connection_str]

cls.engines[connection_str] = engine
@property
def engine(self):
"""The ``Engine`` object for the dataset's connection string."""
return self.create_connection(self._connection_str)

def _describe(self) -> Dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
Expand All @@ -502,16 +505,15 @@ def _describe(self) -> Dict[str, Any]:

def _load(self) -> pd.DataFrame:
load_args = copy.deepcopy(self._load_args)
engine = self.engines[self._connection_str].execution_options(
**self._execution_options
) # type: ignore

if self._filepath:
load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol)
with self._fs.open(load_path, mode="r") as fs_file:
load_args["sql"] = fs_file.read()

return pd.read_sql_query(con=engine, **load_args)
return pd.read_sql_query(
con=self.engine.execution_options(**self._execution_options), **load_args
)

def _save(self, data: None) -> NoReturn: # pylint: disable=no-self-use
raise DataSetError("'save' is not supported on SQLQueryDataSet")
Expand Down
6 changes: 3 additions & 3 deletions kedro-datasets/tests/pandas/test_sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def test_load_driver_missing(self, mocker):
"kedro_datasets.pandas.sql_dataset.create_engine", side_effect=_err
)
with pytest.raises(DataSetError, match=ERROR_PREFIX + "mysqlclient"):
SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION})
SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION}).load()

def test_invalid_module(self, mocker):
"""Test that if an unknown module/driver is encountered by SQLAlchemy
Expand All @@ -354,7 +354,7 @@ def test_invalid_module(self, mocker):
)
pattern = ERROR_PREFIX + r"Invalid module some\_module"
with pytest.raises(DataSetError, match=pattern):
SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION})
SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION}).load()

def test_load_unknown_module(self, mocker):
"""Test that if an unknown module/driver is encountered by SQLAlchemy
Expand All @@ -365,7 +365,7 @@ def test_load_unknown_module(self, mocker):
)
pattern = ERROR_PREFIX + r"No module named \'unknown\_module\'"
with pytest.raises(DataSetError, match=pattern):
SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION})
SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION}).load()

def test_load_unknown_sql(self):
"""Check the error when unknown SQL dialect is provided
Expand Down

0 comments on commit 80b6069

Please sign in to comment.