Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion airflow/providers/amazon/aws/hooks/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,20 @@ def restore_from_cluster_snapshot(self, cluster_identifier: str, snapshot_identi
)
return response['Cluster'] if response['Cluster'] else None

def create_cluster_snapshot(self, snapshot_identifier: str, cluster_identifier: str) -> str:
def create_cluster_snapshot(
self, snapshot_identifier: str, cluster_identifier: str, retention_period: int = -1
) -> str:
"""
Creates a snapshot of a cluster

:param snapshot_identifier: unique identifier for a snapshot of a cluster
:param cluster_identifier: unique identifier of a cluster
:param retention_period: The number of days that a manual snapshot is retained.
If the value is -1, the manual snapshot is retained indefinitely.
"""
response = self.get_conn().create_cluster_snapshot(
SnapshotIdentifier=snapshot_identifier,
ClusterIdentifier=cluster_identifier,
ManualSnapshotRetentionPeriod=retention_period,
)
return response['Snapshot'] if response['Snapshot'] else None
66 changes: 66 additions & 0 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook

Expand Down Expand Up @@ -242,6 +243,71 @@ def execute(self, context: 'Context'):
self.log.info(cluster)


class RedshiftCreateClusterSnapshotOperator(BaseOperator):
"""
Creates a manual snapshot of the specified cluster. The cluster must be in the available state

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:RedshiftCreateClusterSnapshotOperator`

:param snapshot_identifier: A unique identifier for the snapshot that you are requesting
:param cluster_identifier: The cluster identifier for which you want a snapshot
:param retention_period: The number of days that a manual snapshot is retained.
If the value is -1, the manual snapshot is retained indefinitely.
:param wait_for_completion: Whether wait for the cluster snapshot to be in ``available`` state
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check state
:param max_attempt: The maximum number of attempts to be made to check the state
:param aws_conn_id: The Airflow connection used for AWS credentials.
The default connection id is ``aws_default``
"""

def __init__(
self,
*,
snapshot_identifier: str,
cluster_identifier: str,
retention_period: int = -1,
wait_for_completion: bool = False,
poll_interval: int = 15,
max_attempt: int = 20,
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
self.snapshot_identifier = snapshot_identifier
self.cluster_identifier = cluster_identifier
self.retention_period = retention_period
self.wait_for_completion = wait_for_completion
self.poll_interval = poll_interval
self.max_attempt = max_attempt
self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id)

def execute(self, context: "Context") -> Any:
cluster_state = self.redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
if cluster_state != "available":
raise AirflowException(
"Redshift cluster must be in available state. "
f"Redshift cluster current state is {cluster_state}"
)

self.redshift_hook.create_cluster_snapshot(
cluster_identifier=self.cluster_identifier,
snapshot_identifier=self.snapshot_identifier,
retention_period=self.retention_period,
)

if self.wait_for_completion:
self.redshift_hook.get_conn().get_waiter("snapshot_available").wait(
ClusterIdentifier=self.cluster_identifier,
SnapshotIdentifier=self.snapshot_identifier,
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": self.max_attempt,
},
)


class RedshiftResumeClusterOperator(BaseOperator):
"""
Resume a paused AWS Redshift Cluster
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ To pause an 'available' Amazon Redshift cluster you can use
:start-after: [START howto_operator_redshift_pause_cluster]
:end-before: [END howto_operator_redshift_pause_cluster]

.. _howto/operator:RedshiftCreateClusterSnapshotOperator:

Create an Amazon Redshift cluster snapshot
==========================================

To create Amazon Redshift cluster snapshot you can use
:class:`RedshiftCreateClusterSnapshotOperator <airflow.providers.amazon.aws.operators.redshift_cluster>`

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_redshift_cluster.py
:language: python
:dedent: 4
:start-after: [START howto_operator_redshift_create_cluster_snapshot]
:end-before: [END howto_operator_redshift_create_cluster_snapshot]

.. _howto/operator:RedshiftDeleteClusterOperator:

Delete an Amazon Redshift cluster
Expand Down
53 changes: 53 additions & 0 deletions tests/providers/amazon/aws/operators/test_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@

from unittest import mock

import pytest

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.redshift_cluster import (
RedshiftCreateClusterOperator,
RedshiftCreateClusterSnapshotOperator,
RedshiftDeleteClusterOperator,
RedshiftPauseClusterOperator,
RedshiftResumeClusterOperator,
Expand Down Expand Up @@ -99,6 +103,55 @@ def test_create_multi_node_cluster(self, mock_get_conn):
)


class TestRedshiftCreateClusterSnapshotOperator:
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_create_cluster_snapshot_is_called_when_cluster_is_available(
self, mock_get_conn, mock_cluster_status
):
mock_cluster_status.return_value = "available"
create_snapshot = RedshiftCreateClusterSnapshotOperator(
task_id="test_snapshot",
cluster_identifier="test_cluster",
snapshot_identifier="test_snapshot",
retention_period=1,
)
create_snapshot.execute(None)
mock_get_conn.return_value.create_cluster_snapshot.assert_called_once_with(
ClusterIdentifier='test_cluster',
SnapshotIdentifier="test_snapshot",
ManualSnapshotRetentionPeriod=1,
)

mock_get_conn.return_value.get_waiter.assert_not_called()

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
def test_raise_exception_when_cluster_is_not_available(self, mock_cluster_status):
mock_cluster_status.return_value = "paused"
create_snapshot = RedshiftCreateClusterSnapshotOperator(
task_id="test_snapshot", cluster_identifier="test_cluster", snapshot_identifier="test_snapshot"
)
with pytest.raises(AirflowException):
create_snapshot.execute(None)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_create_cluster_snapshot_with_wait(self, mock_get_conn, mock_cluster_status):
mock_cluster_status.return_value = "available"
create_snapshot = RedshiftCreateClusterSnapshotOperator(
task_id="test_snapshot",
cluster_identifier="test_cluster",
snapshot_identifier="test_snapshot",
wait_for_completion=True,
)
create_snapshot.execute(None)
mock_get_conn.return_value.get_waiter.return_value.wait.assert_called_once_with(
ClusterIdentifier="test_cluster",
SnapshotIdentifier="test_snapshot",
WaiterConfig={"Delay": 15, "MaxAttempts": 20},
)


class TestResumeClusterOperator:
def test_init(self):
redshift_operator = RedshiftResumeClusterOperator(
Expand Down
15 changes: 15 additions & 0 deletions tests/system/providers/amazon/aws/example_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.operators.redshift_cluster import (
RedshiftCreateClusterOperator,
RedshiftCreateClusterSnapshotOperator,
RedshiftDeleteClusterOperator,
RedshiftPauseClusterOperator,
RedshiftResumeClusterOperator,
Expand All @@ -34,6 +35,9 @@
ENV_ID = set_env_id()
DAG_ID = 'example_redshift_cluster'
REDSHIFT_CLUSTER_IDENTIFIER = getenv("REDSHIFT_CLUSTER_IDENTIFIER", "redshift-cluster-1")
REDSHIFT_CLUSTER_SNAPSHOT_IDENTIFIER = getenv(
"REDSHIFT_CLUSTER_SNAPSHOT_IDENTIFIER", "redshift-cluster-snapshot-1"
)

with DAG(
dag_id=DAG_ID,
Expand Down Expand Up @@ -85,6 +89,16 @@
)
# [END howto_operator_redshift_resume_cluster]

# [START howto_operator_redshift_create_cluster_snapshot]
task_create_cluster_snapshot = RedshiftCreateClusterSnapshotOperator(
task_id='create_cluster_snapshot',
cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER,
snapshot_identifier=REDSHIFT_CLUSTER_SNAPSHOT_IDENTIFIER,
retention_period=1,
poll_interval=5,
)
# [END howto_operator_redshift_create_cluster_snapshot]

# [START howto_operator_redshift_delete_cluster]
task_delete_cluster = RedshiftDeleteClusterOperator(
task_id="delete_cluster",
Expand All @@ -99,6 +113,7 @@
task_pause_cluster,
task_wait_cluster_paused,
task_resume_cluster,
task_create_cluster_snapshot,
task_delete_cluster,
)

Expand Down