Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,15 @@ def get_job_description(self, job_id: str) -> dict:
return self.parse_job_description(job_id, response)

except botocore.exceptions.ClientError as err:
error = err.response.get("Error", {})
if error.get("Code") == "TooManyRequestsException":
pass # allow it to retry, if possible
else:
raise AirflowException(f"AWS Batch job ({job_id}) description error: {err}")
# Allow it to retry in case of exceeded quota limit of requests to AWS API
if err.response.get("Error", {}).get("Code") != "TooManyRequestsException":
raise
self.log.warning(
"Ignored TooManyRequestsException error, original message: %r. "
"Please consider to setup retries mode in boto3, "
"check Amazon Provider AWS Connection documentation for more details.",
str(err),
)

retries += 1
if retries >= self.status_retries:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/hooks/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,9 @@ def acknowledge(
:param metadata: (Optional) Additional metadata that is provided to the method.
"""
if ack_ids is not None and messages is None:
pass
pass # use ack_ids as is
elif ack_ids is None and messages is not None:
ack_ids = [message.ack_id for message in messages]
ack_ids = [message.ack_id for message in messages] # extract ack_ids from messages
else:
raise ValueError("One and only one of 'ack_ids' and 'messages' arguments have to be provided")

Expand Down
10 changes: 4 additions & 6 deletions airflow/providers/imap/hooks/imap.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,12 @@ def download_mail_attachments(
self._create_files(mail_attachments, local_output_directory)

def _handle_not_found_mode(self, not_found_mode: str) -> None:
if not_found_mode == "raise":
if not_found_mode not in ("raise", "warn", "ignore"):
self.log.error('Invalid "not_found_mode" %s', not_found_mode)
elif not_found_mode == "raise":
raise AirflowException("No mail attachments found!")
if not_found_mode == "warn":
elif not_found_mode == "warn":
self.log.warning("No mail attachments found!")
elif not_found_mode == "ignore":
pass # Do not notify if the attachment has not been found.
else:
self.log.error('Invalid "not_found_mode" %s', not_found_mode)

def _retrieve_mails_attachments_by_name(
self, name: str, check_regex: bool, latest_only: bool, mail_folder: str, mail_filter: str
Expand Down
8 changes: 2 additions & 6 deletions airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,8 @@ def get_file_by_pattern(self, path, fnmatch_pattern) -> str:
:param fnmatch_pattern: The pattern that will be matched with `fnmatch`
:return: string containing the first found file, or an empty string if none matched
"""
files_list = self.list_directory(path)

for file in files_list:
if not fnmatch(file, fnmatch_pattern):
pass
else:
for file in self.list_directory(path):
if fnmatch(file, fnmatch_pattern):
return file

return ""
19 changes: 9 additions & 10 deletions airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,17 +290,16 @@ def get_conn(self) -> paramiko.SSHClient:
known_hosts = os.path.expanduser("~/.ssh/known_hosts")
if not self.allow_host_key_change and os.path.isfile(known_hosts):
client.load_host_keys(known_hosts)
else:
if self.host_key is not None:
client_host_keys = client.get_host_keys()
if self.port == SSH_PORT:
client_host_keys.add(self.remote_host, self.host_key.get_name(), self.host_key)
else:
client_host_keys.add(
f"[{self.remote_host}]:{self.port}", self.host_key.get_name(), self.host_key
)

elif self.host_key is not None:
# Get host key from connection extra if it not set or None then we fallback to system host keys
client_host_keys = client.get_host_keys()
if self.port == SSH_PORT:
client_host_keys.add(self.remote_host, self.host_key.get_name(), self.host_key)
else:
pass # will fallback to system host keys if none explicitly specified in conn extra
client_host_keys.add(
f"[{self.remote_host}]:{self.port}", self.host_key.get_name(), self.host_key
)

connect_kwargs: dict[str, Any] = dict(
hostname=self.remote_host,
Expand Down
46 changes: 27 additions & 19 deletions tests/providers/amazon/aws/hooks/test_batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
# under the License.
from __future__ import annotations

import unittest
import logging
from unittest import mock

import botocore.exceptions
import pytest
from parameterized import parameterized

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
Expand All @@ -36,7 +35,7 @@
LOG_STREAM_NAME = "test/stream/d56a66bb98a14c4593defa1548686edf"


class TestBatchClient(unittest.TestCase):
class TestBatchClient:

MAX_RETRIES = 2
STATUS_RETRIES = 3
Expand All @@ -45,7 +44,7 @@ class TestBatchClient(unittest.TestCase):
@mock.patch.dict("os.environ", AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID)
@mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY)
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
def setUp(self, get_client_type_mock):
def setup_method(self, method, get_client_type_mock):
self.get_client_type_mock = get_client_type_mock
self.batch_client = BatchClientHook(
max_retries=self.MAX_RETRIES,
Expand Down Expand Up @@ -135,13 +134,17 @@ def test_poll_job_complete_raises_for_max_retries(self):
self.client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID])
assert self.client_mock.describe_jobs.call_count == self.MAX_RETRIES + 1

def test_poll_job_status_hit_api_throttle(self):
def test_poll_job_status_hit_api_throttle(self, caplog):
self.client_mock.describe_jobs.side_effect = botocore.exceptions.ClientError(
error_response={"Error": {"Code": "TooManyRequestsException"}},
operation_name="get job description",
)
with pytest.raises(AirflowException) as ctx:
self.batch_client.poll_for_job_complete(JOB_ID)
with caplog.at_level(level=logging.getLevelName("WARNING")):
self.batch_client.poll_for_job_complete(JOB_ID)
log_record = caplog.records[0]
assert "Ignored TooManyRequestsException error" in log_record.message

msg = f"AWS Batch job ({JOB_ID}) description error"
assert msg in str(ctx.value)
# It should retry when this client error occurs
Expand All @@ -153,10 +156,10 @@ def test_poll_job_status_with_client_error(self):
error_response={"Error": {"Code": "InvalidClientTokenId"}},
operation_name="get job description",
)
with pytest.raises(AirflowException) as ctx:
with pytest.raises(botocore.exceptions.ClientError) as ctx:
self.batch_client.poll_for_job_complete(JOB_ID)
msg = f"AWS Batch job ({JOB_ID}) description error"
assert msg in str(ctx.value)

assert ctx.value.response["Error"]["Code"] == "InvalidClientTokenId"
# It will not retry when this client error occurs
self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])

Expand Down Expand Up @@ -272,7 +275,7 @@ def test_job_awslogs_user_defined(self):
assert awslogs["awslogs_group"] == "/test/batch/job"
assert awslogs["awslogs_region"] == "ap-southeast-2"

def test_job_no_awslogs_stream(self):
def test_job_no_awslogs_stream(self, caplog):
self.client_mock.describe_jobs.return_value = {
"jobs": [
{
Expand All @@ -281,11 +284,13 @@ def test_job_no_awslogs_stream(self):
}
]
}
with self.assertLogs(level="WARNING") as capture_logs:
with caplog.at_level(level=logging.getLevelName("WARNING")):
assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
assert len(capture_logs.records) == 1
assert len(caplog.records) == 1
log_record = caplog.records[0]
assert "doesn't create AWS CloudWatch Stream" in log_record.message

def test_job_splunk_logs(self):
def test_job_splunk_logs(self, caplog):
self.client_mock.describe_jobs.return_value = {
"jobs": [
{
Expand All @@ -299,16 +304,18 @@ def test_job_splunk_logs(self):
}
]
}
with self.assertLogs(level="WARNING") as capture_logs:
with caplog.at_level(level=logging.getLevelName("WARNING")):
assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
assert len(capture_logs.records) == 1
assert len(caplog.records) == 1
log_record = caplog.records[0]
assert "uses logDriver (splunk). AWS CloudWatch logging disabled." in log_record.message


class TestBatchClientDelays(unittest.TestCase):
class TestBatchClientDelays:
@mock.patch.dict("os.environ", AWS_DEFAULT_REGION=AWS_REGION)
@mock.patch.dict("os.environ", AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID)
@mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY)
def setUp(self):
def setup_method(self, method):
self.batch_client = BatchClientHook(aws_conn_id="airflow_test", region_name=AWS_REGION)
# We're mocking all actual AWS calls and don't need a connection. This
# avoids an Airflow warning about connection cannot be found.
Expand Down Expand Up @@ -360,7 +367,8 @@ def test_delay_with_float(self, mock_sleep, mock_uniform):
mock_uniform.assert_called_once_with(4.0, 6.0) # in add_jitter
mock_sleep.assert_called_once_with(mock_uniform.return_value)

@parameterized.expand(
@pytest.mark.parametrize(
"tries, lower, upper",
[
(0, 0, 1),
(1, 0, 2),
Expand All @@ -373,7 +381,7 @@ def test_delay_with_float(self, mock_sleep, mock_uniform):
(8, 8, 25),
(9, 10, 31),
(45, 200, 600), # > 40 tries invokes maximum delay allowed
]
],
)
def test_exponential_delay(self, tries, lower, upper):
result = self.batch_client.exponential_delay(tries)
Expand Down
42 changes: 29 additions & 13 deletions tests/providers/google/cloud/hooks/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,24 @@ def mock_init(
pass


def _generate_messages(count) -> list[ReceivedMessage]:
return [
ReceivedMessage(
ack_id=str(i),
message={
"data": f"Message {i}".encode(),
"attributes": {"type": "generated message"},
},
)
for i in range(1, count + 1)
]


class TestPubSubHook(unittest.TestCase):
def setUp(self):
with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_init):
self.pubsub_hook = PubSubHook(gcp_conn_id="test")

def _generate_messages(self, count) -> list[ReceivedMessage]:
return [
ReceivedMessage(
ack_id=str(i),
message={
"data": f"Message {i}".encode(),
"attributes": {"type": "generated message"},
},
)
for i in range(1, count + 1)
]

@mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubHook.get_credentials")
@mock.patch("airflow.providers.google.cloud.hooks.pubsub.PublisherClient")
def test_publisher_client_creation(self, mock_client, mock_get_creds):
Expand Down Expand Up @@ -478,7 +479,7 @@ def test_acknowledge_by_message_objects(self, mock_service):
self.pubsub_hook.acknowledge(
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
messages=self._generate_messages(3),
messages=_generate_messages(3),
)
ack_method.assert_called_once_with(
request=dict(
Expand All @@ -490,6 +491,21 @@ def test_acknowledge_by_message_objects(self, mock_service):
metadata=(),
)

@parameterized.expand([(None, None), ([1, 2, 3], _generate_messages(3))])
@mock.patch(PUBSUB_STRING.format("PubSubHook.subscriber_client"))
def test_acknowledge_fails_on_method_args_validation(self, ack_ids, messages, mock_service):
ack_method = mock_service.acknowledge

error_message = r"One and only one of 'ack_ids' and 'messages' arguments have to be provided"
with pytest.raises(ValueError, match=error_message):
self.pubsub_hook.acknowledge(
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
ack_ids=ack_ids,
messages=messages,
)
ack_method.assert_not_called()

@parameterized.expand(
[
(exception,)
Expand Down