Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 32 additions & 41 deletions tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,17 +390,19 @@ def test_get_session_returns_a_boto3_session(self):
assert table.item_count == 0

@pytest.mark.parametrize(
"client_meta",
"hook_params",
[
AwsBaseHook(client_type="s3").get_client_type().meta,
AwsBaseHook(resource_type="dynamodb").get_resource_type().meta.client.meta,
pytest.param({"client_type": "s3"}, id="client-type"),
pytest.param({"resource_type": "dynamodb"}, id="resource-type"),
],
)
def test_user_agent_extra_update(self, client_meta):
def test_user_agent_extra_update(self, hook_params):
"""
We are only looking for the keys appended by the AwsBaseHook. A user_agent string
is a number of key/value pairs such as: `BOTO3/1.25.4 AIRFLOW/2.5.0.DEV0 AMPP/6.0.0`.
"""
client_meta = AwsBaseHook(aws_conn_id=None, client_type="s3").conn_client_meta

expected_user_agent_tag_keys = ["Airflow", "AmPP", "Caller", "DagRunKey"]

result_user_agent_tags = client_meta.config.user_agent.split(" ")
Expand Down Expand Up @@ -477,31 +479,25 @@ def mock_assume_role(**kwargs):
return sts_response

with mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.requests.Session.get"
) as mock_get, mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.boto3"
) as mock_boto3, mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.isinstance"
) as mock_isinstance:
mock_isinstance.return_value = True
mock_get.return_value.ok = True

mock_client = mock_boto3.session.Session.return_value.client
"airflow.providers.amazon.aws.hooks.base_aws.BaseSessionFactory._create_basic_session",
spec=boto3.session.Session,
) as mocked_basic_session:
mocked_basic_session.return_value.region_name = "us-east-2"
mock_client = mocked_basic_session.return_value.client
mock_client.return_value.assume_role.side_effect = mock_assume_role

hook = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="s3")
hook.get_client_type("s3")

calls_assume_role = [
mock.call.session.Session().client("sts", config=mock.ANY, endpoint_url=sts_endpoint),
mock.call.session.Session()
.client()
.assume_role(
RoleArn=role_arn,
RoleSessionName=slugified_role_session_name,
),
]
mock_boto3.assert_has_calls(calls_assume_role)
AwsBaseHook(aws_conn_id=aws_conn_id, client_type="s3").get_client_type()
mocked_basic_session.assert_has_calls(
[
mock.call().client("sts", config=mock.ANY, endpoint_url=sts_endpoint),
mock.call()
.client()
.assume_role(
RoleArn=role_arn,
RoleSessionName=slugified_role_session_name,
),
]
)

def test_get_credentials_from_gcp_credentials(self):
mock_connection = Connection(
Expand Down Expand Up @@ -684,25 +680,21 @@ def mock_assume_role_with_saml(**kwargs):
with mock.patch("builtins.__import__", side_effect=import_mock), mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.requests.Session.get"
) as mock_get, mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.boto3"
) as mock_boto3, mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.isinstance"
) as mock_isinstance:
mock_isinstance.return_value = True
mock_get.return_value.ok = True

mock_client = mock_boto3.session.Session.return_value.client
"airflow.providers.amazon.aws.hooks.base_aws.BaseSessionFactory._create_basic_session",
spec=boto3.session.Session,
) as mocked_basic_session:
mocked_basic_session.return_value.region_name = "us-east-2"
mock_client = mocked_basic_session.return_value.client
mock_client.return_value.assume_role_with_saml.side_effect = mock_assume_role_with_saml

hook = AwsBaseHook(aws_conn_id="aws_default", client_type="s3")
hook.get_client_type("s3")
AwsBaseHook(aws_conn_id="aws_default", client_type="s3").get_client_type()

mock_get.assert_called_once_with(idp_url, auth=mock_auth)
mock_xpath.assert_called_once_with(xpath)

calls_assume_role_with_saml = [
mock.call.session.Session().client("sts", config=mock.ANY, endpoint_url=sts_endpoint),
mock.call.session.Session()
mocked_basic_session.assert_has_calls = [
mock.call().client("sts", config=mock.ANY, endpoint_url=sts_endpoint),
mock.call()
.client()
.assume_role_with_saml(
DurationSeconds=duration_seconds,
Expand All @@ -711,7 +703,6 @@ def mock_assume_role_with_saml(**kwargs):
SAMLAssertion=encoded_saml_assertion,
),
]
mock_boto3.assert_has_calls(calls_assume_role_with_saml)

@mock_iam
def test_expand_role(self):
Expand Down
72 changes: 38 additions & 34 deletions tests/providers/amazon/aws/hooks/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from unittest import mock

import pytest

from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook

SUBMIT_JOB_SUCCESS_RETURN = {
Expand Down Expand Up @@ -46,6 +48,12 @@
}


@pytest.fixture
def mocked_hook_client():
with mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook.conn") as m:
yield m


class TestEmrContainerHook:
def setup_method(self):
self.emr_containers = EmrContainerHook(virtual_cluster_id="vc1234")
Expand All @@ -54,14 +62,8 @@ def test_init(self):
assert self.emr_containers.aws_conn_id == "aws_default"
assert self.emr_containers.virtual_cluster_id == "vc1234"

@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True)
@mock.patch("boto3.session.Session")
def test_create_emr_on_eks_cluster(self, mock_session, mock_isinstance):
emr_client_mock = mock.MagicMock()
emr_client_mock.create_virtual_cluster.return_value = CREATE_EMR_ON_EKS_CLUSTER_RETURN
emr_session_mock = mock.MagicMock()
emr_session_mock.client.return_value = emr_client_mock
mock_session.return_value = emr_session_mock
def test_create_emr_on_eks_cluster(self, mocked_hook_client):
mocked_hook_client.create_virtual_cluster.return_value = CREATE_EMR_ON_EKS_CLUSTER_RETURN

emr_on_eks_create_cluster_response = self.emr_containers.create_emr_on_eks_cluster(
virtual_cluster_name="test_virtual_cluster",
Expand All @@ -70,15 +72,19 @@ def test_create_emr_on_eks_cluster(self, mock_session, mock_isinstance):
)
assert emr_on_eks_create_cluster_response == "vc1234"

@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True)
@mock.patch("boto3.session.Session")
def test_submit_job(self, mock_session, mock_isinstance):
mocked_hook_client.create_virtual_cluster.assert_called_once_with(
name="test_virtual_cluster",
containerProvider={
"id": "test_eks_cluster",
"type": "EKS",
"info": {"eksInfo": {"namespace": "test_eks_namespace"}},
},
tags={},
)

def test_submit_job(self, mocked_hook_client):
# Mock out the emr_client creator
emr_client_mock = mock.MagicMock()
emr_client_mock.start_job_run.return_value = SUBMIT_JOB_SUCCESS_RETURN
emr_session_mock = mock.MagicMock()
emr_session_mock.client.return_value = emr_client_mock
mock_session.return_value = emr_session_mock
mocked_hook_client.start_job_run.return_value = SUBMIT_JOB_SUCCESS_RETURN

emr_containers_job = self.emr_containers.submit_job(
name="test-job-run",
Expand All @@ -90,32 +96,30 @@ def test_submit_job(self, mock_session, mock_isinstance):
)
assert emr_containers_job == "job123456"

@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True)
@mock.patch("boto3.session.Session")
def test_query_status_polling_when_terminal(self, mock_session, mock_isinstance):
emr_client_mock = mock.MagicMock()
emr_session_mock = mock.MagicMock()
emr_session_mock.client.return_value = emr_client_mock
mock_session.return_value = emr_session_mock
emr_client_mock.describe_job_run.return_value = JOB1_RUN_DESCRIPTION
mocked_hook_client.start_job_run.assert_called_once_with(
name="test-job-run",
virtualClusterId="vc1234",
executionRoleArn="arn:aws:somerole",
releaseLabel="emr-6.3.0-latest",
jobDriver={},
configurationOverrides={},
tags={},
clientToken="uuidtoken",
)

def test_query_status_polling_when_terminal(self, mocked_hook_client):
mocked_hook_client.describe_job_run.return_value = JOB1_RUN_DESCRIPTION
query_status = self.emr_containers.poll_query_status(job_id="job123456")
# should only poll once since query is already in terminal state
emr_client_mock.describe_job_run.assert_called_once()
mocked_hook_client.describe_job_run.assert_called_once()
assert query_status == "COMPLETED"

@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True)
@mock.patch("boto3.session.Session")
def test_query_status_polling_with_timeout(self, mock_session, mock_isinstance):
emr_client_mock = mock.MagicMock()
emr_session_mock = mock.MagicMock()
emr_session_mock.client.return_value = emr_client_mock
mock_session.return_value = emr_session_mock
emr_client_mock.describe_job_run.return_value = JOB2_RUN_DESCRIPTION
def test_query_status_polling_with_timeout(self, mocked_hook_client):
mocked_hook_client.describe_job_run.return_value = JOB2_RUN_DESCRIPTION

query_status = self.emr_containers.poll_query_status(
job_id="job123456", max_polling_attempts=2, poll_interval=0
)
# should poll until max_tries is reached since query is in non-terminal state
assert emr_client_mock.describe_job_run.call_count == 2
assert mocked_hook_client.describe_job_run.call_count == 2
assert query_status == "RUNNING"
48 changes: 13 additions & 35 deletions tests/providers/amazon/aws/operators/test_cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from unittest import mock
from unittest.mock import MagicMock

import pytest

from airflow.models.dag import DAG
from airflow.providers.amazon.aws.operators.cloud_formation import (
CloudFormationCreateStackOperator,
Expand All @@ -31,19 +33,14 @@
DEFAULT_ARGS = {"owner": "airflow", "start_date": DEFAULT_DATE}


class TestCloudFormationCreateStackOperator:
def setup_method(self):
# Mock out the cloudformation_client (moto fails with an exception).
self.cloudformation_client_mock = MagicMock()

# Mock out the emr_client creator
cloudformation_session_mock = MagicMock()
cloudformation_session_mock.client.return_value = self.cloudformation_client_mock
self.boto3_session_mock = MagicMock(return_value=cloudformation_session_mock)
@pytest.fixture
def mocked_hook_client():
with mock.patch("airflow.providers.amazon.aws.hooks.cloud_formation.CloudFormationHook.conn") as m:
yield m

self.mock_context = MagicMock()

def test_create_stack(self):
class TestCloudFormationCreateStackOperator:
def test_create_stack(self, mocked_hook_client):
stack_name = "myStack"
timeout = 15
template_body = "My stack body"
Expand All @@ -55,30 +52,15 @@ def test_create_stack(self):
dag=DAG("test_dag_id", default_args=DEFAULT_ARGS),
)

with mock.patch("boto3.session.Session", self.boto3_session_mock), mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.isinstance"
) as mock_isinstance:
mock_isinstance.return_value = True
operator.execute(self.mock_context)
operator.execute(MagicMock())

self.cloudformation_client_mock.create_stack.assert_any_call(
mocked_hook_client.create_stack.assert_any_call(
StackName=stack_name, TemplateBody=template_body, TimeoutInMinutes=timeout
)


class TestCloudFormationDeleteStackOperator:
def setup_method(self):
# Mock out the cloudformation_client (moto fails with an exception).
self.cloudformation_client_mock = MagicMock()

# Mock out the emr_client creator
cloudformation_session_mock = MagicMock()
cloudformation_session_mock.client.return_value = self.cloudformation_client_mock
self.boto3_session_mock = MagicMock(return_value=cloudformation_session_mock)

self.mock_context = MagicMock()

def test_delete_stack(self):
def test_delete_stack(self, mocked_hook_client):
stack_name = "myStackToBeDeleted"

operator = CloudFormationDeleteStackOperator(
Expand All @@ -87,10 +69,6 @@ def test_delete_stack(self):
dag=DAG("test_dag_id", default_args=DEFAULT_ARGS),
)

with mock.patch("boto3.session.Session", self.boto3_session_mock), mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.isinstance"
) as mock_isinstance:
mock_isinstance.return_value = True
operator.execute(self.mock_context)
operator.execute(MagicMock())

self.cloudformation_client_mock.delete_stack.assert_any_call(StackName=stack_name)
mocked_hook_client.delete_stack.assert_any_call(StackName=stack_name)
Loading