Skip to content

Commit

Permalink
Add EMR Serverless Operators and Hooks (#25324)
Browse files Browse the repository at this point in the history
  • Loading branch information
syedahsn authored Aug 5, 2022
1 parent 5480b4c commit 8df84e9
Show file tree
Hide file tree
Showing 8 changed files with 1,232 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from datetime import datetime
from os import getenv

from airflow import DAG
from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.operators.emr import (
EmrServerlessCreateApplicationOperator,
EmrServerlessDeleteApplicationOperator,
EmrServerlessStartJobOperator,
)
from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor, EmrServerlessJobSensor

EXECUTION_ROLE_ARN = getenv('EXECUTION_ROLE_ARN', 'execution_role_arn')
EMR_EXAMPLE_BUCKET = getenv('EMR_EXAMPLE_BUCKET', 'emr_example_bucket')
SPARK_JOB_DRIVER = {
"sparkSubmit": {
"entryPoint": "s3://us-east-1.elasticmapreduce/emr-containers/samples/wordcount/scripts/wordcount.py",
"entryPointArguments": [f"s3://{EMR_EXAMPLE_BUCKET}/output"],
"sparkSubmitParameters": "--conf spark.executor.cores=1 --conf spark.executor.memory=4g\
--conf spark.driver.cores=1 --conf spark.driver.memory=4g --conf spark.executor.instances=1",
}
}

SPARK_CONFIGURATION_OVERRIDES = {
"monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": f"s3://{EMR_EXAMPLE_BUCKET}/logs"}}
}

with DAG(
dag_id='example_emr_serverless',
schedule_interval=None,
start_date=datetime(2021, 1, 1),
tags=['example'],
catchup=False,
) as emr_serverless_dag:

# [START howto_operator_emr_serverless_create_application]
emr_serverless_app = EmrServerlessCreateApplicationOperator(
task_id='create_emr_serverless_task',
release_label='emr-6.6.0',
job_type="SPARK",
config={'name': 'new_application'},
)
# [END howto_operator_emr_serverless_create_application]

# [START howto_sensor_emr_serverless_application]
wait_for_app_creation = EmrServerlessApplicationSensor(
task_id='wait_for_app_creation',
application_id=emr_serverless_app.output,
)
# [END howto_sensor_emr_serverless_application]

# [START howto_operator_emr_serverless_start_job]
start_job = EmrServerlessStartJobOperator(
task_id='start_emr_serverless_job',
application_id=emr_serverless_app.output,
execution_role_arn=EXECUTION_ROLE_ARN,
job_driver=SPARK_JOB_DRIVER,
configuration_overrides=SPARK_CONFIGURATION_OVERRIDES,
)
# [END howto_operator_emr_serverless_start_job]

# [START howto_sensor_emr_serverless_job]
wait_for_job = EmrServerlessJobSensor(
task_id='wait_for_job', application_id=emr_serverless_app.output, job_run_id=start_job.output
)
# [END howto_sensor_emr_serverless_job]

# [START howto_operator_emr_serverless_delete_application]
delete_app = EmrServerlessDeleteApplicationOperator(
task_id='delete_application', application_id=emr_serverless_app.output, trigger_rule="all_done"
)
# [END howto_operator_emr_serverless_delete_application]

chain(
emr_serverless_app,
wait_for_app_creation,
start_job,
wait_for_job,
delete_app,
)
75 changes: 74 additions & 1 deletion airflow/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
# specific language governing permissions and limitations
# under the License.
from time import sleep
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Set

from botocore.exceptions import ClientError

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

Expand Down Expand Up @@ -90,6 +91,78 @@ def create_job_flow(self, job_flow_overrides: Dict[str, Any]) -> Dict[str, Any]:
return response


class EmrServerlessHook(AwsBaseHook):
"""
Interact with EMR Serverless API.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["client_type"] = "emr-serverless"
super().__init__(*args, **kwargs)

@cached_property
def conn(self):
"""Get the underlying boto3 EmrServerlessAPIService client (cached)"""
return super().conn

# This method should be replaced with boto waiters which would implement timeouts and backoff nicely.
def waiter(
self,
get_state_callable: Callable,
get_state_args: Dict,
parse_response: List,
desired_state: Set,
failure_states: Set,
object_type: str,
action: str,
countdown: int = 25 * 60,
check_interval_seconds: int = 60,
) -> None:
"""
Will run the sensor until it turns True.
:param get_state_callable: A callable to run until it returns True
:param get_state_args: Arguments to pass to get_state_callable
:param parse_response: Dictionary keys to extract state from response of get_state_callable
:param desired_state: Wait until the getter returns this value
:param failure_states: A set of states which indicate failure and should throw an
exception if any are reached before the desired_state
:param object_type: Used for the reporting string. What are you waiting for? (application, job, etc)
:param action: Used for the reporting string. What action are you waiting for? (created, deleted, etc)
:param countdown: Total amount of time the waiter should wait for the desired state
before timing out (in seconds). Defaults to 25 * 60 seconds.
:param check_interval_seconds: Number of seconds waiter should wait before attempting
to retry get_state_callable. Defaults to 60 seconds.
"""
response = get_state_callable(**get_state_args)
state: str = self.get_state(response, parse_response)
while state not in desired_state:
if state in failure_states:
raise AirflowException(f'{object_type.title()} reached failure state {state}.')
if countdown >= check_interval_seconds:
countdown -= check_interval_seconds
self.log.info('Waiting for %s to be %s.', object_type.lower(), action.lower())
sleep(check_interval_seconds)
state = self.get_state(get_state_callable(**get_state_args), parse_response)
else:
message = f'{object_type.title()} still not {action.lower()} after the allocated time limit.'
self.log.error(message)
raise RuntimeError(message)

def get_state(self, response, keys) -> str:
value = response
for key in keys:
if value is not None:
value = value.get(key, None)
return value


class EmrContainerHook(AwsBaseHook):
"""
Interact with AWS EMR Virtual Cluster to run, poll jobs and return job status
Expand Down
Loading

0 comments on commit 8df84e9

Please sign in to comment.