Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,8 @@ def __init__(
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
self.indexing_error_max_attempts = 5
self.indexing_error_retry_delay = 5

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
validated_event = validate_execute_complete_event(event)
Expand All @@ -654,9 +656,37 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None
return validated_event["ingestion_job_id"]

def execute(self, context: Context) -> str:
ingestion_job_id = self.hook.conn.start_ingestion_job(
knowledgeBaseId=self.knowledge_base_id, dataSourceId=self.data_source_id
)["ingestionJob"]["ingestionJobId"]
def start_ingestion_job():
try:
ingestion_job_id = self.hook.conn.start_ingestion_job(
knowledgeBaseId=self.knowledge_base_id, dataSourceId=self.data_source_id
)["ingestionJob"]["ingestionJobId"]

return ingestion_job_id
except ClientError as error:
error_message = error.response["Error"]["Message"].lower()
is_known_retryable_message = (
"dependency error document status code: 404" in error_message
or "request failed: [http_exception] server returned 401" in error_message
)
if all(
[
error.response["Error"]["Code"] == "ValidationException",
is_known_retryable_message,
self.indexing_error_max_attempts > 0,
]
):
self.indexing_error_max_attempts -= 1
self.log.warning(
"Index is not ready for ingestion, retrying in %s seconds.",
self.indexing_error_retry_delay,
)
self.log.info("%s retries remaining.", self.indexing_error_max_attempts)
sleep(self.indexing_error_retry_delay)
return start_ingestion_job()
raise

ingestion_job_id = start_ingestion_job()

if self.deferrable:
self.log.info("Deferring for ingestion job.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,104 @@ def test_id_returned(self, mock_conn):
def test_template_fields(self):
validate_template_fields(self.operator)

# Retry functionality tests

def _create_validation_error(self, message: str) -> ClientError:
"""Helper to create ValidationException with specific message."""
return ClientError(
error_response={"Error": {"Message": message, "Code": "ValidationException"}},
operation_name="StartIngestionJob",
)

@mock.patch("airflow.providers.amazon.aws.operators.bedrock.sleep")
@mock.patch("airflow.providers.amazon.aws.operators.bedrock.BedrockIngestDataOperator.log")
def test_retry_multiple_attempts_with_logging(self, mock_log, mock_sleep, mock_conn):
"""Test multiple retry attempts with proper logging."""
error_404 = self._create_validation_error("Dependency error document status code: 404")
success_response = {"ingestionJob": {"ingestionJobId": self.INGESTION_JOB_ID}}

# Fail 3 times, then succeed
mock_conn.start_ingestion_job.side_effect = [error_404, error_404, error_404, success_response]

result = self.operator.execute({})

assert result == self.INGESTION_JOB_ID
assert mock_conn.start_ingestion_job.call_count == 4
assert mock_sleep.call_count == 3

# Verify warning logs for retries
assert mock_log.warning.call_count == 3
mock_log.warning.assert_any_call("Index is not ready for ingestion, retrying in %s seconds.", 5)
assert mock_log.info.call_count == 3
expected_info_calls = [
mock.call("%s retries remaining.", 4),
mock.call("%s retries remaining.", 3),
mock.call("%s retries remaining.", 2),
]
mock_log.info.assert_has_calls(expected_info_calls)

@mock.patch("airflow.providers.amazon.aws.operators.bedrock.sleep")
def test_retry_exhaustion_raises_original_error(self, mock_sleep, mock_conn):
"""Test that original error is raised when retries are exhausted."""
error_404 = self._create_validation_error("Dependency error document status code: 404")

# Always fail (6 attempts total: initial + 5 retries)
mock_conn.start_ingestion_job.side_effect = [error_404] * 6

with pytest.raises(ClientError) as exc_info:
self.operator.execute({})

# Verify it's the original error
assert exc_info.value.response["Error"]["Code"] == "ValidationException"
assert "Dependency error document status code: 404" in exc_info.value.response["Error"]["Message"]

# Verify exactly 6 attempts were made (initial + 5 retries)
assert mock_conn.start_ingestion_job.call_count == 6
assert mock_sleep.call_count == 5

@pytest.mark.parametrize(
"error_message, should_retry",
[
("Dependency error document status code: 404", True),
("request failed: [http_exception] server returned 401", True),
("Some other validation error", False),
],
)
def test_retry_condition_validation(self, error_message, should_retry, mock_conn):
"""Test which error messages trigger retries."""
validation_error = self._create_validation_error(error_message)
mock_conn.start_ingestion_job.side_effect = [validation_error]

if should_retry:
# For retryable errors, we need to provide a success response for the retry
success_response = {"ingestionJob": {"ingestionJobId": self.INGESTION_JOB_ID}}
mock_conn.start_ingestion_job.side_effect = [validation_error, success_response]

with mock.patch("airflow.providers.amazon.aws.operators.bedrock.sleep"):
result = self.operator.execute({})
assert result == self.INGESTION_JOB_ID
assert mock_conn.start_ingestion_job.call_count == 2
else:
# For non-retryable errors, the original error should be raised immediately
with pytest.raises(ClientError):
self.operator.execute({})
assert mock_conn.start_ingestion_job.call_count == 1

def test_non_validation_exception_not_retried(self, mock_conn):
"""Test that non-ValidationException errors are not retried."""
access_denied_error = ClientError(
error_response={"Error": {"Message": "Access denied", "Code": "AccessDenied"}},
operation_name="StartIngestionJob",
)

mock_conn.start_ingestion_job.side_effect = [access_denied_error]

with pytest.raises(ClientError) as exc_info:
self.operator.execute({})

assert exc_info.value.response["Error"]["Code"] == "AccessDenied"
assert mock_conn.start_ingestion_job.call_count == 1


class TestBedrockRaGOperator:
VECTOR_SEARCH_CONFIG = {"filter": {"equals": {"key": "some key", "value": "some value"}}}
Expand Down