Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from airflow.models import Connection
from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook

pytestmark = pytest.mark.db_test


class TestSparkConnectHook:
@pytest.fixture(autouse=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from airflow.models import Connection
from airflow.providers.apache.spark.hooks.spark_jdbc import SparkJDBCHook

pytestmark = pytest.mark.db_test


class TestSparkJDBCHook:
_config = {
Expand Down Expand Up @@ -122,6 +120,7 @@ def setup_connections(self, create_connection_without_db):
)
)

@pytest.mark.db_test
def test_resolve_jdbc_connection(self):
# Given
hook = SparkJDBCHook(jdbc_conn_id="jdbc-default")
Expand All @@ -139,6 +138,7 @@ def test_resolve_jdbc_connection(self):
# Then
assert connection == expected_connection

@pytest.mark.db_test
def test_build_jdbc_arguments(self):
# Given
hook = SparkJDBCHook(**self._config)
Expand Down Expand Up @@ -183,21 +183,25 @@ def test_build_jdbc_arguments(self):
]
assert expected_jdbc_arguments == cmd

@pytest.mark.db_test
def test_build_jdbc_arguments_invalid(self):
# Given
hook = SparkJDBCHook(**self._invalid_config)

# Expect Exception
hook._build_jdbc_application_arguments(hook._resolve_jdbc_connection())

@pytest.mark.db_test
def test_invalid_host(self):
with pytest.raises(ValueError, match="host should not contain a"):
SparkJDBCHook(jdbc_conn_id="jdbc-invalid-host", **self._config)

@pytest.mark.db_test
def test_invalid_schema(self):
with pytest.raises(ValueError, match="schema should not contain a"):
SparkJDBCHook(jdbc_conn_id="jdbc-invalid-schema", **self._config)

@pytest.mark.db_test
@patch("airflow.providers.apache.spark.hooks.spark_submit.SparkSubmitHook.submit")
def test_invalid_extra_conn_prefix(self, mock_submit):
hook = SparkJDBCHook(jdbc_conn_id="jdbc-invalid-extra-conn-prefix", **self._config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@

from tests_common.test_utils.db import clear_db_connections

pytestmark = pytest.mark.db_test


def get_after(sentinel, iterable):
"""Get the value after `sentinel` in an `iterable`"""
Expand Down Expand Up @@ -78,6 +76,7 @@ def setup_connections(self, create_connection_without_db):
def teardown_class(cls) -> None:
clear_db_connections(add_default_connections_back=True)

@pytest.mark.db_test
def test_build_command(self):
hook = SparkSqlHook(**self._config)

Expand All @@ -100,6 +99,7 @@ def test_build_command(self):
if self._config["verbose"]:
assert "--verbose" in cmd

@pytest.mark.db_test
def test_build_command_with_str_conf(self):
hook = SparkSqlHook(**self._config_str)

Expand All @@ -123,6 +123,7 @@ def test_build_command_with_str_conf(self):
if self._config["verbose"]:
assert "--verbose" in cmd

@pytest.mark.db_test
@patch("airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen")
def test_spark_process_runcmd(self, mock_popen):
# Given
Expand Down Expand Up @@ -171,6 +172,7 @@ def test_spark_process_runcmd(self, mock_popen):
universal_newlines=True,
)

@pytest.mark.db_test
@patch("airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen")
def test_spark_process_runcmd_with_str(self, mock_popen):
# Given
Expand Down Expand Up @@ -201,6 +203,7 @@ def test_spark_process_runcmd_with_str(self, mock_popen):
universal_newlines=True,
)

@pytest.mark.db_test
@patch("airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen")
def test_spark_process_runcmd_with_list(self, mock_popen):
# Given
Expand Down Expand Up @@ -231,6 +234,7 @@ def test_spark_process_runcmd_with_list(self, mock_popen):
universal_newlines=True,
)

@pytest.mark.db_test
@patch("airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen")
def test_spark_process_runcmd_and_fail(self, mock_popen):
# Given
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
from airflow.models import Connection
from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook

pytestmark = pytest.mark.db_test


class TestSparkSubmitHook:
_spark_job_file = "test_application.py"
Expand Down Expand Up @@ -170,6 +168,7 @@ def setup_connections(self, create_connection_without_db):
)
)

@pytest.mark.db_test
@patch(
"airflow.providers.apache.spark.hooks.spark_submit.os.getenv", return_value="/tmp/airflow_krb5_ccache"
)
Expand Down Expand Up @@ -290,6 +289,7 @@ def test_build_track_driver_status_command(self):
assert expected_spark_standalone_cluster == build_track_driver_status_spark_standalone_cluster
assert expected_spark_yarn_cluster == build_track_driver_status_spark_yarn_cluster

@pytest.mark.db_test
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
def test_spark_process_runcmd(self, mock_popen):
# Given
Expand All @@ -310,6 +310,7 @@ def test_spark_process_runcmd(self, mock_popen):
bufsize=-1,
)

@pytest.mark.db_test
def test_resolve_should_track_driver_status(self):
# Given
hook_default = SparkSubmitHook(conn_id="")
Expand Down Expand Up @@ -345,6 +346,7 @@ def test_resolve_should_track_driver_status(self):
assert should_track_driver_status_spark_binary_set is False
assert should_track_driver_status_spark_standalone_cluster is True

@pytest.mark.db_test
def test_resolve_connection_yarn_default(self):
# Given
hook = SparkSubmitHook(conn_id="")
Expand All @@ -367,6 +369,7 @@ def test_resolve_connection_yarn_default(self):
assert connection == expected_spark_connection
assert dict_cmd["--master"] == "yarn"

@pytest.mark.db_test
def test_resolve_connection_yarn_default_connection(self):
# Given
hook = SparkSubmitHook(conn_id="spark_default")
Expand Down Expand Up @@ -560,6 +563,7 @@ def test_resolve_connection_spark_binary_default_value_override(self):
assert connection == expected_spark_connection
assert cmd[0] == "spark3-submit"

@pytest.mark.db_test
def test_resolve_connection_spark_binary_default_value(self):
# Given
hook = SparkSubmitHook(conn_id="spark_default")
Expand Down Expand Up @@ -1044,6 +1048,7 @@ def test_k8s_process_on_kill(self, mock_popen, mock_client_method):
),
],
)
@pytest.mark.db_test
def test_masks_passwords(self, command: str, expected: str) -> None:
# Given
hook = SparkSubmitHook()
Expand All @@ -1054,13 +1059,15 @@ def test_masks_passwords(self, command: str, expected: str) -> None:
# Then
assert command_masked == expected

@pytest.mark.db_test
def test_create_keytab_path_from_base64_keytab_with_decode_exception(self):
hook = SparkSubmitHook()
invalid_base64 = "invalid_base64"

with pytest.raises(AirflowException, match="Failed to decode base64 keytab"):
hook._create_keytab_path_from_base64_keytab(invalid_base64, None)

@pytest.mark.db_test
@patch("pathlib.Path.exists")
@patch("builtins.open", new_callable=mock_open)
def test_create_keytab_path_from_base64_keytab_with_write_exception(
Expand All @@ -1084,6 +1091,7 @@ def test_create_keytab_path_from_base64_keytab_with_write_exception(
# Then
assert mock_exists.call_count == 2 # called twice (before write, after write)

@pytest.mark.db_test
@patch("airflow.providers.apache.spark.hooks.spark_submit.shutil.move")
@patch("pathlib.Path.exists")
@patch("builtins.open", new_callable=mock_open)
Expand All @@ -1110,6 +1118,7 @@ def test_create_keytab_path_from_base64_keytab_with_move_exception(
mock_move.assert_called_once()
assert mock_exists.call_count == 2 # called twice (before write, after write)

@pytest.mark.db_test
@patch("airflow.providers.apache.spark.hooks.spark_submit.uuid.uuid4")
@patch("pathlib.Path.resolve")
@patch("airflow.providers.apache.spark.hooks.spark_submit.shutil.move")
Expand Down Expand Up @@ -1140,6 +1149,7 @@ def test_create_keytab_path_from_base64_keytab_with_new_keytab(
mock_open().write.assert_called_once_with(keytab_value)
mock_move.assert_called_once()

@pytest.mark.db_test
@patch("pathlib.Path.resolve")
@patch("airflow.providers.apache.spark.hooks.spark_submit.shutil.move")
@patch("pathlib.Path.exists")
Expand Down Expand Up @@ -1168,6 +1178,7 @@ def test_create_keytab_path_from_base64_keytab_with_new_keytab_with_principal(
mock_open().write.assert_called_once_with(keytab_value)
mock_move.assert_called_once()

@pytest.mark.db_test
@patch("pathlib.Path.resolve")
@patch("pathlib.Path.exists")
@patch("builtins.open", new_callable=mock_open)
Expand Down