Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions airflow/providers/amazon/aws/operators/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class AthenaOperator(BaseOperator):
class AthenaOperator(AwsBaseOperator[AthenaHook]):
"""
An operator that submits a presto query to athena.
An operator that submits a Trino/Presto query to Amazon Athena.

.. note:: if the task is killed while it runs, it'll cancel the athena query that was launched,
EXCEPT if running in deferrable mode.
Expand All @@ -41,11 +41,10 @@ class AthenaOperator(BaseOperator):
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:AthenaOperator`

:param query: Presto to be run on athena. (templated)
:param query: Trino/Presto query to be run on Amazon Athena. (templated)
:param database: Database to select. (templated)
:param catalog: Catalog to select. (templated)
:param output_location: s3 path to write the query results into. (templated)
:param aws_conn_id: aws connection to use
:param client_request_token: Unique token created by user to avoid multiple executions of same query
:param workgroup: Athena workgroup in which query will be run. (templated)
:param query_execution_context: Context in which query need to be run
Expand All @@ -55,10 +54,23 @@ class AthenaOperator(BaseOperator):
To limit task execution time, use execution_timeout.
:param log_query: Whether to log athena query and other execution params when it's executed.
Defaults to *True*.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

aws_hook_class = AthenaHook
ui_color = "#44b5e2"
template_fields: Sequence[str] = ("query", "database", "output_location", "workgroup", "catalog")
template_fields: Sequence[str] = aws_template_fields(
"query", "database", "output_location", "workgroup", "catalog"
)
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"query": "sql"}

Expand All @@ -68,7 +80,6 @@ def __init__(
query: str,
database: str,
output_location: str,
aws_conn_id: str = "aws_default",
client_request_token: str | None = None,
workgroup: str = "primary",
query_execution_context: dict[str, str] | None = None,
Expand All @@ -84,7 +95,6 @@ def __init__(
self.query = query
self.database = database
self.output_location = output_location
self.aws_conn_id = aws_conn_id
self.client_request_token = client_request_token
self.workgroup = workgroup
self.query_execution_context = query_execution_context or {}
Expand All @@ -96,13 +106,12 @@ def __init__(
self.deferrable = deferrable
self.catalog: str = catalog

@cached_property
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook."""
return AthenaHook(self.aws_conn_id, log_query=self.log_query)
@property
def _hook_parameters(self) -> dict[str, Any]:
return {**super()._hook_parameters, "log_query": self.log_query}

def execute(self, context: Context) -> str | None:
"""Run Presto Query on Athena."""
"""Run Trino/Presto Query on Amazon Athena."""
self.query_execution_context["Database"] = self.database
self.query_execution_context["Catalog"] = self.catalog
self.result_configuration["OutputLocation"] = self.output_location
Expand All @@ -117,7 +126,13 @@ def execute(self, context: Context) -> str | None:
if self.deferrable:
self.defer(
trigger=AthenaTrigger(
self.query_execution_id, self.sleep_time, self.max_polling_attempts, self.aws_conn_id
query_execution_id=self.query_execution_id,
waiter_delay=self.sleep_time,
waiter_max_attempts=self.max_polling_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -148,7 +163,7 @@ def execute_complete(self, context, event=None):
return event["value"]

def on_kill(self) -> None:
"""Cancel the submitted athena query."""
"""Cancel the submitted Amazon Athena query."""
if self.query_execution_id:
self.log.info("Received a kill signal.")
response = self.hook.stop_query(self.query_execution_id)
Expand Down
31 changes: 18 additions & 13 deletions airflow/providers/amazon/aws/sensors/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.sensors.base import BaseSensorOperator


class AthenaSensor(BaseSensorOperator):
class AthenaSensor(AwsBaseSensor[AthenaHook]):
"""
Poll the state of the Query until it reaches a terminal state; fails if the query fails.

Expand All @@ -40,9 +41,18 @@ class AthenaSensor(BaseSensorOperator):
:param query_execution_id: query_execution_id to check the state of
:param max_retries: Number of times to poll for query state before
returning the current state, defaults to None
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
:param sleep_time: Time in seconds to wait between two consecutive call to
check query status on athena, defaults to 10
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

INTERMEDIATE_STATES = (
Expand All @@ -55,21 +65,21 @@ class AthenaSensor(BaseSensorOperator):
)
SUCCESS_STATES = ("SUCCEEDED",)

template_fields: Sequence[str] = ("query_execution_id",)
template_ext: Sequence[str] = ()
aws_hook_class = AthenaHook
template_fields: Sequence[str] = aws_template_fields(
"query_execution_id",
)
ui_color = "#66c3ff"

def __init__(
self,
*,
query_execution_id: str,
max_retries: int | None = None,
aws_conn_id: str = "aws_default",
sleep_time: int = 10,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.query_execution_id = query_execution_id
self.sleep_time = sleep_time
self.max_retries = max_retries
Expand All @@ -87,8 +97,3 @@ def poke(self, context: Context) -> bool:
if state in self.INTERMEDIATE_STATES:
return False
return True

@cached_property
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook."""
return AthenaHook(self.aws_conn_id)
11 changes: 9 additions & 2 deletions airflow/providers/amazon/aws/triggers/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def __init__(
query_execution_id: str,
waiter_delay: int,
waiter_max_attempts: int,
aws_conn_id: str,
aws_conn_id: str | None,
**kwargs,
):
super().__init__(
serialized_fields={"query_execution_id": query_execution_id},
Expand All @@ -56,7 +57,13 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return AthenaHook(self.aws_conn_id)
return AthenaHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)
5 changes: 5 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/athena.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
30 changes: 26 additions & 4 deletions tests/providers/amazon/aws/operators/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,47 @@ def setup_method(self):
"start_date": DEFAULT_DATE,
}

self.dag = DAG(f"{TEST_DAG_ID}test_schedule_dag_once", default_args=args, schedule="@once")
self.dag = DAG(TEST_DAG_ID, default_args=args, schedule="@once")

self.athena = AthenaOperator(
self.default_op_kwargs = dict(
task_id="test_athena_operator",
query="SELECT * FROM TEST_TABLE",
database="TEST_DATABASE",
output_location="s3://test_s3_bucket/",
client_request_token="eac427d0-1c6d-4dfb-96aa-2835d3ac6595",
sleep_time=0,
max_polling_attempts=3,
dag=self.dag,
)
self.athena = AthenaOperator(**self.default_op_kwargs, aws_conn_id=None, dag=self.dag)

def test_base_aws_op_attributes(self):
op = AthenaOperator(**self.default_op_kwargs)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None
assert op.hook.log_query is True

op = AthenaOperator(
**self.default_op_kwargs,
aws_conn_id="aws-test-custom-conn",
region_name="eu-west-1",
verify=False,
botocore_config={"read_timeout": 42},
log_query=False,
)
assert op.hook.aws_conn_id == "aws-test-custom-conn"
assert op.hook._region_name == "eu-west-1"
assert op.hook._verify is False
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42
assert op.hook.log_query is False

def test_init(self):
assert self.athena.task_id == MOCK_DATA["task_id"]
assert self.athena.query == MOCK_DATA["query"]
assert self.athena.database == MOCK_DATA["database"]
assert self.athena.catalog == MOCK_DATA["catalog"]
assert self.athena.aws_conn_id == "aws_default"
assert self.athena.client_request_token == MOCK_DATA["client_request_token"]
assert self.athena.sleep_time == 0

Expand Down
76 changes: 46 additions & 30 deletions tests/providers/amazon/aws/sensors/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,64 @@
from airflow.providers.amazon.aws.sensors.athena import AthenaSensor


@pytest.fixture
def mock_poll_query_status():
with mock.patch.object(AthenaHook, "poll_query_status") as m:
yield m


class TestAthenaSensor:
def setup_method(self):
self.sensor = AthenaSensor(
self.default_op_kwargs = dict(
task_id="test_athena_sensor",
query_execution_id="abc",
sleep_time=5,
max_retries=1,
aws_conn_id="aws_default",
)
self.sensor = AthenaSensor(**self.default_op_kwargs, aws_conn_id=None)

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("SUCCEEDED",))
def test_poke_success(self, mock_poll_query_status):
assert self.sensor.poke({}) is True

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("RUNNING",))
def test_poke_running(self, mock_poll_query_status):
assert self.sensor.poke({}) is False
def test_base_aws_op_attributes(self):
op = AthenaSensor(**self.default_op_kwargs)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None
assert op.hook.log_query is True

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("QUEUED",))
def test_poke_queued(self, mock_poll_query_status):
assert self.sensor.poke({}) is False
op = AthenaSensor(
**self.default_op_kwargs,
aws_conn_id="aws-test-custom-conn",
region_name="eu-west-1",
verify=False,
botocore_config={"read_timeout": 42},
)
assert op.hook.aws_conn_id == "aws-test-custom-conn"
assert op.hook._region_name == "eu-west-1"
assert op.hook._verify is False
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("FAILED",))
def test_poke_failed(self, mock_poll_query_status):
with pytest.raises(AirflowException) as ctx:
self.sensor.poke({})
assert "Athena sensor failed" in str(ctx.value)
@pytest.mark.parametrize("state", ["SUCCEEDED"])
def test_poke_success_states(self, state, mock_poll_query_status):
mock_poll_query_status.side_effect = [state]
assert self.sensor.poke({}) is True

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("CANCELLED",))
def test_poke_cancelled(self, mock_poll_query_status):
with pytest.raises(AirflowException) as ctx:
self.sensor.poke({})
assert "Athena sensor failed" in str(ctx.value)
@pytest.mark.parametrize("state", ["RUNNING", "QUEUED"])
def test_poke_intermediate_states(self, state, mock_poll_query_status):
mock_poll_query_status.side_effect = [state]
assert self.sensor.poke({}) is False

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
"soft_fail, expected_exception",
[
pytest.param(False, AirflowException, id="not-soft-fail"),
pytest.param(True, AirflowSkipException, id="soft-fail"),
],
)
def test_fail_poke(self, soft_fail, expected_exception):
self.sensor.soft_fail = soft_fail
@pytest.mark.parametrize("state", ["FAILED", "CANCELLED"])
def test_poke_failure_states(self, state, soft_fail, expected_exception, mock_poll_query_status):
mock_poll_query_status.side_effect = [state]
sensor = AthenaSensor(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail)
message = "Athena sensor failed"
with pytest.raises(expected_exception, match=message), mock.patch(
"airflow.providers.amazon.aws.hooks.athena.AthenaHook.poll_query_status"
) as poll_query_status:
poll_query_status.return_value = "FAILED"
self.sensor.poke(context={})
with pytest.raises(expected_exception, match=message):
sensor.poke({})