Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions providers/sqlite/tests/unit/sqlite/hooks/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,22 @@
from __future__ import annotations

from unittest import mock
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest
import sqlalchemy

from airflow.models import Connection
from airflow.providers.sqlite.hooks.sqlite import SqliteHook

pytestmark = pytest.mark.db_test

def mock_connection(host=None, extra=None, uri=None):
"""Create a mock connection object without triggering SQLAlchemy ORM initialization."""
conn = MagicMock(spec=Connection)
conn.host = host
conn.extra = extra
conn.get_uri.return_value = uri if uri is not None else (host or "")
return conn


class TestSqliteHookConn:
Expand All @@ -39,13 +46,22 @@ class UnitTestSqliteHook(SqliteHook):
@pytest.mark.parametrize(
"connection, uri",
[
(Connection(host="host"), "file:host"),
(Connection(host="host", extra='{"mode":"ro"}'), "file:host?mode=ro"),
(Connection(host=":memory:"), "file::memory:"),
(Connection(), "file:"),
(Connection(uri="sqlite://relative/path/to/db?mode=ro"), "file:relative/path/to/db?mode=ro"),
(Connection(uri="sqlite:///absolute/path/to/db?mode=ro"), "file:/absolute/path/to/db?mode=ro"),
(Connection(uri="sqlite://?mode=ro"), "file:?mode=ro"),
(mock_connection(host="host", uri="sqlite:///host"), "file:/host"),
(
mock_connection(host="host", extra='{"mode":"ro"}', uri="sqlite:///host?mode=ro"),
"file:/host?mode=ro",
),
(mock_connection(host=":memory:", uri="sqlite:///:memory:"), "file:/:memory:"),
(mock_connection(uri="sqlite:///"), "file:/"),
(
mock_connection(uri="sqlite:///relative/path/to/db?mode=ro"),
"file:/relative/path/to/db?mode=ro",
),
(
mock_connection(uri="sqlite:////absolute/path/to/db?mode=ro"),
"file://absolute/path/to/db?mode=ro",
),
(mock_connection(uri="sqlite://?mode=ro"), "sqlite:/?mode=ro"),
],
)
@patch("airflow.providers.sqlite.hooks.sqlite.sqlite3.connect")
Expand All @@ -56,10 +72,12 @@ def test_get_conn(self, mock_connect, connection, uri):

@patch("airflow.providers.sqlite.hooks.sqlite.sqlite3.connect")
def test_get_conn_non_default_id(self, mock_connect):
self.db_hook.get_connection = mock.Mock(return_value=Connection(host="host"))
self.db_hook.get_connection = mock.Mock(
return_value=mock_connection(host="host", uri="sqlite:///host")
)
self.db_hook.test_conn_id = "non_default"
self.db_hook.get_conn()
mock_connect.assert_called_once_with("file:host", uri=True)
mock_connect.assert_called_once_with("file:/host", uri=True)
self.db_hook.get_connection.assert_called_once_with("non_default")


Expand Down Expand Up @@ -135,6 +153,7 @@ def test_run_log(self):
self.db_hook.run(statement)
assert self.db_hook.log.info.call_count == 2

@pytest.mark.db_test
def test_generate_insert_sql_replace_false(self):
expected_sql = "INSERT INTO Customer (first_name, last_name) VALUES (?,?)"
rows = ("James", "1")
Expand All @@ -145,6 +164,7 @@ def test_generate_insert_sql_replace_false(self):

assert sql == expected_sql

@pytest.mark.db_test
def test_generate_insert_sql_replace_true(self):
expected_sql = "REPLACE INTO Customer (first_name, last_name) VALUES (?,?)"
rows = ("James", "1")
Expand Down