Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion airflow/providers/amazon/aws/operators/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class AthenaOperator(BaseOperator):

:param query: Presto to be run on athena. (templated)
:param database: Database to select. (templated)
:param catalog: Catalog to select. (templated)
:param output_location: s3 path to write the query results into. (templated)
:param aws_conn_id: aws connection to use
:param client_request_token: Unique token created by user to avoid multiple executions of same query
Expand All @@ -57,7 +58,7 @@ class AthenaOperator(BaseOperator):
"""

ui_color = "#44b5e2"
template_fields: Sequence[str] = ("query", "database", "output_location", "workgroup")
template_fields: Sequence[str] = ("query", "database", "output_location", "workgroup", "catalog")
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"query": "sql"}

Expand All @@ -76,6 +77,7 @@ def __init__(
max_polling_attempts: int | None = None,
log_query: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
catalog: str = "AwsDataCatalog",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -92,6 +94,7 @@ def __init__(
self.query_execution_id: str | None = None
self.log_query: bool = log_query
self.deferrable = deferrable
self.catalog: str = catalog

@cached_property
def hook(self) -> AthenaHook:
Expand All @@ -101,6 +104,7 @@ def hook(self) -> AthenaHook:
def execute(self, context: Context) -> str | None:
"""Run Presto Query on Athena."""
self.query_execution_context["Database"] = self.database
self.query_execution_context["Catalog"] = self.catalog
self.result_configuration["OutputLocation"] = self.output_location
self.query_execution_id = self.hook.run_query(
self.query,
Expand Down
20 changes: 19 additions & 1 deletion tests/providers/amazon/aws/operators/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@
"task_id": "test_athena_operator",
"query": "SELECT * FROM TEST_TABLE",
"database": "TEST_DATABASE",
"catalog": "AwsDataCatalog",
"outputLocation": "s3://test_s3_bucket/",
"client_request_token": "eac427d0-1c6d-4dfb-96aa-2835d3ac6595",
"workgroup": "primary",
}

query_context = {"Database": MOCK_DATA["database"]}
query_context = {"Database": MOCK_DATA["database"], "Catalog": MOCK_DATA["catalog"]}
result_configuration = {"OutputLocation": MOCK_DATA["outputLocation"]}


Expand Down Expand Up @@ -69,10 +70,27 @@ def test_init(self):
assert self.athena.task_id == MOCK_DATA["task_id"]
assert self.athena.query == MOCK_DATA["query"]
assert self.athena.database == MOCK_DATA["database"]
assert self.athena.catalog == MOCK_DATA["catalog"]
assert self.athena.aws_conn_id == "aws_default"
assert self.athena.client_request_token == MOCK_DATA["client_request_token"]
assert self.athena.sleep_time == 0

@mock.patch.object(AthenaHook, "check_query_status", side_effect=("SUCCEEDED",))
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_override_catalog(self, mock_conn, mock_run_query, mock_check_query_status):
query_context_catalog = {"Database": MOCK_DATA["database"], "Catalog": "MyCatalog"}
self.athena.catalog = "MyCatalog"
self.athena.execute({})
mock_run_query.assert_called_once_with(
MOCK_DATA["query"],
query_context_catalog,
result_configuration,
MOCK_DATA["client_request_token"],
MOCK_DATA["workgroup"],
)
assert mock_check_query_status.call_count == 1

@mock.patch.object(AthenaHook, "check_query_status", side_effect=("SUCCEEDED",))
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
Expand Down