Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions devel-common/src/tests_common/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,6 +2071,35 @@ def mock_supervisor_comms(monkeypatch):
yield comms


@pytest.fixture
def sdk_connection_not_found(mock_supervisor_comms):
"""
Fixture that mocks supervisor comms to return CONNECTION_NOT_FOUND error.

This eliminates the need to manually set up the mock in every test that
needs a connection not found message through supervisor comms.

Example:
@pytest.mark.db_test
def test_invalid_location(self, sdk_connection_not_found):
# Test logic that expects CONNECTION_NOT_FOUND error
with pytest.raises(AirflowException):
operator.execute(context)
"""
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if not AIRFLOW_V_3_0_PLUS:
yield None
return

from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse

mock_supervisor_comms.send.return_value = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND)

yield mock_supervisor_comms


@pytest.fixture
def mocked_parse(spy_agency):
"""
Expand Down
23 changes: 5 additions & 18 deletions providers/amazon/tests/unit/amazon/aws/bundles/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils import db

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import clear_db_connections

AWS_CONN_ID_WITH_REGION = "s3_dags_connection"
AWS_CONN_ID_REGION = "eu-central-1"
Expand Down Expand Up @@ -78,13 +76,9 @@ def bundle_temp_dir(tmp_path):

@pytest.mark.skipif(not airflow.version.version.strip().startswith("3"), reason="Airflow >=3.0.0 test")
class TestS3DagBundle:
@classmethod
def teardown_class(cls) -> None:
clear_db_connections()

@classmethod
def setup_class(cls) -> None:
db.merge_conn(
@pytest.fixture(autouse=True)
def setup_connections(self, create_connection_without_db):
create_connection_without_db(
Connection(
conn_id=AWS_CONN_ID_DEFAULT,
conn_type="aws",
Expand All @@ -93,8 +87,8 @@ def setup_class(cls) -> None:
},
)
)
db.merge_conn(
conn=Connection(
create_connection_without_db(
Connection(
conn_id=AWS_CONN_ID_WITH_REGION,
conn_type="aws",
extra={
Expand All @@ -104,7 +98,6 @@ def setup_class(cls) -> None:
)
)

@pytest.mark.db_test
def test_view_url_generates_presigned_url(self):
bundle = S3DagBundle(
name="test", aws_conn_id=AWS_CONN_ID_DEFAULT, prefix="project1/dags", bucket_name=S3_BUCKET_NAME
Expand All @@ -113,15 +106,13 @@ def test_view_url_generates_presigned_url(self):
url: str = bundle.view_url("test_version")
assert url.startswith("https://my-airflow-dags-bucket.s3.amazonaws.com/project1/dags")

@pytest.mark.db_test
def test_view_url_template_generates_presigned_url(self):
bundle = S3DagBundle(
name="test", aws_conn_id=AWS_CONN_ID_DEFAULT, prefix="project1/dags", bucket_name=S3_BUCKET_NAME
)
url: str = bundle.view_url_template()
assert url.startswith("https://my-airflow-dags-bucket.s3.amazonaws.com/project1/dags")

@pytest.mark.db_test
def test_supports_versioning(self):
bundle = S3DagBundle(
name="test", aws_conn_id=AWS_CONN_ID_DEFAULT, prefix="project1/dags", bucket_name=S3_BUCKET_NAME
Expand All @@ -136,14 +127,12 @@ def test_supports_versioning(self):
with pytest.raises(AirflowException, match="S3 url with version is not supported"):
bundle.view_url("test_version")

@pytest.mark.db_test
def test_correct_bundle_path_used(self):
bundle = S3DagBundle(
name="test", aws_conn_id=AWS_CONN_ID_DEFAULT, prefix="project1_dags", bucket_name="airflow_dags"
)
assert str(bundle.base_dir) == str(bundle.s3_dags_dir)

@pytest.mark.db_test
def test_s3_bucket_and_prefix_validated(self, s3_bucket):
hook = S3Hook(aws_conn_id=AWS_CONN_ID_DEFAULT)
assert hook.check_for_bucket(s3_bucket.name) is True
Expand Down Expand Up @@ -195,7 +184,6 @@ def _upload_fixtures(self, bucket: str, fixtures_dir: str) -> None:
key = os.path.relpath(path, fixtures_dir)
client.upload_file(Filename=path, Bucket=bucket, Key=key)

@pytest.mark.db_test
def test_refresh(self, s3_bucket, s3_client):
bundle = S3DagBundle(
name="test",
Expand All @@ -218,7 +206,6 @@ def test_refresh(self, s3_bucket, s3_client):
assert bundle._log.debug.call_count == 3
assert bundle._log.debug.call_args_list == [download_log_call, download_log_call, download_log_call]

@pytest.mark.db_test
def test_refresh_without_prefix(self, s3_bucket, s3_client):
bundle = S3DagBundle(
name="test",
Expand Down
15 changes: 12 additions & 3 deletions providers/amazon/tests/unit/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

pytest.importorskip("aiobotocore")

Expand Down Expand Up @@ -430,9 +431,8 @@ def test_user_agent_caller_target_function_found(self, mock_class_name, found_cl
assert mock_class_name.call_count == len(found_classes)
assert user_agent_tags["Caller"] == found_classes[-1]

@pytest.mark.db_test
@mock.patch.object(AwsEcsExecutor, "_load_run_kwargs")
def test_user_agent_caller_target_executor_found(self, mock_load_run_kwargs):
def test_user_agent_caller_target_executor_found(self, mock_load_run_kwargs, sdk_connection_not_found):
with conf_vars(
{
("aws_ecs_executor", "cluster"): "foo",
Expand All @@ -456,7 +456,16 @@ def test_user_agent_caller_target_function_not_found(self):
@pytest.mark.db_test
@pytest.mark.parametrize("env_var, expected_version", [({"AIRFLOW_CTX_DAG_ID": "banana"}, 5), [{}, None]])
@mock.patch.object(AwsBaseHook, "_get_caller", return_value="Test")
def test_user_agent_dag_run_key_is_hashed_correctly(self, _, env_var, expected_version):
def test_user_agent_dag_run_key_is_hashed_correctly(
self, _, env_var, expected_version, mock_supervisor_comms
):
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk.execution_time.comms import ConnectionResult

mock_supervisor_comms.send.return_value = ConnectionResult(
conn_id="aws_default",
conn_type="aws",
)
with mock.patch.dict(os.environ, env_var, clear=True):
dag_run_key = self.fetch_tags()["DagRunKey"]

Expand Down
2 changes: 1 addition & 1 deletion providers/amazon/tests/unit/amazon/aws/hooks/test_emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_empty_emr_conn_id(self, mock_boto3_client):

@pytest.mark.db_test
@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_conn")
def test_missing_emr_conn_id(self, mock_boto3_client):
def test_missing_emr_conn_id(self, mock_boto3_client, sdk_connection_not_found):
"""Test not exists ``emr_conn_id``."""
mock_run_job_flow = mock.MagicMock()
mock_boto3_client.return_value.run_job_flow = mock_run_job_flow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_poke_reached_unexpected_terminal_state(self, mock_get_cluster_state, un
mock_get_cluster_state.assert_called_once_with(clusterName=CLUSTER_NAME)

@pytest.mark.db_test
def test_region_argument(self):
def test_region_argument(self, sdk_connection_not_found):
with pytest.warns(AirflowProviderDeprecationWarning) as w:
w.sensor = EksClusterStateSensor(
task_id=TASK_ID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,24 +600,39 @@ def test_table_unloading_using_redshift_data_api(
# test sql arg
assert_equal_ignore_multiple_spaces(mock_rs.execute_statement.call_args.kwargs["Sql"], unload_query)

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
def test_get_openlineage_facets_on_complete_default(
self, mock_get_facets, mock_run, mock_session, mock_connection, mock_hook
self, mock_get_facets, mock_run, mock_session, create_connection_without_db
):
create_connection_without_db(
Connection(
conn_id="aws_conn_id",
conn_type="aws",
schema="database",
port=5439,
host="cluster.id.region.redshift.amazonaws.com",
extra={},
)
)
create_connection_without_db(
Connection(
conn_id="redshift_conn_id",
conn_type="redshift",
schema="database",
port=5439,
host="cluster.id.region.redshift.amazonaws.com",
extra={},
)
)
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
mock_session.return_value.access_key = access_key
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None

mock_connection.return_value = mock.MagicMock(
schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
)
mock_facets = {
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
"documentation": DocumentationDatasetFacet(description="mock_description"),
Expand Down Expand Up @@ -671,24 +686,38 @@ def test_get_openlineage_facets_on_complete_default(
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
}

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
def test_get_openlineage_facets_on_complete_with_select_query(
self, mock_get_facets, mock_run, mock_session, mock_connection, mock_hook
self, mock_get_facets, mock_run, mock_session, create_connection_without_db
):
create_connection_without_db(
Connection(
conn_id="redshift_conn_id",
conn_type="redshift",
schema="database",
port=5439,
host="cluster.id.region.redshift.amazonaws.com",
extra={},
)
)
create_connection_without_db(
Connection(
conn_id="aws_conn_id",
conn_type="aws",
schema="database",
port=5439,
host="cluster.id.region.redshift.amazonaws.com",
extra={},
)
)
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
mock_session.return_value.access_key = access_key
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None

mock_connection.return_value = mock.MagicMock(
schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
)
mock_facets = {
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
"documentation": DocumentationDatasetFacet(description="mock_description"),
Expand Down Expand Up @@ -835,8 +864,6 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api(
"schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]),
}

@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
Expand All @@ -846,22 +873,38 @@ def test_get_openlineage_facets_on_complete_using_redshift_data_api(
)
@mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table")
def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned(
self, mock_get_facets, mock_rs_region, mock_rs, mock_run, mock_session, mock_connection, mock_hook
self, mock_get_facets, mock_rs_region, mock_rs, mock_run, mock_session, create_connection_without_db
):
"""
Ensuring both supported hooks - RedshiftDataHook and RedshiftSQLHook return same lineage.
"""
create_connection_without_db(
Connection(
conn_id="redshift_conn_id",
conn_type="redshift",
schema="database",
port=5439,
host="cluster.id.region.redshift.amazonaws.com",
extra={},
)
)
create_connection_without_db(
Connection(
conn_id="aws_conn_id",
conn_type="aws",
schema="database",
port=5439,
host="cluster.id.region.redshift.amazonaws.com",
extra={},
)
)

access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
mock_session.return_value.access_key = access_key
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None

mock_connection.return_value = mock.MagicMock(
schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
)
mock_hook.return_value = Connection()
mock_rs.execute_statement.return_value = {"Id": "STATEMENT_ID"}
mock_rs.describe_statement.return_value = {"Status": "FINISHED"}

Expand Down
Loading
Loading