Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 85 additions & 12 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,26 +136,64 @@ def _get_unique_job_name(
:param describe_func: The `describe_` function for that kind of job.
We use it as an O(1) way to check if a job exists.
"""
job_name = proposed_name
while self._check_if_job_exists(job_name, describe_func):
return self._get_unique_name(
proposed_name, fail_if_exists, describe_func, self._check_if_job_exists, "job"
)

def _get_unique_name(
self,
proposed_name: str,
fail_if_exists: bool,
describe_func: Callable[[str], Any],
check_exists_func: Callable[[str, Callable[[str], Any]], bool],
resource_type: str,
) -> str:
"""
Return the proposed name if it doesn't already exist, otherwise returns it with a timestamp suffix.

:param proposed_name: Base name.
:param fail_if_exists: Will throw an error if a resource with that name already exists
instead of finding a new name.
:param check_exists_func: The function to check if the resource exists.
It should take the resource name and a describe function as arguments.
:param resource_type: Type of the resource (e.g., "model", "job").
"""
self._check_resource_type(resource_type)
name = proposed_name
while check_exists_func(name, describe_func):
# this while should loop only once in most cases, just setting it this way to regenerate a name
# in case there is collision.
if fail_if_exists:
raise AirflowException(f"A SageMaker job with name {job_name} already exists.")
raise AirflowException(f"A SageMaker {resource_type} with name {name} already exists.")
else:
job_name = f"{proposed_name}-{time.time_ns()//1000000}"
self.log.info("Changed job name to '%s' to avoid collision.", job_name)
return job_name
name = f"{proposed_name}-{time.time_ns()//1000000}"
self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name)
return name

def _check_resource_type(self, resource_type: str):
"""Raise exception if resource type is not 'model' or 'job'."""
if resource_type not in ("model", "job"):
raise AirflowException(
"Argument resource_type accepts only 'model' and 'job'. "
f"Provided value: '{resource_type}'."
)

def _check_if_job_exists(self, job_name, describe_func: Callable[[str], Any]) -> bool:
def _check_if_job_exists(self, job_name: str, describe_func: Callable[[str], Any]) -> bool:
"""Return True if job exists, False otherwise."""
return self._check_if_resource_exists(job_name, "job", describe_func)

def _check_if_resource_exists(
self, resource_name: str, resource_type: str, describe_func: Callable[[str], Any]
) -> bool:
"""Return True if resource exists, False otherwise."""
self._check_resource_type(resource_type)
try:
describe_func(job_name)
self.log.info("Found existing job with name '%s'.", job_name)
describe_func(resource_name)
self.log.info("Found existing %s with name '%s'.", resource_type, resource_name)
return True
except ClientError as e:
if e.response["Error"]["Code"] == "ValidationException":
return False # ValidationException is thrown when the job could not be found
return False # ValidationException is thrown when the resource could not be found
else:
raise e

Expand Down Expand Up @@ -637,6 +675,8 @@ def __init__(
max_ingestion_time: int | None = None,
check_if_job_exists: bool = True,
action_if_job_exists: str = "timestamp",
check_if_model_exists: bool = True,
action_if_model_exists: str = "timestamp",
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
Expand All @@ -660,6 +700,14 @@ def __init__(
f"Argument action_if_job_exists accepts only 'timestamp', 'increment' and 'fail'. \
Provided value: '{action_if_job_exists}'."
)
self.check_if_model_exists = check_if_model_exists
if action_if_model_exists in ("fail", "timestamp"):
self.action_if_model_exists = action_if_model_exists
else:
raise AirflowException(
f"Argument action_if_model_exists accepts only 'timestamp' and 'fail'. \
Provided value: '{action_if_model_exists}'."
)
self.deferrable = deferrable
self.serialized_model: dict
self.serialized_transform: dict
Expand Down Expand Up @@ -697,6 +745,14 @@ def execute(self, context: Context) -> dict:

model_config = self.config.get("Model")
if model_config:
if self.check_if_model_exists:
model_config["ModelName"] = self._get_unique_model_name(
model_config["ModelName"],
self.action_if_model_exists == "fail",
self.hook.describe_model,
)
if "ModelName" in self.config["Transform"].keys():
self.config["Transform"]["ModelName"] = model_config["ModelName"]
self.log.info("Creating SageMaker Model %s for transform job", model_config["ModelName"])
self.hook.create_model(model_config)

Expand Down Expand Up @@ -752,6 +808,17 @@ def execute(self, context: Context) -> dict:

return self.serialize_result()

def _get_unique_model_name(
self, proposed_name: str, fail_if_exists: bool, describe_func: Callable[[str], Any]
) -> str:
return self._get_unique_name(
proposed_name, fail_if_exists, describe_func, self._check_if_model_exists, "model"
)

def _check_if_model_exists(self, model_name: str, describe_func: Callable[[str], Any]) -> bool:
"""Return True if model exists, False otherwise."""
return self._check_if_resource_exists(model_name, "model", describe_func)

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
event = validate_execute_complete_event(event)

Expand Down Expand Up @@ -885,7 +952,8 @@ def _create_integer_fields(self) -> None:
def execute(self, context: Context) -> dict:
self.preprocess_config()
self.log.info(
"Creating SageMaker Hyper-Parameter Tuning Job %s", self.config["HyperParameterTuningJobName"]
"Creating SageMaker Hyper-Parameter Tuning Job %s",
self.config["HyperParameterTuningJobName"],
)
response = self.hook.create_tuning_job(
self.config,
Expand Down Expand Up @@ -1238,7 +1306,12 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
:return str: Returns The ARN of the pipeline execution created in Amazon SageMaker.
"""

template_fields: Sequence[str] = ("aws_conn_id", "pipeline_name", "display_name", "pipeline_params")
template_fields: Sequence[str] = (
"aws_conn_id",
"pipeline_name",
"display_name",
"pipeline_params",
)

def __init__(
self,
Expand Down
92 changes: 87 additions & 5 deletions tests/providers/amazon/aws/operators/test_sagemaker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,16 @@
)
from airflow.utils import timezone

CONFIG: dict = {"key1": "1", "key2": {"key3": "3", "key4": "4"}, "key5": [{"key6": "6"}, {"key6": "7"}]}
PARSED_CONFIG: dict = {"key1": 1, "key2": {"key3": 3, "key4": 4}, "key5": [{"key6": 6}, {"key6": 7}]}
CONFIG: dict = {
"key1": "1",
"key2": {"key3": "3", "key4": "4"},
"key5": [{"key6": "6"}, {"key6": "7"}],
}
PARSED_CONFIG: dict = {
"key1": 1,
"key2": {"key3": 3, "key4": 4},
"key5": [{"key6": 6}, {"key6": 7}],
}

EXPECTED_INTEGER_FIELDS: list[list[Any]] = []

Expand All @@ -46,7 +54,12 @@ def setup_method(self):
self.sagemaker.aws_conn_id = "aws_default"

def test_parse_integer(self):
self.sagemaker.integer_fields = [["key1"], ["key2", "key3"], ["key2", "key4"], ["key5", "key6"]]
self.sagemaker.integer_fields = [
["key1"],
["key2", "key3"],
["key2", "key4"],
["key5", "key6"],
]
self.sagemaker.parse_config_integers()
assert self.sagemaker.config == PARSED_CONFIG

Expand Down Expand Up @@ -79,10 +92,77 @@ def test_job_not_unique_with_fail(self):
with pytest.raises(AirflowException):
self.sagemaker._get_unique_job_name("test", True, lambda _: None)

def test_check_resource_type_raises_exception_when_resource_type_is_invalid(self):
with pytest.raises(AirflowException) as context:
self.sagemaker._check_resource_type("invalid_resource")

assert str(context.value) == (
"Argument resource_type accepts only 'model' and 'job'. Provided value: 'invalid_resource'."
)

def test_get_unique_name_raises_exception_if_name_exists_when_fail_is_true(self):
with pytest.raises(AirflowException) as context:
self.sagemaker._get_unique_name(
"existing_name",
fail_if_exists=True,
describe_func=None,
check_exists_func=lambda name, describe_func: True,
resource_type="model",
)

assert str(context.value) == "A SageMaker model with name existing_name already exists."

@patch("airflow.providers.amazon.aws.operators.sagemaker.time.time_ns", return_value=3000000)
def test_get_unique_name_avoids_name_collision(self, time_mock):
new_name = self.sagemaker._get_unique_name(
"existing_name",
fail_if_exists=False,
describe_func=None,
check_exists_func=MagicMock(side_effect=[True, False]),
resource_type="model",
)

assert new_name == "existing_name-3"

def test_get_unique_name_checks_only_once_when_resource_does_not_exist(self):
describe_func = MagicMock(side_effect=ClientError({"Error": {"Code": "ValidationException"}}, "op"))
new_name = "new_name"

name = self.sagemaker._get_unique_name(
new_name,
fail_if_exists=False,
describe_func=describe_func,
check_exists_func=self.sagemaker._check_if_job_exists,
resource_type="job",
)

describe_func.assert_called_once_with(new_name)
assert name == new_name

def test_check_if_resource_exists_returns_true_when_it_finds_existing_resource(self):
exists = self.sagemaker._check_if_resource_exists("job_123", "job", lambda name: None)
assert exists

def test_check_if_resource_exists_returns_false_when_validation_exception_is_raised(self):
describe_func = MagicMock(side_effect=ClientError({"Error": {"Code": "ValidationException"}}, "op"))
exists = self.sagemaker._check_if_resource_exists("job_123", "job", describe_func)
assert not exists

def test_check_if_resource_exists_raises_when_it_is_not_validation_exception(self):
describe_func = MagicMock(side_effect=ValueError("different exception"))

with pytest.raises(ValueError) as context:
self.sagemaker._check_if_resource_exists("job_123", "job", describe_func)

assert str(context.value) == "different exception"


@pytest.mark.db_test
class TestSageMakerExperimentOperator:
@patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.conn", new_callable=mock.PropertyMock)
@patch(
"airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.conn",
new_callable=mock.PropertyMock,
)
def test_create_experiment(self, conn_mock):
conn_mock().create_experiment.return_value = {"ExperimentArn": "abcdef"}

Expand All @@ -106,5 +186,7 @@ def test_create_experiment(self, conn_mock):

assert ret == "abcdef"
conn_mock().create_experiment.assert_called_once_with(
ExperimentName="the name", Description="the desc", Tags=[{"Key": "jinja", "Value": "tid"}]
ExperimentName="the name",
Description="the desc",
Tags=[{"Key": "jinja", "Value": "tid"}],
)
Loading