Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions providers/amazon/tests/system/amazon/aws/example_mwaa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@

from datetime import datetime

import boto3

from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
from airflow.providers.amazon.aws.hooks.sts import StsHook
from airflow.providers.amazon.aws.operators.mwaa import MwaaTriggerDagRunOperator
from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor
from system.amazon.aws.utils import SystemTestContextBuilder
Expand All @@ -29,6 +34,7 @@
# Externally fetched variables:
EXISTING_ENVIRONMENT_NAME_KEY = "ENVIRONMENT_NAME"
EXISTING_DAG_ID_KEY = "DAG_ID"
ROLE_WITHOUT_INVOKE_REST_API_ARN_KEY = "ROLE_WITHOUT_INVOKE_REST_API_ARN"

sys_test_context_task = (
SystemTestContextBuilder()
Expand All @@ -45,6 +51,7 @@
# Make sure to set the environment variables with appropriate values
.add_variable(EXISTING_ENVIRONMENT_NAME_KEY)
.add_variable(EXISTING_DAG_ID_KEY)
.add_variable(ROLE_WITHOUT_INVOKE_REST_API_ARN_KEY)
.build()
)

Expand All @@ -58,6 +65,7 @@
test_context = sys_test_context_task()
env_name = test_context[EXISTING_ENVIRONMENT_NAME_KEY]
trigger_dag_id = test_context[EXISTING_DAG_ID_KEY]
restricted_role_arn = test_context[ROLE_WITHOUT_INVOKE_REST_API_ARN_KEY]

# [START howto_operator_mwaa_trigger_dag_run]
trigger_dag_run = MwaaTriggerDagRunOperator(
Expand All @@ -77,12 +85,36 @@
)
# [END howto_sensor_mwaa_dag_run]

# This task in the system test verifies that the MwaaHook's IAM fallback mechanism continues to work with
# the live MWAA API. This fallback depends on parsing a specific error message from the MWAA API, so we
# want to ensure we find out if the API response format ever changes. Unit tests cover this with mocked
# responses, but this system test validates against the real API.
@task
def test_iam_fallback(role_to_assume_arn, mwaa_env_name):
sts_client = StsHook().conn
assumed_role = sts_client.assume_role(
RoleArn=role_to_assume_arn, RoleSessionName="MwaaSysTestIamFallback"
)

credentials = assumed_role["Credentials"]
session = boto3.Session(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)

mwaa_hook = MwaaHook()
mwaa_hook.conn = session.client("mwaa")
response = mwaa_hook.invoke_rest_api(env_name=mwaa_env_name, path="/dags", method="GET")
return "dags" in response["RestApiResponse"]

chain(
# TEST SETUP
test_context,
# TEST BODY
trigger_dag_run,
wait_for_dag_run,
test_iam_fallback(restricted_role_arn, env_name),
)

from tests_common.test_utils.watcher import watcher
Expand Down