Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions airflow/providers/amazon/aws/operators/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,17 @@
# specific language governing permissions and limitations
# under the License.
#
import sys
import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence

if sys.version_info >= (3, 8):
from functools import cached_property
else:
from cached_property import cached_property

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class AthenaOperator(BaseOperator):
class AthenaOperator(AwsBaseOperator[AthenaHook]):
"""
An operator that submits a presto query to athena.

Expand All @@ -44,6 +38,8 @@ class AthenaOperator(BaseOperator):
:param database: Database to select. (templated)
:param output_location: s3 path to write the query results into. (templated)
:param aws_conn_id: aws connection to use
:param region_name: (optional) region name to use in AWS Hook.
Override the region_name in connection (if provided)
:param client_request_token: Unique token created by user to avoid multiple executions of same query
:param workgroup: Athena workgroup in which query will be run
:param query_execution_context: Context in which query need to be run
Expand All @@ -57,13 +53,17 @@ class AthenaOperator(BaseOperator):
template_ext: Sequence[str] = ('.sql',)
template_fields_renderers = {"query": "sql"}

aws_hook_class = AthenaHook
aws_hook_class_kwargs = {"sleep_time"}

def __init__(
self,
*,
query: str,
database: str,
output_location: str,
aws_conn_id: str = "aws_default",
region_name: Optional[str] = None,
client_request_token: Optional[str] = None,
workgroup: str = "primary",
query_execution_context: Optional[Dict[str, str]] = None,
Expand All @@ -72,11 +72,10 @@ def __init__(
max_tries: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
super().__init__(aws_conn_id=aws_conn_id, region_name=region_name, sleep_time=sleep_time, **kwargs)
self.query = query
self.database = database
self.output_location = output_location
self.aws_conn_id = aws_conn_id
self.client_request_token = client_request_token
self.workgroup = workgroup
self.query_execution_context = query_execution_context or {}
Expand All @@ -85,11 +84,6 @@ def __init__(
self.max_tries = max_tries
self.query_execution_id = None # type: Optional[str]

@cached_property
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook."""
return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time)

def execute(self, context: 'Context') -> Optional[str]:
"""Run Presto Query on Athena"""
self.query_execution_context['Database'] = self.database
Expand Down
15 changes: 9 additions & 6 deletions airflow/providers/amazon/aws/operators/aws_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
import json
from typing import TYPE_CHECKING, Optional, Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class AwsLambdaInvokeFunctionOperator(BaseOperator):
class AwsLambdaInvokeFunctionOperator(AwsBaseOperator[LambdaHook]):
"""
Invokes an AWS Lambda function.
You can invoke a function synchronously (and wait for the response),
Expand All @@ -40,6 +40,8 @@ class AwsLambdaInvokeFunctionOperator(BaseOperator):
:param log_type: Set to Tail to include the execution log in the response. Otherwise, set to "None".
:param qualifier: Specify a version or alias to invoke a published version of the function.
:param aws_conn_id: The AWS connection ID to use
:param region_name: (optional) region name to use in AWS Hook.
Override the region_name in connection (if provided)

.. seealso::
For more information on how to use this operator, take a look at the guide:
Expand All @@ -50,6 +52,8 @@ class AwsLambdaInvokeFunctionOperator(BaseOperator):
template_fields: Sequence[str] = ('function_name', 'payload', 'qualifier', 'invocation_type')
ui_color = '#ff7300'

aws_hook_class = LambdaHook

def __init__(
self,
*,
Expand All @@ -60,27 +64,26 @@ def __init__(
client_context: Optional[str] = None,
payload: Optional[str] = None,
aws_conn_id: str = 'aws_default',
region_name: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(aws_conn_id=aws_conn_id, region_name=region_name, **kwargs)
self.function_name = function_name
self.payload = payload
self.log_type = log_type
self.qualifier = qualifier
self.invocation_type = invocation_type
self.client_context = client_context
self.aws_conn_id = aws_conn_id

def execute(self, context: 'Context'):
"""
Invokes the target AWS Lambda function from Airflow.

:return: The response payload from the function, or an error object.
"""
hook = LambdaHook(aws_conn_id=self.aws_conn_id)
success_status_codes = [200, 202, 204]
self.log.info("Invoking AWS Lambda function: %s with payload: %s", self.function_name, self.payload)
response = hook.invoke_lambda(
response = self.hook.invoke_lambda(
function_name=self.function_name,
invocation_type=self.invocation_type,
log_type=self.log_type,
Expand Down
97 changes: 97 additions & 0 deletions airflow/providers/amazon/aws/operators/base_aws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import sys
import warnings
from typing import Generic, Optional, Set, Type, TypeVar

if sys.version_info >= (3, 8):
from functools import cached_property
else:
from cached_property import cached_property

from botocore.config import Config

from airflow import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

AwsHookClass = TypeVar("AwsHookClass", bound=AwsBaseHook)


class AwsBaseOperator(BaseOperator, Generic[AwsHookClass]):
"""Base implementations for amazon-provider operators.

:param aws_conn_id: aws connection to use
:param region_name: (optional) region name to use in AWS Hook.
Override the region_name in connection (if provided)
:param botocore_config: Configuration for botocore client.
(https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html)
"""

aws_hook_class: Type[AwsHookClass]
aws_hook_class_kwargs: Optional[Set[str]] = None

def __init__(
self,
*,
aws_conn_id: Optional[str] = "aws_default",
region_name: Optional[str] = None,
botocore_config: Optional[Config] = None,
**kwargs,
) -> None:
self.aws_conn_id = aws_conn_id

region = kwargs.pop("region", None)
if region:
warnings.warn(
'Parameter `region` is deprecated. Please use `region_name` instead.',
DeprecationWarning,
stacklevel=2,
)
if region_name:
raise AirflowException("Either `region_name` or `region` can be provided, not both.")
region_name = region

self.region_name = region_name
self.botocore_config = botocore_config

self.hooks_class_args = {}

# Check if Hook uses non `aws_conn_id` for connection
conn_name_attr = self.aws_hook_class.conn_name_attr
if conn_name_attr != "aws_conn_id":
self.hooks_class_args[conn_name_attr] = kwargs.pop(
conn_name_attr, self.aws_hook_class.default_conn_name
)

# Add additional hook args and remove from keywords arguments
for arg in self.aws_hook_class_kwargs or {}:
if arg in kwargs:
self.hooks_class_args[arg] = kwargs.pop(arg, None)

super().__init__(**kwargs)

@cached_property
def hook(self) -> AwsHookClass:
"""Create and return an AWS Hook."""
return self.aws_hook_class(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
config=self.botocore_config,
**self.hooks_class_args,
)
22 changes: 12 additions & 10 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
from typing import TYPE_CHECKING, Any, Optional, Sequence

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class BatchOperator(BaseOperator):
class BatchOperator(AwsBaseOperator[BatchClientHook]):
"""
Execute a job on AWS Batch

Expand Down Expand Up @@ -95,6 +95,9 @@ class BatchOperator(BaseOperator):
)
template_fields_renderers = {"overrides": "json", "parameters": "json"}

aws_hook_class = BatchClientHook
aws_hook_class_kwargs = {"max_retries", "status_retries"}

def __init__(
self,
*,
Expand All @@ -113,8 +116,13 @@ def __init__(
tags: Optional[dict] = None,
**kwargs,
):

BaseOperator.__init__(self, **kwargs)
super().__init__(
aws_conn_id=aws_conn_id,
region_name=region_name,
max_retries=max_retries,
status_retries=status_retries,
**kwargs,
)
self.job_id = job_id
self.job_name = job_name
self.job_definition = job_definition
Expand All @@ -124,12 +132,6 @@ def __init__(
self.parameters = parameters or {}
self.waiters = waiters
self.tags = tags or {}
self.hook = BatchClientHook(
max_retries=max_retries,
status_retries=status_retries,
aws_conn_id=aws_conn_id,
region_name=region_name,
)

def execute(self, context: 'Context'):
"""
Expand Down
Loading