Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions airflow/providers/amazon/aws/sensors/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from typing import TYPE_CHECKING, Sequence

from deprecated import deprecated

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
Expand Down Expand Up @@ -57,10 +59,9 @@ def __init__(
self.job_id = job_id
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.hook: BatchClientHook | None = None

def poke(self, context: Context) -> bool:
job_description = self.get_hook().get_job_description(self.job_id)
job_description = self.hook.get_job_description(self.job_id)
state = job_description["status"]

if state == BatchClientHook.SUCCESS_STATE:
Expand All @@ -74,16 +75,17 @@ def poke(self, context: Context) -> bool:

raise AirflowException(f"Batch sensor failed. Unknown AWS Batch job status: {state}")

@deprecated(reason="use `hook` property instead.")
def get_hook(self) -> BatchClientHook:
"""Create and return a BatchClientHook"""
if self.hook:
return self.hook
return self.hook

self.hook = BatchClientHook(
@cached_property
def hook(self) -> BatchClientHook:
return BatchClientHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
)
return self.hook


class BatchComputeEnvironmentSensor(BaseSensorOperator):
Expand Down
15 changes: 9 additions & 6 deletions airflow/providers/amazon/aws/sensors/dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

from typing import TYPE_CHECKING, Iterable, Sequence

from deprecated import deprecated

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.dms import DmsHook
from airflow.sensors.base import BaseSensorOperator
Expand Down Expand Up @@ -58,18 +61,18 @@ def __init__(
self.replication_task_arn = replication_task_arn
self.target_statuses: Iterable[str] = target_statuses or []
self.termination_statuses: Iterable[str] = termination_statuses or []
self.hook: DmsHook | None = None

@deprecated(reason="use `hook` property instead.")
def get_hook(self) -> DmsHook:
"""Get DmsHook"""
if self.hook:
return self.hook

self.hook = DmsHook(self.aws_conn_id)
return self.hook

@cached_property
def hook(self) -> DmsHook:
return DmsHook(self.aws_conn_id)

def poke(self, context: Context):
status: str | None = self.get_hook().get_task_status(self.replication_task_arn)
status: str | None = self.hook.get_task_status(self.replication_task_arn)

if not status:
raise AirflowException(
Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/amazon/aws/sensors/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from typing import TYPE_CHECKING, Sequence

from airflow.compat.functools import cached_property
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -62,8 +63,11 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.region_name = region_name

@cached_property
def hook(self):
return EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)

def poke(self, context: Context):
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
instance_state = ec2_hook.get_instance_state(instance_id=self.instance_id)
instance_state = self.hook.get_instance_state(instance_id=self.instance_id)
self.log.info("instance state: %s", instance_state)
return instance_state == self.target_state
25 changes: 16 additions & 9 deletions airflow/providers/amazon/aws/sensors/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from typing import TYPE_CHECKING, Sequence

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.eks import (
ClusterStates,
Expand Down Expand Up @@ -98,13 +99,15 @@ def __init__(
self.region = region
super().__init__(**kwargs)

def poke(self, context: Context):
eks_hook = EksHook(
@cached_property
def hook(self):
return EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)

cluster_state = eks_hook.get_cluster_state(clusterName=self.cluster_name)
def poke(self, context: Context):
cluster_state = self.hook.get_cluster_state(clusterName=self.cluster_name)
self.log.info("Cluster state: %s", cluster_state)
if cluster_state in (CLUSTER_TERMINAL_STATES - {self.target_state}):
# If we reach a terminal state which is not the target state:
Expand Down Expand Up @@ -167,13 +170,15 @@ def __init__(
self.region = region
super().__init__(**kwargs)

def poke(self, context: Context):
eks_hook = EksHook(
@cached_property
def hook(self):
return EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)

fargate_profile_state = eks_hook.get_fargate_profile_state(
def poke(self, context: Context):
fargate_profile_state = self.hook.get_fargate_profile_state(
clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name
)
self.log.info("Fargate profile state: %s", fargate_profile_state)
Expand Down Expand Up @@ -238,13 +243,15 @@ def __init__(
self.region = region
super().__init__(**kwargs)

def poke(self, context: Context):
eks_hook = EksHook(
@cached_property
def hook(self):
return EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)

nodegroup_state = eks_hook.get_nodegroup_state(
def poke(self, context: Context):
nodegroup_state = self.hook.get_nodegroup_state(
clusterName=self.cluster_name, nodegroupName=self.nodegroup_name
)
self.log.info("Nodegroup state: %s", nodegroup_state)
Expand Down
23 changes: 12 additions & 11 deletions airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from typing import TYPE_CHECKING, Any, Iterable, Sequence

from deprecated import deprecated

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand Down Expand Up @@ -52,16 +54,15 @@ def __init__(self, *, aws_conn_id: str = "aws_default", **kwargs):
self.aws_conn_id = aws_conn_id
self.target_states: Iterable[str] = [] # will be set in subclasses
self.failed_states: Iterable[str] = [] # will be set in subclasses
self.hook: EmrHook | None = None

@deprecated(reason="use `hook` property instead.")
def get_hook(self) -> EmrHook:
"""Get EmrHook"""
if self.hook:
return self.hook

self.hook = EmrHook(aws_conn_id=self.aws_conn_id)
return self.hook

@cached_property
def hook(self) -> EmrHook:
return EmrHook(aws_conn_id=self.aws_conn_id)

def poke(self, context: Context):
response = self.get_emr_response(context=context)

Expand Down Expand Up @@ -332,7 +333,7 @@ def __init__(
self.failed_states = failed_states or self.FAILURE_STATES

def get_emr_response(self, context: Context) -> dict[str, Any]:
emr_client = self.get_hook().get_conn()
emr_client = self.hook.conn
self.log.info("Poking notebook %s", self.notebook_execution_id)

return emr_client.describe_notebook_execution(NotebookExecutionId=self.notebook_execution_id)
Expand Down Expand Up @@ -408,15 +409,15 @@ def get_emr_response(self, context: Context) -> dict[str, Any]:

:return: response
"""
emr_client = self.get_hook().get_conn()
emr_client = self.hook.conn
self.log.info("Poking cluster %s", self.job_flow_id)
response = emr_client.describe_cluster(ClusterId=self.job_flow_id)
log_uri = S3Hook.parse_s3_url(response["Cluster"]["LogUri"])
EmrLogsLink.persist(
context=context,
operator=self,
region_name=self.get_hook().conn_region_name,
aws_partition=self.get_hook().conn_partition,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_flow_id=self.job_flow_id,
log_uri="/".join(log_uri),
)
Expand Down Expand Up @@ -497,7 +498,7 @@ def get_emr_response(self, context: Context) -> dict[str, Any]:

:return: response
"""
emr_client = self.get_hook().get_conn()
emr_client = self.hook.conn

self.log.info("Poking step %s on cluster %s", self.step_id, self.job_flow_id)
return emr_client.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id)
Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/amazon/aws/sensors/glacier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Sequence

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.glacier import GlacierHook
from airflow.sensors.base import BaseSensorOperator
Expand Down Expand Up @@ -81,9 +82,12 @@ def __init__(
self.poke_interval = poke_interval
self.mode = mode

@cached_property
def hook(self):
return GlacierHook(aws_conn_id=self.aws_conn_id)

def poke(self, context: Context) -> bool:
hook = GlacierHook(aws_conn_id=self.aws_conn_id)
response = hook.describe_job(vault_name=self.vault_name, job_id=self.job_id)
response = self.hook.describe_job(vault_name=self.vault_name, job_id=self.job_id)

if response["StatusCode"] == JobStatus.SUCCEEDED.value:
self.log.info("Job status: %s, code status: %s", response["Action"], response["StatusCode"])
Expand Down
10 changes: 7 additions & 3 deletions airflow/providers/amazon/aws/sensors/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from typing import TYPE_CHECKING, Sequence

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
from airflow.sensors.base import BaseSensorOperator
Expand Down Expand Up @@ -61,10 +62,13 @@ def __init__(
self.errored_states: list[str] = ["FAILED", "STOPPED", "TIMEOUT"]
self.next_log_token: str | None = None

@cached_property
def hook(self):
return GlueJobHook(aws_conn_id=self.aws_conn_id)
Copy link
Contributor

@ferruzzi ferruzzi Jan 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of these files you are passing region to the hook and others are not. Should we standardize that as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly kept the existing behavior. I don't know what is a real use case for passing the region, maybe that depends on the operator/sensor ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhere in the past I have an idea to keep Sensors/Operators more consistent and provide all of supported parameters in Hook.
I think it safe to pass region_name because we set it to None by default in Operators/Sensors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW , it is not related to this PR, but we have a 3 different hooks which is wraps around Glue.Client 🤦

  • airflow.providers.amazon.aws.hooks.glue.GlueJobHook
  • airflow.providers.amazon.aws.hooks.glue_catalog.GlueCatalogHook
  • airflow.providers.amazon.aws.hooks.glue_crawler.GlueCrawlerHook

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW , it is not related to this PR, but we have a 3 different hooks which is wraps around Glue.Client facepalm

Yeah, there is also a BatchWaiterHook which is an oddball too. Some Day ™️

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In one recent PR I mark Hooks as Thin and Thick wrapper, so we could easily find them in "Some Day"

grep -rl 'thick wrapper' airflow/providers/amazon/aws/hooks
airflow/providers/amazon/aws/hooks/emr.py
airflow/providers/amazon/aws/hooks/dynamodb.py
airflow/providers/amazon/aws/hooks/batch_client.py
airflow/providers/amazon/aws/hooks/kinesis.py
airflow/providers/amazon/aws/hooks/glue.py
airflow/providers/amazon/aws/hooks/sagemaker.py
airflow/providers/amazon/aws/hooks/ec2.py
airflow/providers/amazon/aws/hooks/datasync.py
airflow/providers/amazon/aws/hooks/athena.py
airflow/providers/amazon/aws/hooks/s3.py
airflow/providers/amazon/aws/hooks/elasticache_replication_group.py

Copy link
Contributor

@ferruzzi ferruzzi Jan 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tack onto the "Some Day" list, converting as many hooks as possible to thin hooks, maybe with a Protocol class for IDE type hinting and completion. I didn't like the EcsProtocol and BatchProtocol classes at first, but they grew on me.

Anyway, we should maybe get a list of "easy but tedious suggestions" like this for new contributors to flip through, but we're getting way off track here 😛

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also like an idea @o-nikolas generate hooks for all supported services #28560 (comment)


def poke(self, context: Context):
hook = GlueJobHook(aws_conn_id=self.aws_conn_id)
self.log.info("Poking for job run status :for Glue Job %s and ID %s", self.job_name, self.run_id)
job_state = hook.get_job_state(job_name=self.job_name, run_id=self.run_id)
job_state = self.hook.get_job_state(job_name=self.job_name, run_id=self.run_id)
job_failed = False

try:
Expand All @@ -80,7 +84,7 @@ def poke(self, context: Context):
return False
finally:
if self.verbose:
self.next_log_token = hook.print_job_logs(
self.next_log_token = self.hook.print_job_logs(
job_name=self.job_name,
run_id=self.run_id,
job_failed=job_failed,
Expand Down
15 changes: 9 additions & 6 deletions airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

from typing import TYPE_CHECKING, Sequence

from deprecated import deprecated

from airflow.compat.functools import cached_property
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -71,7 +74,6 @@ def __init__(
self.table_name = table_name
self.expression = expression
self.database_name = database_name
self.hook: GlueCatalogHook | None = None

def poke(self, context: Context):
"""Checks for existence of the partition in the AWS Glue Catalog table"""
Expand All @@ -81,12 +83,13 @@ def poke(self, context: Context):
"Poking for table %s. %s, expression %s", self.database_name, self.table_name, self.expression
)

return self.get_hook().check_for_partition(self.database_name, self.table_name, self.expression)
return self.hook.check_for_partition(self.database_name, self.table_name, self.expression)

@deprecated(reason="use `hook` property instead.")
def get_hook(self) -> GlueCatalogHook:
"""Gets the GlueCatalogHook"""
if self.hook:
return self.hook

self.hook = GlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
return self.hook

@cached_property
def hook(self) -> GlueCatalogHook:
return GlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
18 changes: 10 additions & 8 deletions airflow/providers/amazon/aws/sensors/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

from typing import TYPE_CHECKING, Sequence

from deprecated import deprecated

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
from airflow.sensors.base import BaseSensorOperator
Expand Down Expand Up @@ -48,15 +51,13 @@ def __init__(self, *, crawler_name: str, aws_conn_id: str = "aws_default", **kwa
self.aws_conn_id = aws_conn_id
self.success_statuses = "SUCCEEDED"
self.errored_statuses = ("FAILED", "CANCELLED")
self.hook: GlueCrawlerHook | None = None

def poke(self, context: Context):
hook = self.get_hook()
self.log.info("Poking for AWS Glue crawler: %s", self.crawler_name)
crawler_state = hook.get_crawler(self.crawler_name)["State"]
crawler_state = self.hook.get_crawler(self.crawler_name)["State"]
if crawler_state == "READY":
self.log.info("State: %s", crawler_state)
crawler_status = hook.get_crawler(self.crawler_name)["LastCrawl"]["Status"]
crawler_status = self.hook.get_crawler(self.crawler_name)["LastCrawl"]["Status"]
if crawler_status == self.success_statuses:
self.log.info("Status: %s", crawler_status)
return True
Expand All @@ -65,10 +66,11 @@ def poke(self, context: Context):
else:
return False

@deprecated(reason="use `hook` property instead.")
def get_hook(self) -> GlueCrawlerHook:
"""Returns a new or pre-existing GlueCrawlerHook"""
if self.hook:
return self.hook

self.hook = GlueCrawlerHook(aws_conn_id=self.aws_conn_id)
return self.hook

@cached_property
def hook(self) -> GlueCrawlerHook:
return GlueCrawlerHook(aws_conn_id=self.aws_conn_id)
Loading