Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 57 additions & 69 deletions airflow/providers/amazon/aws/operators/dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

from typing import TYPE_CHECKING, Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.dms import DmsHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class DmsCreateTaskOperator(BaseOperator):
class DmsCreateTaskOperator(AwsBaseOperator[DmsHook]):
"""
Creates AWS DMS replication task.

Expand All @@ -42,13 +43,19 @@ class DmsCreateTaskOperator(BaseOperator):
:param migration_type: Migration type ('full-load'|'cdc'|'full-load-and-cdc'), full-load by default.
:param create_task_kwargs: Extra arguments for DMS replication task creation.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = (
aws_hook_class = DmsHook
template_fields: Sequence[str] = aws_template_fields(
"replication_task_id",
"source_endpoint_arn",
"target_endpoint_arn",
Expand All @@ -57,7 +64,6 @@ class DmsCreateTaskOperator(BaseOperator):
"migration_type",
"create_task_kwargs",
)
template_ext: Sequence[str] = ()
template_fields_renderers = {
"table_mappings": "json",
"create_task_kwargs": "json",
Expand Down Expand Up @@ -92,9 +98,7 @@ def execute(self, context: Context):

:return: replication task arn
"""
dms_hook = DmsHook(aws_conn_id=self.aws_conn_id)

task_arn = dms_hook.create_replication_task(
task_arn = self.hook.create_replication_task(
replication_task_id=self.replication_task_id,
source_endpoint_arn=self.source_endpoint_arn,
target_endpoint_arn=self.target_endpoint_arn,
Expand All @@ -108,7 +112,7 @@ def execute(self, context: Context):
return task_arn


class DmsDeleteTaskOperator(BaseOperator):
class DmsDeleteTaskOperator(AwsBaseOperator[DmsHook]):
"""
Deletes AWS DMS replication task.

Expand All @@ -118,39 +122,35 @@ class DmsDeleteTaskOperator(BaseOperator):

:param replication_task_arn: Replication task ARN
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("replication_task_arn",)
template_ext: Sequence[str] = ()
template_fields_renderers: dict[str, str] = {}
aws_hook_class = DmsHook
template_fields: Sequence[str] = aws_template_fields("replication_task_arn")

def __init__(
self,
*,
replication_task_arn: str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
def __init__(self, *, replication_task_arn: str | None = None, **kwargs):
super().__init__(**kwargs)
self.replication_task_arn = replication_task_arn
self.aws_conn_id = aws_conn_id

def execute(self, context: Context):
"""
Delete AWS DMS replication task from Airflow.

:return: replication task arn
"""
dms_hook = DmsHook(aws_conn_id=self.aws_conn_id)
dms_hook.delete_replication_task(replication_task_arn=self.replication_task_arn)
self.hook.delete_replication_task(replication_task_arn=self.replication_task_arn)
self.log.info("DMS replication task(%s) has been deleted.", self.replication_task_arn)


class DmsDescribeTasksOperator(BaseOperator):
class DmsDescribeTasksOperator(AwsBaseOperator[DmsHook]):
"""
Describes AWS DMS replication tasks.

Expand All @@ -160,38 +160,35 @@ class DmsDescribeTasksOperator(BaseOperator):

:param describe_tasks_kwargs: Describe tasks command arguments
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("describe_tasks_kwargs",)
template_ext: Sequence[str] = ()
aws_hook_class = DmsHook
template_fields: Sequence[str] = aws_template_fields("describe_tasks_kwargs")
template_fields_renderers: dict[str, str] = {"describe_tasks_kwargs": "json"}

def __init__(
self,
*,
describe_tasks_kwargs: dict | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
def __init__(self, *, describe_tasks_kwargs: dict | None = None, **kwargs):
super().__init__(**kwargs)
self.describe_tasks_kwargs = describe_tasks_kwargs or {}
self.aws_conn_id = aws_conn_id

def execute(self, context: Context) -> tuple[str | None, list]:
"""
Describe AWS DMS replication tasks from Airflow.

:return: Marker and list of replication tasks
"""
dms_hook = DmsHook(aws_conn_id=self.aws_conn_id)
return dms_hook.describe_replication_tasks(**self.describe_tasks_kwargs)
return self.hook.describe_replication_tasks(**self.describe_tasks_kwargs)


class DmsStartTaskOperator(BaseOperator):
class DmsStartTaskOperator(AwsBaseOperator[DmsHook]):
"""
Starts AWS DMS replication task.

Expand All @@ -204,18 +201,23 @@ class DmsStartTaskOperator(BaseOperator):
('start-replication'|'resume-processing'|'reload-target')
:param start_task_kwargs: Extra start replication task arguments
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = (
aws_hook_class = DmsHook
template_fields: Sequence[str] = aws_template_fields(
"replication_task_arn",
"start_replication_task_type",
"start_task_kwargs",
)
template_ext: Sequence[str] = ()
template_fields_renderers = {"start_task_kwargs": "json"}

def __init__(
Expand All @@ -234,22 +236,16 @@ def __init__(
self.aws_conn_id = aws_conn_id

def execute(self, context: Context):
"""
Start AWS DMS replication task from Airflow.

:return: replication task arn
"""
dms_hook = DmsHook(aws_conn_id=self.aws_conn_id)

dms_hook.start_replication_task(
"""Start AWS DMS replication task from Airflow."""
self.hook.start_replication_task(
replication_task_arn=self.replication_task_arn,
start_replication_task_type=self.start_replication_task_type,
**self.start_task_kwargs,
)
self.log.info("DMS replication task(%s) is starting.", self.replication_task_arn)


class DmsStopTaskOperator(BaseOperator):
class DmsStopTaskOperator(AwsBaseOperator[DmsHook]):
"""
Stops AWS DMS replication task.

Expand All @@ -259,33 +255,25 @@ class DmsStopTaskOperator(BaseOperator):

:param replication_task_arn: Replication task ARN
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("replication_task_arn",)
template_ext: Sequence[str] = ()
template_fields_renderers: dict[str, str] = {}
aws_hook_class = DmsHook
template_fields: Sequence[str] = aws_template_fields("replication_task_arn")

def __init__(
self,
*,
replication_task_arn: str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
def __init__(self, *, replication_task_arn: str | None = None, **kwargs):
super().__init__(**kwargs)
self.replication_task_arn = replication_task_arn
self.aws_conn_id = aws_conn_id

def execute(self, context: Context):
"""
Stop AWS DMS replication task from Airflow.

:return: replication task arn
"""
dms_hook = DmsHook(aws_conn_id=self.aws_conn_id)
dms_hook.stop_replication_task(replication_task_arn=self.replication_task_arn)
"""Stop AWS DMS replication task from Airflow."""
self.hook.stop_replication_task(replication_task_arn=self.replication_task_arn)
self.log.info("DMS replication task(%s) is stopping.", self.replication_task_arn)
55 changes: 31 additions & 24 deletions airflow/providers/amazon/aws/sensors/dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,53 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Iterable, Sequence

from deprecated import deprecated

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.dms import DmsHook
from airflow.sensors.base import BaseSensorOperator
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class DmsTaskBaseSensor(BaseSensorOperator):
class DmsTaskBaseSensor(AwsBaseSensor[DmsHook]):
"""
Contains general sensor behavior for DMS task.

Subclasses should set ``target_statuses`` and ``termination_statuses`` fields.

:param replication_task_arn: AWS DMS replication task ARN
:param aws_conn_id: aws connection to uses
:param target_statuses: the target statuses, sensor waits until
the task reaches any of these states
:param termination_statuses: the termination statuses, sensor fails when
the task reaches any of these states
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("replication_task_arn",)
template_ext: Sequence[str] = ()
aws_hook_class = DmsHook
template_fields: Sequence[str] = aws_template_fields("replication_task_arn")

def __init__(
self,
replication_task_arn: str,
aws_conn_id="aws_default",
target_statuses: Iterable[str] | None = None,
termination_statuses: Iterable[str] | None = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id
super().__init__(**kwargs)
self.replication_task_arn = replication_task_arn
self.target_statuses: Iterable[str] = target_statuses or []
self.termination_statuses: Iterable[str] = termination_statuses or []
Expand All @@ -67,14 +73,8 @@ def get_hook(self) -> DmsHook:
"""Get DmsHook."""
return self.hook

@cached_property
def hook(self) -> DmsHook:
return DmsHook(self.aws_conn_id)

def poke(self, context: Context):
status: str | None = self.hook.get_task_status(self.replication_task_arn)

if not status:
if not (status := self.hook.get_task_status(self.replication_task_arn)):
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Failed to read task status, task with ARN {self.replication_task_arn} not found"
if self.soft_fail:
Expand Down Expand Up @@ -105,15 +105,21 @@ class DmsTaskCompletedSensor(DmsTaskBaseSensor):
:ref:`howto/sensor:DmsTaskCompletedSensor`

:param replication_task_arn: AWS DMS replication task ARN
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("replication_task_arn",)
template_ext: Sequence[str] = ()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.target_statuses = ["stopped"]
self.termination_statuses = [
def __init__(self, **kwargs):
kwargs["target_statuses"] = ["stopped"]
kwargs["termination_statuses"] = [
"creating",
"deleting",
"failed",
Expand All @@ -123,3 +129,4 @@ def __init__(self, *args, **kwargs):
"ready",
"testing",
]
super().__init__(**kwargs)
5 changes: 5 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/dms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
Loading