Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions airflow/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,15 @@ def add_job_flow_steps(
)
return response["StepIds"]

def terminate_job_flow(self, job_flow_id: str) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally try to avoid functions in hooks which just wrap boto3 api. You can call the boto3 api directly from the operator

"""
Terminate a given EMR cluster (job flow) by id. If TerminationProtected=True on the cluster,
termination will be unsuccessful.

:param job_flow_id: id of the job flow to terminate
"""
self.get_conn().terminate_job_flows(JobFlowIds=[job_flow_id])

def test_connection(self):
"""
Return failed state for test Amazon Elastic MapReduce Connection (untestable).
Expand Down
62 changes: 51 additions & 11 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,12 @@ class EmrCreateJobFlowOperator(BaseOperator):
:param job_flow_overrides: boto3 style arguments or reference to an arguments file
(must be '.json') to override specific ``emr_conn_id`` extra parameters. (templated)
:param region_name: Region named passed to EmrHook
:param wait_for_completion: Whether to finish task immediately after creation (False) or wait for jobflow
completion (True)
:param waiter_countdown: Max. seconds to wait for jobflow completion (only in combination with
wait_for_completion=True, None = no limit)
:param waiter_check_interval_seconds: Number of seconds between polling the jobflow state. Defaults to 60
seconds.
"""

template_fields: Sequence[str] = ("job_flow_overrides",)
Expand All @@ -538,42 +544,76 @@ def __init__(
emr_conn_id: str | None = "emr_default",
job_flow_overrides: str | dict[str, Any] | None = None,
region_name: str | None = None,
wait_for_completion: bool = False,
waiter_countdown: int | None = None,
waiter_check_interval_seconds: int = 60,
**kwargs,
):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.emr_conn_id = emr_conn_id
self.job_flow_overrides = job_flow_overrides or {}
self.region_name = region_name
self.wait_for_completion = wait_for_completion
self.waiter_countdown = waiter_countdown
self.waiter_check_interval_seconds = waiter_check_interval_seconds

self._job_flow_id: str | None = None

def execute(self, context: Context) -> str:
emr = EmrHook(
@cached_property
def _emr_hook(self) -> EmrHook:
"""Create and return an EmrHook."""
return EmrHook(
aws_conn_id=self.aws_conn_id, emr_conn_id=self.emr_conn_id, region_name=self.region_name
)

def execute(self, context: Context) -> str | None:
self.log.info(
"Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s", self.aws_conn_id, self.emr_conn_id
"Creating job flow using aws_conn_id: %s, emr_conn_id: %s", self.aws_conn_id, self.emr_conn_id
)
if isinstance(self.job_flow_overrides, str):
job_flow_overrides: dict[str, Any] = ast.literal_eval(self.job_flow_overrides)
self.job_flow_overrides = job_flow_overrides
else:
job_flow_overrides = self.job_flow_overrides
response = emr.create_job_flow(job_flow_overrides)
response = self._emr_hook.create_job_flow(job_flow_overrides)

if not response["ResponseMetadata"]["HTTPStatusCode"] == 200:
raise AirflowException(f"JobFlow creation failed: {response}")
raise AirflowException(f"Job flow creation failed: {response}")
else:
job_flow_id = response["JobFlowId"]
self.log.info("JobFlow with id %s created", job_flow_id)
self._job_flow_id = response["JobFlowId"]
self.log.info("Job flow with id %s created", self._job_flow_id)
EmrClusterLink.persist(
context=context,
operator=self,
region_name=emr.conn_region_name,
aws_partition=emr.conn_partition,
job_flow_id=job_flow_id,
region_name=self._emr_hook.conn_region_name,
aws_partition=self._emr_hook.conn_partition,
job_flow_id=self._job_flow_id,
)
return job_flow_id

if self.wait_for_completion:
# Didn't use a boto-supplied waiter because those don't support waiting for WAITING state.
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#waiters
waiter(
get_state_callable=self._emr_hook.get_conn().describe_cluster,
Comment on lines +597 to +598
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad to see someone using this already to create new customer waiters! 🤩

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this isn't the waiter setup I thought it was originally (thanks @ferruzzi for pointing that out!). You can find details on the new custom waiters here. Though I'm actually happy to merge this PR with the waiter you used, and then move all of the EMR waiters to the new waiter system in another PR rather than scope creeping this one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I didn't realize you could actually implement custom waiters that way, hence my comment above: https://github.com/BasPH/airflow/blob/add-emrcreatejobflow-waitforcompletion/airflow/providers/amazon/aws/operators/emr.py#L595-L596.

I'll take a look.

get_state_args={"ClusterId": self._job_flow_id},
parse_response=["Cluster", "Status", "State"],
# Cluster will be in WAITING after finishing if KeepJobFlowAliveWhenNoSteps is True
desired_state={"WAITING", "TERMINATED"},
failure_states={"TERMINATED_WITH_ERRORS"},
object_type="job flow",
action="finished",
countdown=self.waiter_countdown,
check_interval_seconds=self.waiter_check_interval_seconds,
)

return self._job_flow_id

def on_kill(self) -> None:
"""Terminate job flow."""
if self._job_flow_id:
self.log.info("Terminating job flow %s", self._job_flow_id)
self._emr_hook.terminate_job_flow(self._job_flow_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self._emr_hook.terminate_job_flow(self._job_flow_id)
self._emr_hook.get_conn().terminate_job_flows(JobFlowIds=[job_flow_id])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @BasPH, thoughts on this suggested change? Otherwise the PR looks good



class EmrModifyClusterOperator(BaseOperator):
Expand Down
10 changes: 7 additions & 3 deletions airflow/providers/amazon/aws/utils/waiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def waiter(
failure_states: set,
object_type: str,
action: str,
countdown: int = 25 * 60,
countdown: int | float | None = 25 * 60,
check_interval_seconds: int = 60,
) -> None:
"""
Expand All @@ -49,8 +49,8 @@ def waiter(
exception if any are reached before the desired_state
:param object_type: Used for the reporting string. What are you waiting for? (application, job, etc)
:param action: Used for the reporting string. What action are you waiting for? (created, deleted, etc)
:param countdown: Total amount of time the waiter should wait for the desired state
before timing out (in seconds). Defaults to 25 * 60 seconds.
:param countdown: Number of seconds the waiter should wait for the desired state before timing out.
Defaults to 25 * 60 seconds. None = infinite.
:param check_interval_seconds: Number of seconds waiter should wait before attempting
to retry get_state_callable. Defaults to 60 seconds.
"""
Expand All @@ -60,6 +60,10 @@ def waiter(
break
if state in failure_states:
raise AirflowException(f"{object_type.title()} reached failure state {state}.")

if countdown is None:
countdown = float("inf")

if countdown > check_interval_seconds:
countdown -= check_interval_seconds
log.info("Waiting for %s to be %s.", object_type.lower(), action.lower())
Expand Down