Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions airflow/providers/amazon/aws/hooks/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,35 @@ async def pause_cluster(self, cluster_identifier: str, poll_interval: float = 5.
except botocore.exceptions.ClientError as error:
return {"status": "error", "message": str(error)}

async def resume_cluster(
self,
cluster_identifier: str,
polling_period_seconds: float = 5.0,
) -> dict[str, Any]:
"""
Connects to the AWS redshift cluster via aiobotocore and
resume the cluster for the cluster_identifier passed

:param cluster_identifier: unique identifier of a cluster
:param polling_period_seconds: polling period in seconds to check for the status
"""
async with await self.get_client_async() as client:
try:
response = await client.resume_cluster(ClusterIdentifier=cluster_identifier)
status = response["Cluster"]["ClusterStatus"] if response and response["Cluster"] else None
if status == "resuming":
flag = asyncio.Event()
while True:
expected_response = await asyncio.create_task(
self.get_cluster_status(cluster_identifier, "available", flag)
)
await asyncio.sleep(polling_period_seconds)
if flag.is_set():
return expected_response
return {"status": "error", "cluster_state": status}
except botocore.exceptions.ClientError as error:
return {"status": "error", "message": str(error)}

async def get_cluster_status(
self,
cluster_identifier: str,
Expand Down
63 changes: 50 additions & 13 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,11 @@ class RedshiftResumeClusterOperator(BaseOperator):
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:RedshiftResumeClusterOperator`

:param cluster_identifier: id of the AWS Redshift Cluster
:param aws_conn_id: aws connection to use
:param cluster_identifier: Unique identifier of the AWS Redshift cluster
:param aws_conn_id: The Airflow connection used for AWS credentials.
The default connection id is ``aws_default``
:param deferrable: Run operator in deferrable mode
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state
"""

template_fields: Sequence[str] = ("cluster_identifier",)
Expand All @@ -410,11 +413,15 @@ def __init__(
*,
cluster_identifier: str,
aws_conn_id: str = "aws_default",
deferrable: bool = False,
poll_interval: int = 10,
**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
self.deferrable = deferrable
self.poll_interval = poll_interval
# These parameters are added to address an issue with the boto3 API where the API
# prematurely reports the cluster as available to receive requests. This causes the cluster
# to reject initial attempts to resume the cluster despite reporting the correct state.
Expand All @@ -424,18 +431,48 @@ def __init__(
def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)

while self._attempts >= 1:
try:
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=RedshiftClusterTrigger(
task_id=self.task_id,
poll_interval=self.poll_interval,
aws_conn_id=self.aws_conn_id,
cluster_identifier=self.cluster_identifier,
attempts=self._attempts,
operation_type="pause_cluster",
),
method_name="execute_complete",
)
else:
while self._attempts >= 1:
try:
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to resume cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error
if self._attempts > 0:
self.log.error("Unable to resume cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error

def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if "status" in event and event["status"] == "error":
msg = f"{event['status']}: {event['message']}"
raise AirflowException(msg)
elif "status" in event and event["status"] == "success":
self.log.info("%s completed successfully.", self.task_id)
self.log.info("Paused cluster successfully")
else:
raise AirflowException("No event received from trigger")


class RedshiftPauseClusterOperator(BaseOperator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Resume an Amazon Redshift cluster

To resume a 'paused' Amazon Redshift cluster you can use
:class:`RedshiftResumeClusterOperator <airflow.providers.amazon.aws.operators.redshift_cluster>`
You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_redshift.py
:language: python
Expand Down
50 changes: 50 additions & 0 deletions tests/providers/amazon/aws/operators/test_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,56 @@ def test_resume_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
redshift_operator.execute(None)
assert mock_conn.resume_cluster.call_count == 10

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.resume_cluster")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.get_client_async")
def test_resume_cluster(self, mock_async_client, mock_async_resume_cluster, mock_sync_cluster_status):
"""Test Resume cluster operator run"""
mock_sync_cluster_status.return_value = "paused"
mock_async_client.return_value.resume_cluster.return_value = {
"Cluster": {"ClusterIdentifier": "test_cluster", "ClusterStatus": "resuming"}
}
mock_async_resume_cluster.return_value = {"status": "success", "cluster_state": "available"}

redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
deferrable=True,
)

with pytest.raises(TaskDeferred) as exc:
redshift_operator.execute({})

assert isinstance(
exc.value.trigger, RedshiftClusterTrigger
), "Trigger is not a RedshiftClusterTrigger"

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.resume_cluster")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.get_client_async")
def test_resume_cluster_failure(
self, mock_async_client, mock_async_resume_cluster, mock_sync_cluster_statue
):
"""Test Resume cluster operator Failure"""
mock_sync_cluster_statue.return_value = "paused"
mock_async_client.return_value.resume_cluster.return_value = {
"Cluster": {"ClusterIdentifier": "test_cluster", "ClusterStatus": "resuming"}
}
mock_async_resume_cluster.return_value = {"status": "success", "cluster_state": "available"}

redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
deferrable=True,
)

with pytest.raises(AirflowException):
redshift_operator.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)


class TestPauseClusterOperator:
def test_init(self):
Expand Down
7 changes: 7 additions & 0 deletions tests/system/providers/amazon/aws/example_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def delete_security_group(sec_group_id: str, sec_group_name: str):
task_id="resume_cluster",
cluster_identifier=redshift_cluster_identifier,
)

resume_cluster_in_deferred_mode = RedshiftResumeClusterOperator(
task_id="resume_cluster_in_deferred_mode",
cluster_identifier=redshift_cluster_identifier,
deferrable=True,
)
# [END howto_operator_redshift_resume_cluster]

wait_cluster_available_after_resume = RedshiftClusterSensor(
Expand Down Expand Up @@ -279,6 +285,7 @@ def delete_security_group(sec_group_id: str, sec_group_name: str):
pause_cluster,
wait_cluster_paused,
resume_cluster,
resume_cluster_in_deferred_mode,
wait_cluster_available_after_resume,
set_up_connection,
create_table_redshift_data,
Expand Down