Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 135 additions & 15 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tempfile
import time
import warnings
from collections import Counter
from datetime import datetime
from functools import partial
from typing import Any, Callable, Generator, cast
Expand Down Expand Up @@ -146,6 +147,7 @@ class SageMakerHook(AwsBaseHook):

non_terminal_states = {"InProgress", "Stopping"}
endpoint_non_terminal_states = {"Creating", "Updating", "SystemUpdating", "RollingBack", "Deleting"}
pipeline_non_terminal_states = {"Executing", "Stopping"}
failed_states = {"Failed"}

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -654,22 +656,21 @@ def check_status(
check_interval: int,
max_ingestion_time: int | None = None,
non_terminal_states: set | None = None,
):
) -> dict:
"""
Check status of a SageMaker job
Check status of a SageMaker resource

:param job_name: name of the job to check status
:param key: the key of the response dict
that points to the state
:param job_name: name of the resource to check status, can be a job but also pipeline for instance.
:param key: the key of the response dict that points to the state
:param describe_function: the function used to retrieve the status
:param args: the arguments for the function
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
will check the status of any SageMaker resource
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
SageMaker resources that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker resource.
:param non_terminal_states: the set of nonterminal states
:return: response of describe call after job is done
:return: response of describe call after resource is done
"""
if not non_terminal_states:
non_terminal_states = self.non_terminal_states
Expand All @@ -683,22 +684,22 @@ def check_status(
try:
response = describe_function(job_name)
status = response[key]
self.log.info("Job still running for %s seconds... current status is %s", sec, status)
self.log.info("Resource still running for %s seconds... current status is %s", sec, status)
except KeyError:
raise AirflowException("Could not get status of the SageMaker job")
raise AirflowException("Could not get status of the SageMaker resource")
except ClientError:
raise AirflowException("AWS request failed, check logs for more info")

if status in self.failed_states:
raise AirflowException(f"SageMaker job failed because {response['FailureReason']}")
raise AirflowException(f"SageMaker resource failed because {response['FailureReason']}")
elif status not in non_terminal_states:
break

if max_ingestion_time and sec > max_ingestion_time:
# ensure that the job gets killed if the max ingestion time is exceeded
raise AirflowException(f"SageMaker job took more than {max_ingestion_time} seconds")
# ensure that the resource gets killed if the max ingestion time is exceeded
raise AirflowException(f"SageMaker resource took more than {max_ingestion_time} seconds")

self.log.info("SageMaker Job completed")
self.log.info("SageMaker resource completed")
return response

def check_training_status_with_log(
Expand Down Expand Up @@ -1010,3 +1011,122 @@ def delete_model(self, model_name: str):
except Exception as general_error:
self.log.error("Failed to delete model, error: %s", general_error)
raise

def describe_pipeline_exec(self, pipeline_exec_arn: str, verbose: bool = False):
"""Get info about a SageMaker pipeline execution

:param pipeline_exec_arn: arn of the pipeline execution
:param verbose: Whether to log details about the steps status in the pipeline execution
"""
if verbose:
res = self.conn.list_pipeline_execution_steps(PipelineExecutionArn=pipeline_exec_arn)
count_by_state = Counter(s["StepStatus"] for s in res["PipelineExecutionSteps"])
running_steps = [
s["StepName"] for s in res["PipelineExecutionSteps"] if s["StepStatus"] == "Executing"
]
self.log.info("state of the pipeline steps: %s", count_by_state)
self.log.info("steps currently in progress: %s", running_steps)

return self.conn.describe_pipeline_execution(PipelineExecutionArn=pipeline_exec_arn)

def start_pipeline(
self,
pipeline_name: str,
display_name: str = "airflow-triggered-execution",
pipeline_params: dict | None = None,
wait_for_completion: bool = False,
check_interval: int = 30,
verbose: bool = True,
) -> str:
"""
Start a new execution for a SageMaker pipeline

:param pipeline_name: Name of the pipeline to start (this is _not_ the ARN).
:param display_name: The name this pipeline execution will have in the UI. Doesn't need to be unique.
:param pipeline_params: Optional parameters for the pipeline.
All parameters supplied need to already be present in the pipeline definition.
:param wait_for_completion: Will only return once the pipeline is complete if true.
:param check_interval: How long to wait between checks for pipeline status when waiting for
completion.
:param verbose: Whether to print steps details when waiting for completion.
Defaults to true, consider turning off for pipelines that have thousands of steps.

:return: the ARN of the pipeline execution launched.
"""
if pipeline_params is None:
pipeline_params = {}
formatted_params = [{"Name": kvp[0], "Value": kvp[1]} for kvp in pipeline_params.items()]

try:
res = self.conn.start_pipeline_execution(
PipelineName=pipeline_name,
PipelineExecutionDisplayName=display_name,
PipelineParameters=formatted_params,
)
except ClientError as ce:
self.log.error("Failed to start pipeline execution, error: %s", ce)
raise

arn = res["PipelineExecutionArn"]
if wait_for_completion:
self.check_status(
arn,
"PipelineExecutionStatus",
lambda p: self.describe_pipeline_exec(p, verbose),
check_interval,
non_terminal_states=self.pipeline_non_terminal_states,
)
return arn

def stop_pipeline(
self,
pipeline_exec_arn: str,
wait_for_completion: bool = False,
check_interval: int = 10,
verbose: bool = True,
fail_if_not_running: bool = False,
) -> str:
"""Stop SageMaker pipeline execution

:param pipeline_exec_arn: Amazon Resource Name (ARN) of the pipeline execution.
It's the ARN of the pipeline itself followed by "/execution/" and an id.
:param wait_for_completion: Whether to wait for the pipeline to reach a final state.
(i.e. either 'Stopped' or 'Failed')
:param check_interval: How long to wait between checks for pipeline status when waiting for
completion.
:param verbose: Whether to print steps details when waiting for completion.
Defaults to true, consider turning off for pipelines that have thousands of steps.
:param fail_if_not_running: This method will raise an exception if the pipeline we're trying to stop
is not in an "Executing" state when the call is sent (which would mean that the pipeline is
already either stopping or stopped).
Note that setting this to True will raise an error if the pipeline finished successfully before it
was stopped.
:return: Status of the pipeline execution after the operation.
One of 'Executing'|'Stopping'|'Stopped'|'Failed'|'Succeeded'.
"""
try:
self.conn.stop_pipeline_execution(PipelineExecutionArn=pipeline_exec_arn)
except ClientError as ce:
# we have to rely on the message to catch the right error here, because its type
# (ValidationException) is shared with other kinds of error (for instance, badly formatted ARN)
if (
not fail_if_not_running
and "Only pipelines with 'Executing' status can be stopped" in ce.response["Error"]["Message"]
):
self.log.warning("Cannot stop pipeline execution, as it was not running: %s", ce)
else:
self.log.error(ce)
raise

res = self.describe_pipeline_exec(pipeline_exec_arn)

if wait_for_completion and res["PipelineExecutionStatus"] in self.pipeline_non_terminal_states:
res = self.check_status(
pipeline_exec_arn,
"PipelineExecutionStatus",
lambda p: self.describe_pipeline_exec(p, verbose),
check_interval,
non_terminal_states=self.pipeline_non_terminal_states,
)

return res["PipelineExecutionStatus"]
120 changes: 119 additions & 1 deletion airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _create_integer_fields(self) -> None:
"""
self.integer_fields = []

def execute(self, context: Context) -> None | dict:
def execute(self, context: Context):
raise NotImplementedError("Please implement execute() in sub class!")

@cached_property
Expand Down Expand Up @@ -750,3 +750,121 @@ def execute(self, context: Context) -> Any:
sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
sagemaker_hook.delete_model(model_name=self.config["ModelName"])
self.log.info("Model %s deleted successfully.", self.config["ModelName"])


class SageMakerStartPipelineOperator(SageMakerBaseOperator):
"""
Starts a SageMaker pipeline execution.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:SageMakerStartPipelineOperator`

:param config: The configuration to start the pipeline execution.
:param aws_conn_id: The AWS connection ID to use.
:param pipeline_name: Name of the pipeline to start.
:param display_name: The name this pipeline execution will have in the UI. Doesn't need to be unique.
:param pipeline_params: Optional parameters for the pipeline.
All parameters supplied need to already be present in the pipeline definition.
:param wait_for_completion: If true, this operator will only complete once the pipeline is complete.
:param check_interval: How long to wait between checks for pipeline status when waiting for completion.
:param verbose: Whether to print steps details when waiting for completion.
Defaults to true, consider turning off for pipelines that have thousands of steps.

:return str: Returns The ARN of the pipeline execution created in Amazon SageMaker.
"""

template_fields: Sequence[str] = ("aws_conn_id", "pipeline_name", "display_name", "pipeline_params")

def __init__(
self,
*,
aws_conn_id: str = DEFAULT_CONN_ID,
pipeline_name: str,
display_name: str = "airflow-triggered-execution",
pipeline_params: dict | None = None,
wait_for_completion: bool = False,
check_interval: int = CHECK_INTERVAL_SECOND,
verbose: bool = True,
**kwargs,
):
super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
self.pipeline_name = pipeline_name
self.display_name = display_name
self.pipeline_params = pipeline_params
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.verbose = verbose

def execute(self, context: Context) -> str:
arn = self.hook.start_pipeline(
pipeline_name=self.pipeline_name,
display_name=self.display_name,
pipeline_params=self.pipeline_params,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
verbose=self.verbose,
)
self.log.info(
"Starting a new execution for pipeline %s, running with ARN %s", self.pipeline_name, arn
)
return arn


class SageMakerStopPipelineOperator(SageMakerBaseOperator):
"""
Stops a SageMaker pipeline execution.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:SageMakerStopPipelineOperator`

:param config: The configuration to start the pipeline execution.
:param aws_conn_id: The AWS connection ID to use.
:param pipeline_exec_arn: Amazon Resource Name of the pipeline execution to stop.
:param wait_for_completion: If true, this operator will only complete once the pipeline is fully stopped.
:param check_interval: How long to wait between checks for pipeline status when waiting for completion.
:param verbose: Whether to print steps details when waiting for completion.
Defaults to true, consider turning off for pipelines that have thousands of steps.
:param fail_if_not_running: raises an exception if the pipeline stopped or succeeded before this was run

:return str: Returns the status of the pipeline execution after the operation has been done.
"""

template_fields: Sequence[str] = (
"aws_conn_id",
"pipeline_exec_arn",
)

def __init__(
self,
*,
aws_conn_id: str = DEFAULT_CONN_ID,
pipeline_exec_arn: str,
wait_for_completion: bool = False,
check_interval: int = CHECK_INTERVAL_SECOND,
verbose: bool = True,
fail_if_not_running: bool = False,
**kwargs,
):
super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
self.pipeline_exec_arn = pipeline_exec_arn
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.verbose = verbose
self.fail_if_not_running = fail_if_not_running

def execute(self, context: Context) -> str:
status = self.hook.stop_pipeline(
pipeline_exec_arn=self.pipeline_exec_arn,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
verbose=self.verbose,
fail_if_not_running=self.fail_if_not_running,
)
self.log.info(
"Stop requested for pipeline execution with ARN %s. Status is now %s",
self.pipeline_exec_arn,
status,
)
return status
44 changes: 41 additions & 3 deletions airflow/providers/amazon/aws/sensors/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ class SageMakerBaseSensor(BaseSensorOperator):

ui_color = "#ededed"

def __init__(self, *, aws_conn_id: str = "aws_default", **kwargs):
def __init__(self, *, aws_conn_id: str = "aws_default", resource_type: str = "job", **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.resource_type = resource_type # only used for logs, to say what kind of resource we are sensing
self.hook: SageMakerHook | None = None

def get_hook(self) -> SageMakerHook:
Expand All @@ -55,12 +56,14 @@ def poke(self, context: Context):
self.log.info("Bad HTTP response: %s", response)
return False
state = self.state_from_response(response)
self.log.info("Job currently %s", state)
self.log.info("%s currently %s", self.resource_type, state)
if state in self.non_terminal_states():
return False
if state in self.failed_states():
failed_reason = self.get_failed_reason_from_response(response)
raise AirflowException(f"Sagemaker job failed for the following reason: {failed_reason}")
raise AirflowException(
f"Sagemaker {self.resource_type} failed for the following reason: {failed_reason}"
)
return True

def non_terminal_states(self) -> set[str]:
Expand Down Expand Up @@ -269,3 +272,38 @@ def get_failed_reason_from_response(self, response):

def state_from_response(self, response):
return response["TrainingJobStatus"]


class SageMakerPipelineSensor(SageMakerBaseSensor):
"""
Polls the pipeline until it reaches a terminal state. Raises an
AirflowException with the failure reason if a failed state is reached.

.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:SageMakerPipelineSensor`

:param pipeline_exec_arn: ARN of the pipeline to watch.
:param verbose: Whether to print steps details while waiting for completion.
Defaults to true, consider turning off for pipelines that have thousands of steps.
"""

template_fields: Sequence[str] = ("pipeline_exec_arn",)

def __init__(self, *, pipeline_exec_arn: str, verbose: bool = True, **kwargs):
super().__init__(resource_type="pipeline", **kwargs)
self.pipeline_exec_arn = pipeline_exec_arn
self.verbose = verbose

def non_terminal_states(self) -> set[str]:
return SageMakerHook.pipeline_non_terminal_states

def failed_states(self) -> set[str]:
return SageMakerHook.failed_states

def get_sagemaker_response(self) -> dict:
self.log.info("Poking Sagemaker Pipeline Execution %s", self.pipeline_exec_arn)
return self.get_hook().describe_pipeline_exec(self.pipeline_exec_arn, self.verbose)

def state_from_response(self, response: dict) -> str:
return response["PipelineExecutionStatus"]
Loading