Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions providers/amazon/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ Misc

* ``The experimental AWS auth manager is no longer compatible with Airflow 2``

Bug Fixes
~~~~~~~~~

* ``The DMS waiter replication_terminal_status has been extended to proceed on 2 additional states: "created" and "deprovisioned"``

9.2.0
.....

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,10 +493,6 @@ def execute(self, context: Context) -> None:
Filters=[{"Name": "replication-config-arn", "Values": [self.replication_config_arn]}],
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
)
self.hook.get_waiter("replication_deprovisioned").wait(
Filters=[{"Name": "replication-config-arn", "Values": [self.replication_config_arn]}],
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
)
self.hook.delete_replication_config(self.replication_config_arn)
self.handle_delete_wait()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@
"argument": "Replications[0].Status",
"expected": "stopped",
"state": "success"
},
{
"matcher": "path",
"argument": "Replications[0].Status",
"expected": "created",
"state": "success"
},
{
"matcher": "path",
"argument": "Replications[0].ProvisionData.ProvisionState",
"expected": "deprovisioned",
"state": "success"
}
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,11 +647,6 @@ def test_happy_path(self, mock_waiter, mock_handle, mock_describe_replications,
Filters=[{"Name": "replication-config-arn", "Values": ["arn:xxxxxx"]}],
WaiterConfig={"Delay": 60, "MaxAttempts": 60},
),
mock.call("replication_deprovisioned"),
mock.call().wait(
Filters=[{"Name": "replication-config-arn", "Values": ["arn:xxxxxx"]}],
WaiterConfig={"Delay": 60, "MaxAttempts": 60},
),
]
)
mock_handle.assert_called_once()
Expand Down Expand Up @@ -695,17 +690,9 @@ def test_wait_for_completion(self, mock_waiter, mock_describe_replications, mock
Filters=[{"Name": "replication-config-arn", "Values": ["arn:xxxxxx"]}],
WaiterConfig={"Delay": 60, "MaxAttempts": 60},
),
mock.call("replication_deprovisioned"),
mock.call().wait(
Filters=[{"Name": "replication-config-arn", "Values": ["arn:xxxxxx"]}],
WaiterConfig={"Delay": 60, "MaxAttempts": 60},
),
]
)

# mock_waiter.assert_called_with("replication_config_deleted")
# mock_waiter.assert_called_once()

@mock.patch.object(DmsHook, "conn")
@mock.patch.object(DmsHook, "describe_replications")
@mock.patch.object(DmsHook, "get_waiter")
Expand All @@ -724,12 +711,7 @@ def test_wait_for_completion_not_ready(self, mock_waiter, mock_describe_replicat

mock_waiter.assert_has_calls(
[
mock.call("replication_deprovisioned"),
mock.call().wait(
Filters=[{"Name": "replication-config-arn", "Values": ["arn:xxxxxx"]}],
WaiterConfig={"Delay": 60, "MaxAttempts": 60},
),
mock.call("replication_config_deleted"),
mock.call("replication_terminal_status"),
mock.call().wait(
Filters=[{"Name": "replication-config-arn", "Values": ["arn:xxxxxx"]}],
WaiterConfig={"Delay": 60, "MaxAttempts": 60},
Expand Down Expand Up @@ -759,11 +741,6 @@ def test_not_ready_state(self, mock_waiter, mock_handle, mock_describe, mock_con
Filters=[{"Name": "replication-config-arn", "Values": ["arn:xxxxxx"]}],
WaiterConfig={"Delay": 60, "MaxAttempts": 60},
),
mock.call("replication_deprovisioned"),
mock.call().wait(
Filters=[{"Name": "replication-config-arn", "Values": ["arn:xxxxxx"]}],
WaiterConfig={"Delay": 60, "MaxAttempts": 60},
),
]
)
mock_handle.assert_called_once()
Expand All @@ -790,11 +767,6 @@ def test_not_deprovisioned(self, mock_waiter, mock_handle, mock_describe, mock_c
Filters=[{"Name": "replication-config-arn", "Values": ["arn:xxxxxx"]}],
WaiterConfig={"Delay": 60, "MaxAttempts": 60},
),
mock.call("replication_deprovisioned"),
mock.call().wait(
Filters=[{"Name": "replication-config-arn", "Values": ["arn:xxxxxx"]}],
WaiterConfig={"Delay": 60, "MaxAttempts": 60},
),
]
)
mock_handle.assert_called_once()
Expand Down
129 changes: 10 additions & 119 deletions providers/amazon/tests/system/amazon/aws/example_dms_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import boto3
from providers.amazon.tests.system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder
from providers.amazon.tests.system.amazon.aws.utils.ec2 import get_default_vpc_id
from sqlalchemy import Column, MetaData, String, Table, create_engine

from airflow.decorators import task
Expand All @@ -38,8 +37,6 @@
DmsDeleteReplicationConfigOperator,
DmsDescribeReplicationConfigsOperator,
DmsDescribeReplicationsOperator,
DmsStartReplicationOperator,
DmsStopReplicationOperator,
)
from airflow.providers.amazon.aws.operators.rds import (
RdsCreateDbInstanceOperator,
Expand Down Expand Up @@ -76,11 +73,6 @@
("Subversion", "2000"),
("NiFi", "2006"),
]
SG_IP_PERMISSION = {
"FromPort": 5432,
"IpProtocol": "All",
"IpRanges": [{"CidrIp": "0.0.0.0/0"}],
}


def _get_rds_instance_endpoint(instance_name: str):
Expand All @@ -92,25 +84,6 @@ def _get_rds_instance_endpoint(instance_name: str):
return rds_instance_endpoint


@task
def create_security_group(security_group_name: str, vpc_id: str):
client = boto3.client("ec2")
security_group = client.create_security_group(
GroupName=security_group_name,
Description="Created for DMS system test",
VpcId=vpc_id,
)
client.get_waiter("security_group_exists").wait(
GroupIds=[security_group["GroupId"]],
)
client.authorize_security_group_ingress(
GroupId=security_group["GroupId"],
IpPermissions=[SG_IP_PERMISSION],
)

return security_group["GroupId"]


@task
def create_sample_table(instance_name: str, db_name: str, table_name: str):
print("Creating sample table.")
Expand Down Expand Up @@ -138,37 +111,6 @@ def create_sample_table(instance_name: str, db_name: str, table_name: str):
connection.execute(table.select())


@task(trigger_rule=TriggerRule.ALL_SUCCESS)
def create_vpc_endpoints(vpc_id: str):
print("Creating VPC endpoints in vpc: %s", vpc_id)
client = boto3.client("ec2")
session = boto3.session.Session()
region = session.region_name
route_tbls = client.describe_route_tables(Filters=[{"Name": "vpc-id", "Values": [vpc_id]}])
endpoints = client.create_vpc_endpoint(
VpcId=vpc_id,
ServiceName=f"com.amazonaws.{region}.s3",
VpcEndpointType="Gateway",
RouteTableIds=[tbl["RouteTableId"] for tbl in route_tbls["RouteTables"]],
)

return endpoints.get("VpcEndpoint", {}).get("VpcEndpointId")


@task(trigger_rule=TriggerRule.ALL_DONE)
def delete_vpc_endpoints(endpoint_ids: list[str]):
if len(endpoint_ids) == 0:
print("No VPC endpoints to delete.")
return

print("Deleting VPC endpoints.")
client = boto3.client("ec2")

client.delete_vpc_endpoints(VpcEndpointIds=endpoint_ids, DryRun=False)

print("Deleted endpoints: %s", endpoint_ids)


@task(multiple_outputs=True)
def create_dms_assets(
db_name: str,
Expand Down Expand Up @@ -223,20 +165,8 @@ def delete_dms_assets(
target_endpoint_identifier: str,
):
dms_client = boto3.client("dms")

print("Deleting DMS assets.")

print(source_endpoint_arn)
print(target_endpoint_arn)

try:
dms_client.delete_endpoint(EndpointArn=source_endpoint_arn)
dms_client.delete_endpoint(EndpointArn=target_endpoint_arn)
except Exception as ex:
print("Exception while cleaning up endpoints:%s", ex)

print("Awaiting DMS assets tear-down.")

dms_client.delete_endpoint(EndpointArn=source_endpoint_arn)
dms_client.delete_endpoint(EndpointArn=target_endpoint_arn)
dms_client.get_waiter("endpoint_deleted").wait(
Filters=[
{
Expand All @@ -247,44 +177,26 @@ def delete_dms_assets(
)


@task(trigger_rule=TriggerRule.ALL_DONE)
def delete_security_group(security_group_id: str, security_group_name: str):
boto3.client("ec2").delete_security_group(GroupId=security_group_id, GroupName=security_group_name)


# setup
# source: aurora serverless
# dest: S3
# S3

with DAG(
dag_id=DAG_ID,
schedule="@once",
start_date=datetime(2021, 1, 1),
tags=["example"],
catchup=False,
) as dag:
test_context = sys_test_context_task()
env_id = test_context[ENV_ID_KEY]
role_arn = test_context[ROLE_ARN_KEY]

bucket_name = f"{env_id}-dms-bucket"
bucket_name = f"{env_id}-dms-serverless-bucket"
rds_instance_name = f"{env_id}-instance"
rds_db_name = f"{env_id}_source_database" # dashes are not allowed in db name
rds_table_name = f"{env_id}-table"
dms_replication_instance_name = f"{env_id}-replication-instance"
dms_replication_task_id = f"{env_id}-replication-task"
source_endpoint_identifier = f"{env_id}-source-endpoint"
target_endpoint_identifier = f"{env_id}-target-endpoint"
security_group_name = f"{env_id}-dms-security-group"
replication_id = f"{env_id}-replication-id"

create_s3_bucket = S3CreateBucketOperator(task_id="create_s3_bucket", bucket_name=bucket_name)

get_vpc_id = get_default_vpc_id()

create_sg = create_security_group(security_group_name, get_vpc_id)

create_db_instance = RdsCreateDbInstanceOperator(
task_id="create_db_instance",
db_instance_identifier=rds_instance_name,
Expand All @@ -296,9 +208,6 @@ def delete_security_group(security_group_id: str, security_group_name: str):
"MasterUsername": RDS_USERNAME,
"MasterUserPassword": RDS_PASSWORD,
"PubliclyAccessible": True,
"VpcSecurityGroupIds": [
create_sg,
],
},
)

Expand Down Expand Up @@ -360,24 +269,24 @@ def delete_security_group(security_group_id: str, security_group_name: str):
},
replication_type="full-load",
table_mappings=json.dumps(table_mappings),
trigger_rule=TriggerRule.ALL_SUCCESS,
)
# [END howto_operator_dms_create_replication_config]

# [START howto_operator_dms_describe_replication_config]
describe_replication_configs = DmsDescribeReplicationConfigsOperator(
task_id="describe_replication_configs",
trigger_rule=TriggerRule.ALL_SUCCESS,
)
# [END howto_operator_dms_describe_replication_config]

# [START howto_operator_dms_serverless_describe_replication]
describe_replications = DmsDescribeReplicationsOperator(
task_id="describe_replications",
trigger_rule=TriggerRule.ALL_SUCCESS,
)
# [END howto_operator_dms_serverless_describe_replication]

# Comment the next two tasks because they take too much time to be run in the CI
# Keep them for documentation purposes
"""
# [START howto_operator_dms_serverless_start_replication]
replicate = DmsStartReplicationOperator(
task_id="replicate",
Expand All @@ -386,34 +295,28 @@ def delete_security_group(security_group_id: str, security_group_name: str):
wait_for_completion=True,
waiter_delay=60,
waiter_max_attempts=200,
trigger_rule=TriggerRule.ALL_SUCCESS,
deferrable=False,
)
# [END howto_operator_dms_serverless_start_replication]

# [START howto_operator_dms_serverless_stop_replication]
stop_relication = DmsStopReplicationOperator(
stop_replication = DmsStopReplicationOperator(
task_id="stop_replication",
replication_config_arn="{{ task_instance.xcom_pull(task_ids='create_replication_config', key='return_value') }}",
wait_for_completion=True,
waiter_delay=120,
waiter_max_attempts=200,
trigger_rule=TriggerRule.ALL_SUCCESS,
deferrable=False,
)
# [END howto_operator_dms_serverless_stop_replication]
"""

# [START howto_operator_dms_serverless_delete_replication_config]
delete_replication_config = DmsDeleteReplicationConfigOperator(
task_id="delete_replication_config",
wait_for_completion=True,
waiter_delay=60,
waiter_max_attempts=200,
deferrable=False,
replication_config_arn="{{ task_instance.xcom_pull(task_ids='create_replication_config', key='return_value') }}",
trigger_rule=TriggerRule.ALL_DONE,
)
# [END howto_operator_dms_serverless_delete_replication_config]
delete_replication_config.trigger_rule = TriggerRule.ALL_DONE

delete_assets = delete_dms_assets(
source_endpoint_arn=create_assets["source_endpoint_arn"],
Expand All @@ -440,31 +343,19 @@ def delete_security_group(security_group_id: str, security_group_name: str):

chain(
# TEST SETUP
test_context,
create_s3_bucket,
get_vpc_id,
create_sg,
create_db_instance,
create_sample_table(rds_instance_name, rds_db_name, rds_table_name),
create_vpc_endpoints(
vpc_id="{{ task_instance.xcom_pull(task_ids='get_default_vpc_id',key='return_value')}}"
),
create_assets,
# TEST BODY
create_replication_config,
describe_replication_configs,
replicate,
stop_relication,
describe_replications,
delete_replication_config,
# TEST TEARDOWN
delete_vpc_endpoints(
endpoint_ids=[
"{{ task_instance.xcom_pull(task_ids='create_vpc_endpoints', key='return_value') }}"
]
),
delete_assets,
delete_db_instance,
delete_security_group(create_sg, security_group_name),
delete_s3_bucket,
)

Expand Down
3 changes: 3 additions & 0 deletions tests/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ class TestAmazonProviderProjectStructure(ExampleCoverageTest):
MISSING_EXAMPLES_FOR_CLASSES = {
# S3 Exasol transfer difficult to test, see: https://github.com/apache/airflow/issues/22632
"airflow.providers.amazon.aws.transfers.exasol_to_s3.ExasolToS3Operator",
# These operations take a lot of time, there are commented out in the system tests for this reason
"airflow.providers.amazon.aws.operators.dms.DmsStartReplicationOperator",
"airflow.providers.amazon.aws.operators.dms.DmsStopReplicationOperator",
}

DEPRECATED_CLASSES = {
Expand Down
Loading