Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -362,26 +362,57 @@ def process_queue(self, queue_url: str):
MaxNumberOfMessages=10,
)

# Pagination? Maybe we don't need it. But we don't always delete messages after viewing them so we
# could possibly accumulate a lot of messages in the queue and get stuck if we don't read bigger
# chunks and paginate.
messages = response.get("Messages", [])
# Pagination? Maybe we don't need it. Since we always delete messages after looking at them.
# But then that may delete messages that could have been adopted. Let's leave it for now and see how it goes.
# The keys that we validate in the messages below will be different depending on whether or not
# the message is from the dead letter queue or the main results queue.
message_keys = ("return_code", "task_key")
if messages and queue_url == self.dlq_url:
self.log.warning("%d messages received from the dead letter queue", len(messages))
message_keys = ("command", "task_key")

for message in messages:
delete_message = False
receipt_handle = message["ReceiptHandle"]
body = json.loads(message["Body"])
try:
body = json.loads(message["Body"])
except json.JSONDecodeError:
self.log.warning(
"Received a message from the queue that could not be parsed as JSON: %s",
message["Body"],
)
delete_message = True
# If the message is not already marked for deletion, check if it has the required keys.
if not delete_message and not all(key in body for key in message_keys):
self.log.warning(
"Message is not formatted correctly, %s and/or %s are missing: %s", *message_keys, body
)
delete_message = True
if delete_message:
self.log.warning("Deleting the message to avoid processing it again.")
self.sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
continue
return_code = body.get("return_code")
ser_task_key = body.get("task_key")
# Fetch the real task key from the running_tasks dict, using the serialized task key.
try:
task_key = self.running_tasks[ser_task_key]
except KeyError:
self.log.warning(
"Received task %s from the queue which is not found in running tasks. Removing message.",
self.log.debug(
"Received task %s from the queue which is not found in running tasks, it is likely "
"from another Lambda Executor sharing this queue or might be a stale message that needs "
"deleting manually. Marking the message as visible again.",
ser_task_key,
)
task_key = None
# Mark task as visible again in SQS so that another executor can pick it up.
self.sqs_client.change_message_visibility(
QueueUrl=queue_url,
ReceiptHandle=receipt_handle,
VisibilityTimeout=0,
)
continue

if task_key:
if return_code == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def test_sync_running_dlq(self, success_mock, fail_mock, mock_executor, mock_air
mock_executor.running_tasks.clear()
mock_executor.running_tasks[ser_airflow_key] = airflow_key
mock_executor.sqs_client.receive_message.side_effect = [
{}, # First request from the results queue will be empt
{}, # First request from the results queue will be empty
{
# Second request from the DLQ will have a message
"Messages": [
Expand Down Expand Up @@ -510,6 +510,87 @@ def test_sync_running_fail(self, success_mock, fail_mock, mock_executor, mock_ai
fail_mock.assert_called_once()
assert mock_executor.sqs_client.delete_message.call_count == 1

def test_sync_running_fail_bad_json(self, mock_executor, mock_airflow_key):
airflow_key = mock_airflow_key()
ser_airflow_key = json.dumps(airflow_key._asdict())

mock_executor.running_tasks.clear()
mock_executor.running_tasks[ser_airflow_key] = airflow_key
mock_executor.sqs_client.receive_message.side_effect = [
{
"Messages": [
{
"ReceiptHandle": "receipt_handle",
"Body": "Banana", # Body not json format
}
]
},
{}, # Second request from the DLQ will be empty
]

mock_executor.sync_running_tasks()
# Assert that the message is deleted if the message is not formatted as json
assert mock_executor.sqs_client.receive_message.call_count == 2
assert mock_executor.sqs_client.delete_message.call_count == 1

def test_sync_running_fail_bad_format(self, mock_executor, mock_airflow_key):
airflow_key = mock_airflow_key()
ser_airflow_key = json.dumps(airflow_key._asdict())

mock_executor.running_tasks.clear()
mock_executor.running_tasks[ser_airflow_key] = airflow_key
mock_executor.sqs_client.receive_message.side_effect = [
{
"Messages": [
{
"ReceiptHandle": "receipt_handle",
"Body": json.dumps(
{
"foo": "bar", # Missing expected keys like "task_key"
"return_code": 1, # Non-zero return code, task failed
}
),
}
]
},
{}, # Second request from the DLQ will be empty
]

mock_executor.sync_running_tasks()
# Assert that the message is deleted if the message does not contain the expected keys
assert mock_executor.sqs_client.receive_message.call_count == 2
assert mock_executor.sqs_client.delete_message.call_count == 1

def test_sync_running_fail_bad_format_dlq(self, mock_executor, mock_airflow_key):
airflow_key = mock_airflow_key()
ser_airflow_key = json.dumps(airflow_key._asdict())

mock_executor.running_tasks.clear()
mock_executor.running_tasks[ser_airflow_key] = airflow_key
# Failure message
mock_executor.sqs_client.receive_message.side_effect = [
{}, # First request from the results queue will be empty
{
# Second request from the DLQ will have a message
"Messages": [
{
"ReceiptHandle": "receipt_handle",
"Body": json.dumps(
{
"foo": "bar", # Missing expected keys like "task_key"
"return_code": 1,
}
),
}
]
},
]

mock_executor.sync_running_tasks()
# Assert that the message is deleted if the message does not contain the expected keys
assert mock_executor.sqs_client.receive_message.call_count == 2
assert mock_executor.sqs_client.delete_message.call_count == 1

@mock.patch.object(BaseExecutor, "fail")
@mock.patch.object(BaseExecutor, "success")
def test_sync_running_short_circuit(self, success_mock, fail_mock, mock_executor, mock_airflow_key):
Expand Down Expand Up @@ -605,10 +686,12 @@ def test_sync_running_unknown_task(self, success_mock, fail_mock, mock_executor,
mock_executor.running_tasks[ser_airflow_key] = airflow_key

# Receive the known task and unknown task
known_task_receipt = "receipt_handle_known"
unknown_task_receipt = "receipt_handle_unknown"
mock_executor.sqs_client.receive_message.return_value = {
"Messages": [
{
"ReceiptHandle": "receipt_handle",
"ReceiptHandle": known_task_receipt,
"Body": json.dumps(
{
"task_key": ser_airflow_key,
Expand All @@ -617,7 +700,7 @@ def test_sync_running_unknown_task(self, success_mock, fail_mock, mock_executor,
),
},
{
"ReceiptHandle": "receipt_handle",
"ReceiptHandle": unknown_task_receipt,
"Body": json.dumps(
{
"task_key": ser_airflow_key_2,
Expand All @@ -635,8 +718,20 @@ def test_sync_running_unknown_task(self, success_mock, fail_mock, mock_executor,
assert len(mock_executor.running_tasks) == 0
success_mock.assert_called_once()
fail_mock.assert_not_called()
# Both messages from the queue should be deleted, both known and unknown
assert mock_executor.sqs_client.delete_message.call_count == 2
# Only the known message from the queue should be deleted, the other should be marked as visible again
assert mock_executor.sqs_client.delete_message.call_count == 1
assert mock_executor.sqs_client.change_message_visibility.call_count == 1
# The argument to delete_message should be the known task
assert mock_executor.sqs_client.delete_message.call_args_list[0].kwargs == {
"QueueUrl": DEFAULT_QUEUE_URL,
"ReceiptHandle": known_task_receipt,
}
# The change_message_visibility should be called with the unknown task
assert mock_executor.sqs_client.change_message_visibility.call_args_list[0].kwargs == {
"QueueUrl": DEFAULT_QUEUE_URL,
"ReceiptHandle": unknown_task_receipt,
"VisibilityTimeout": 0,
}

def test_start_no_check_health(self, mock_executor):
mock_executor.check_health = mock.Mock()
Expand Down