Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/hooks/glacier.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ class GlacierHook(AwsBaseHook):
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

def __init__(self, aws_conn_id: str = "aws_default") -> None:
super().__init__(client_type="glacier")
self.aws_conn_id = aws_conn_id
def __init__(self, *args, **kwargs) -> None:
kwargs.update({"client_type": "glacier", "resource_type": None})
super().__init__(*args, **kwargs)
Comment on lines -38 to +40
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accidentally found that GlacierHook only allows aws_conn_id and ignore other AwsBaseHook's parameters


def retrieve_inventory(self, vault_name: str) -> dict[str, Any]:
"""Initiate an Amazon Glacier inventory-retrieval job.
Expand Down
30 changes: 11 additions & 19 deletions airflow/providers/amazon/aws/operators/glacier.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

from typing import TYPE_CHECKING, Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.glacier import GlacierHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class GlacierCreateJobOperator(BaseOperator):
class GlacierCreateJobOperator(AwsBaseOperator[GlacierHook]):
"""
Initiate an Amazon Glacier inventory-retrieval job.

Expand All @@ -38,25 +39,18 @@ class GlacierCreateJobOperator(BaseOperator):
:param vault_name: the Glacier vault on which job is executed
"""

template_fields: Sequence[str] = ("vault_name",)
aws_hook_class = GlacierHook
template_fields: Sequence[str] = aws_template_fields("vault_name")

def __init__(
self,
*,
aws_conn_id="aws_default",
vault_name: str,
**kwargs,
):
def __init__(self, *, vault_name: str, **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.vault_name = vault_name

def execute(self, context: Context):
hook = GlacierHook(aws_conn_id=self.aws_conn_id)
return hook.retrieve_inventory(vault_name=self.vault_name)
return self.hook.retrieve_inventory(vault_name=self.vault_name)


class GlacierUploadArchiveOperator(BaseOperator):
class GlacierUploadArchiveOperator(AwsBaseOperator[GlacierHook]):
"""
This operator add an archive to an Amazon S3 Glacier vault.

Expand All @@ -74,7 +68,8 @@ class GlacierUploadArchiveOperator(BaseOperator):
:param aws_conn_id: The reference to the AWS connection details
"""

template_fields: Sequence[str] = ("vault_name",)
aws_hook_class = GlacierHook
template_fields: Sequence[str] = aws_template_fields("vault_name")

def __init__(
self,
Expand All @@ -84,20 +79,17 @@ def __init__(
checksum: str | None = None,
archive_description: str | None = None,
account_id: str | None = None,
aws_conn_id="aws_default",
**kwargs,
):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.account_id = account_id
self.vault_name = vault_name
self.body = body
self.checksum = checksum
self.archive_description = archive_description

def execute(self, context: Context):
hook = GlacierHook(aws_conn_id=self.aws_conn_id)
return hook.get_conn().upload_archive(
return self.hook.conn.upload_archive(
accountId=self.account_id,
vaultName=self.vault_name,
archiveDescription=self.archive_description,
Expand Down
15 changes: 5 additions & 10 deletions airflow/providers/amazon/aws/sensors/glacier.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from __future__ import annotations

from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.hooks.glacier import GlacierHook
from airflow.sensors.base import BaseSensorOperator
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand All @@ -36,7 +36,7 @@ class JobStatus(Enum):
SUCCEEDED = "Succeeded"


class GlacierJobOperationSensor(BaseSensorOperator):
class GlacierJobOperationSensor(AwsBaseSensor[GlacierHook]):
"""
Glacier sensor for checking job state. This operator runs only in reschedule mode.

Expand All @@ -63,29 +63,24 @@ class GlacierJobOperationSensor(BaseSensorOperator):
prevent too much load on the scheduler.
"""

template_fields: Sequence[str] = ("vault_name", "job_id")
aws_hook_class = GlacierHook
template_fields: Sequence[str] = aws_template_fields("vault_name", "job_id")

def __init__(
self,
*,
aws_conn_id: str = "aws_default",
vault_name: str,
job_id: str,
poke_interval: int = 60 * 20,
mode: str = "reschedule",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.vault_name = vault_name
self.job_id = job_id
self.poke_interval = poke_interval
self.mode = mode

@cached_property
def hook(self):
return GlacierHook(aws_conn_id=self.aws_conn_id)

def poke(self, context: Context) -> bool:
response = self.hook.describe_job(vault_name=self.vault_name, job_id=self.job_id)

Expand Down
5 changes: 5 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/s3/glacier.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ Prerequisite Tasks

.. include:: ../../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../../_partials/generic_parameters.rst

Operators
---------

Expand Down
72 changes: 58 additions & 14 deletions tests/providers/amazon/aws/operators/test_glacier.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from unittest import mock

import pytest

from airflow.providers.amazon.aws.operators.glacier import (
GlacierCreateJobOperator,
GlacierUploadArchiveOperator,
)

if TYPE_CHECKING:
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator

AWS_CONN_ID = "aws_default"
BUCKET_NAME = "airflow_bucket"
FILENAME = "path/to/file/"
Expand All @@ -34,22 +40,60 @@
VAULT_NAME = "airflow"


class TestGlacierCreateJobOperator:
@mock.patch("airflow.providers.amazon.aws.operators.glacier.GlacierHook")
class BaseGlacierOperatorsTests:
op_class: type[AwsBaseOperator]
default_op_kwargs: dict[str, Any]

def test_base_aws_op_attributes(self):
op = self.op_class(**self.default_op_kwargs)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

op = self.op_class(
**self.default_op_kwargs,
aws_conn_id="aws-test-custom-conn",
region_name="eu-west-1",
verify=False,
botocore_config={"read_timeout": 42},
)
assert op.hook.aws_conn_id == "aws-test-custom-conn"
assert op.hook._region_name == "eu-west-1"
assert op.hook._verify is False
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42


class TestGlacierCreateJobOperator(BaseGlacierOperatorsTests):
op_class = GlacierCreateJobOperator

@pytest.fixture(autouse=True)
def setup_test_cases(self):
self.default_op_kwargs = {"vault_name": VAULT_NAME, "task_id": TASK_ID}

@mock.patch.object(GlacierCreateJobOperator, "hook", new_callable=mock.PropertyMock)
def test_execute(self, hook_mock):
op = GlacierCreateJobOperator(aws_conn_id=AWS_CONN_ID, vault_name=VAULT_NAME, task_id=TASK_ID)
op = self.op_class(aws_conn_id=None, **self.default_op_kwargs)
op.execute(mock.MagicMock())
hook_mock.assert_called_once_with(aws_conn_id=AWS_CONN_ID)
hook_mock.return_value.retrieve_inventory.assert_called_once_with(vault_name=VAULT_NAME)


class TestGlacierUploadArchiveOperator:
@mock.patch("airflow.providers.amazon.aws.operators.glacier.GlacierHook.get_conn")
def test_execute(self, hook_mock):
op = GlacierUploadArchiveOperator(
aws_conn_id=AWS_CONN_ID, vault_name=VAULT_NAME, body=b"Test Data", task_id=TASK_ID
)
op.execute(mock.MagicMock())
hook_mock.return_value.upload_archive.assert_called_once_with(
accountId=None, vaultName=VAULT_NAME, archiveDescription=None, body=b"Test Data", checksum=None
)
class TestGlacierUploadArchiveOperator(BaseGlacierOperatorsTests):
op_class = GlacierUploadArchiveOperator

@pytest.fixture(autouse=True)
def setup_test_cases(self):
self.default_op_kwargs = {"vault_name": VAULT_NAME, "task_id": TASK_ID, "body": b"Test Data"}

def test_execute(self):
with mock.patch.object(self.op_class.aws_hook_class, "conn", new_callable=mock.PropertyMock) as m:
op = self.op_class(aws_conn_id=None, **self.default_op_kwargs)
op.execute(mock.MagicMock())
m.return_value.upload_archive.assert_called_once_with(
accountId=None,
vaultName=VAULT_NAME,
archiveDescription=None,
body=b"Test Data",
checksum=None,
)
65 changes: 42 additions & 23 deletions tests/providers/amazon/aws/sensors/test_glacier.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,49 +28,68 @@
IN_PROGRESS = "InProgress"


@pytest.fixture
def mocked_describe_job():
with mock.patch("airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job") as m:
yield m


class TestAmazonGlacierSensor:
def setup_method(self):
self.op = GlacierJobOperationSensor(
self.default_op_kwargs = dict(
task_id="test_athena_sensor",
aws_conn_id="aws_default",
vault_name="airflow",
job_id="1a2b3c4d",
poke_interval=60 * 20,
)
self.op = GlacierJobOperationSensor(**self.default_op_kwargs, aws_conn_id=None)

@mock.patch(
"airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job",
side_effect=[{"Action": "", "StatusCode": JobStatus.SUCCEEDED.value}],
)
def test_poke_succeeded(self, _):
def test_base_aws_op_attributes(self):
op = GlacierJobOperationSensor(**self.default_op_kwargs)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

op = GlacierJobOperationSensor(
**self.default_op_kwargs,
aws_conn_id="aws-test-custom-conn",
region_name="eu-west-1",
verify=False,
botocore_config={"read_timeout": 42},
)
assert op.hook.aws_conn_id == "aws-test-custom-conn"
assert op.hook._region_name == "eu-west-1"
assert op.hook._verify is False
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

def test_poke_succeeded(self, mocked_describe_job):
mocked_describe_job.side_effect = [{"Action": "", "StatusCode": JobStatus.SUCCEEDED.value}]
assert self.op.poke(None)

@mock.patch(
"airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job",
side_effect=[{"Action": "", "StatusCode": JobStatus.IN_PROGRESS.value}],
)
def test_poke_in_progress(self, _):
def test_poke_in_progress(self, mocked_describe_job):
mocked_describe_job.side_effect = [{"Action": "", "StatusCode": JobStatus.IN_PROGRESS.value}]
assert not self.op.poke(None)

@mock.patch(
"airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job",
side_effect=[{"Action": "", "StatusCode": ""}],
)
def test_poke_fail(self, _):
with pytest.raises(AirflowException) as ctx:
def test_poke_fail(self, mocked_describe_job):
mocked_describe_job.side_effect = [{"Action": "", "StatusCode": ""}]
with pytest.raises(AirflowException, match="Sensor failed"):
self.op.poke(None)
assert "Sensor failed" in str(ctx.value)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
"soft_fail, expected_exception",
[
pytest.param(False, AirflowException, id="not-soft-fail"),
pytest.param(True, AirflowSkipException, id="soft-fail"),
],
)
@mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.describe_job")
def test_fail_poke(self, describe_job, soft_fail, expected_exception):
def test_fail_poke(self, soft_fail, expected_exception, mocked_describe_job):
self.op.soft_fail = soft_fail
response = {"Action": "some action", "StatusCode": "Failed"}
message = f'Sensor failed. Job status: {response["Action"]}, code status: {response["StatusCode"]}'
with pytest.raises(expected_exception, match=message):
describe_job.return_value = response
mocked_describe_job.return_value = response
self.op.poke(context={})


Expand Down