Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions providers/amazon/docs/operators/sagemaker.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,32 @@

import time
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
from airflow.sensors.base import BaseSensorOperator
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class SageMakerBaseSensor(BaseSensorOperator):
class SageMakerBaseSensor(AwsBaseSensor[SageMakerHook]):
"""
Contains general sensor behavior for SageMaker.

Subclasses should implement get_sagemaker_response() and state_from_response() methods.
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods.
"""

aws_hook_class = SageMakerHook
ui_color = "#ededed"

def __init__(self, *, aws_conn_id: str | None = "aws_default", resource_type: str = "job", **kwargs):
def __init__(self, *, resource_type: str = "job", **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.resource_type = resource_type # only used for logs, to say what kind of resource we are sensing

@cached_property
def hook(self) -> SageMakerHook:
return SageMakerHook(aws_conn_id=self.aws_conn_id)

def poke(self, context: Context):
response = self.get_sagemaker_response()
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
Expand Down Expand Up @@ -96,7 +92,9 @@ class SageMakerEndpointSensor(SageMakerBaseSensor):
:param endpoint_name: Name of the endpoint instance to watch.
"""

template_fields: Sequence[str] = ("endpoint_name",)
template_fields: Sequence[str] = aws_template_fields(
"endpoint_name",
)
template_ext: Sequence[str] = ()

def __init__(self, *, endpoint_name, **kwargs):
Expand Down Expand Up @@ -131,7 +129,9 @@ class SageMakerTransformSensor(SageMakerBaseSensor):
:param job_name: Name of the transform job to watch.
"""

template_fields: Sequence[str] = ("job_name",)
template_fields: Sequence[str] = aws_template_fields(
"job_name",
)
template_ext: Sequence[str] = ()

def __init__(self, *, job_name: str, **kwargs):
Expand Down Expand Up @@ -166,7 +166,9 @@ class SageMakerTuningSensor(SageMakerBaseSensor):
:param job_name: Name of the tuning instance to watch.
"""

template_fields: Sequence[str] = ("job_name",)
template_fields: Sequence[str] = aws_template_fields(
"job_name",
)
template_ext: Sequence[str] = ()

def __init__(self, *, job_name: str, **kwargs):
Expand Down Expand Up @@ -202,7 +204,9 @@ class SageMakerTrainingSensor(SageMakerBaseSensor):
:param print_log: Prints the cloudwatch log if True; Defaults to True.
"""

template_fields: Sequence[str] = ("job_name",)
template_fields: Sequence[str] = aws_template_fields(
"job_name",
)
template_ext: Sequence[str] = ()

def __init__(self, *, job_name, print_log=True, **kwargs):
Expand Down Expand Up @@ -281,7 +285,9 @@ class SageMakerPipelineSensor(SageMakerBaseSensor):
Defaults to true, consider turning off for pipelines that have thousands of steps.
"""

template_fields: Sequence[str] = ("pipeline_exec_arn",)
template_fields: Sequence[str] = aws_template_fields(
"pipeline_exec_arn",
)

def __init__(self, *, pipeline_exec_arn: str, verbose: bool = True, **kwargs):
super().__init__(resource_type="pipeline", **kwargs)
Expand Down Expand Up @@ -313,7 +319,9 @@ class SageMakerAutoMLSensor(SageMakerBaseSensor):
:param job_name: unique name of the AutoML job to watch.
"""

template_fields: Sequence[str] = ("job_name",)
template_fields: Sequence[str] = aws_template_fields(
"job_name",
)

def __init__(self, *, job_name: str, **kwargs):
super().__init__(resource_type="autoML job", **kwargs)
Expand Down Expand Up @@ -344,7 +352,9 @@ class SageMakerProcessingSensor(SageMakerBaseSensor):
:param job_name: Name of the processing job to watch.
"""

template_fields: Sequence[str] = ("job_name",)
template_fields: Sequence[str] = aws_template_fields(
"job_name",
)
template_ext: Sequence[str] = ()

def __init__(self, *, job_name: str, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,5 @@ def test_sensor(self, mock_describe, hook_init, mock_get_conn):
assert mock_describe.call_count == 3

# make sure the hook was initialized with the specific params
calls = [mock.call(aws_conn_id="aws_test")]
calls = [mock.call(aws_conn_id="aws_test", config=None, verify=None, region_name=None)]
hook_init.assert_has_calls(calls)
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,5 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client):
assert mock_describe_job.call_count == 3

# make sure the hook was initialized with the specific params
calls = [mock.call(aws_conn_id="aws_test")]
calls = [mock.call(aws_conn_id="aws_test", config=None, verify=None, region_name=None)]
hook_init.assert_has_calls(calls)
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client):
assert mock_describe_job.call_count == 3

# make sure the hook was initialized with the specific params
calls = [mock.call(aws_conn_id="aws_test")]
calls = [mock.call(aws_conn_id="aws_test", config=None, verify=None, region_name=None)]
hook_init.assert_has_calls(calls)

@mock.patch.object(SageMakerHook, "get_conn")
Expand Down Expand Up @@ -123,5 +123,5 @@ def test_sensor_with_log(
assert mock_describe_job_with_log.call_count == 3
assert mock_describe_job.call_count == 1

calls = [mock.call(aws_conn_id="aws_test")]
calls = [mock.call(aws_conn_id="aws_test", config=None, verify=None, region_name=None)]
hook_init.assert_has_calls(calls)
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,5 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client):
assert mock_describe_job.call_count == 3

# make sure the hook was initialized with the specific params
calls = [mock.call(aws_conn_id="aws_test")]
calls = [mock.call(aws_conn_id="aws_test", config=None, verify=None, region_name=None)]
hook_init.assert_has_calls(calls)
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,5 @@ def test_sensor(self, mock_describe_job, hook_init, mock_client):
assert mock_describe_job.call_count == 3

# make sure the hook was initialized with the specific params
calls = [mock.call(aws_conn_id="aws_test")]
calls = [mock.call(aws_conn_id="aws_test", config=None, verify=None, region_name=None)]
hook_init.assert_has_calls(calls)