Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from airflow.utils.helpers import exactly_one
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.log.secrets_masker import mask_secret

Expand Down Expand Up @@ -493,7 +494,7 @@ def conn(self) -> BaseAwsConnection:

:return: boto3.client or boto3.resource
"""
if not ((not self.client_type) ^ (not self.resource_type)):
if not exactly_one(self.client_type, self.resource_type):
raise ValueError(
f"Either client_type={self.client_type!r} or "
f"resource_type={self.resource_type!r} must be provided, not both."
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink
from airflow.utils.helpers import exactly_one

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -71,7 +72,7 @@ def __init__(
wait_for_completion: bool = False,
**kwargs,
):
if not (job_flow_id is None) ^ (job_flow_name is None):
if not exactly_one(job_flow_id is None, job_flow_name is None):
raise AirflowException("Exactly one of job_flow_id or job_flow_name must be specified.")
super().__init__(**kwargs)
cluster_states = cluster_states or []
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/operators/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.helpers import exactly_one

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -463,11 +464,11 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.verify = verify

if not bool(keys is None) ^ bool(prefix is None):
if not exactly_one(prefix is None, keys is None):
raise AirflowException("Either keys or prefix should be set.")

def execute(self, context: Context):
if not bool(self.keys is None) ^ bool(self.prefix is None):
if not exactly_one(self.keys is None, self.prefix is None):
raise AirflowException("Either keys or prefix should be set.")

if isinstance(self.keys, (list, str)) and not bool(self.keys):
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/google/cloud/operators/cloud_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
CloudBuildTriggersListLink,
)
from airflow.utils import yaml
from airflow.utils.helpers import exactly_one

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -917,7 +918,7 @@ def __init__(self, build: dict | Build) -> None:
self.build = deepcopy(build)

def _verify_source(self) -> None:
if not (("storage_source" in self.build["source"]) ^ ("repo_source" in self.build["source"])):
if not exactly_one("storage_source" in self.build["source"], "repo_source" in self.build["source"]):
raise AirflowException(
"The source could not be determined. Please choose one data source from: "
"storage_source and repo_source."
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/slack/hooks/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.providers.slack.utils import ConnectionExtraConfig
from airflow.utils.helpers import exactly_one
from airflow.utils.log.secrets_masker import mask_secret

if TYPE_CHECKING:
Expand Down Expand Up @@ -268,7 +269,7 @@ def send_file(
- `Slack API files.upload method <https://api.slack.com/methods/files.upload>`_
- `File types <https://api.slack.com/types/file#file_types>`_
"""
if not ((not file) ^ (not content)):
if not exactly_one(file, content):
raise ValueError("Either `file` or `content` must be provided, not both.")
elif file:
file = Path(file)
Expand Down
21 changes: 18 additions & 3 deletions tests/providers/amazon/aws/operators/test_emr_add_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import json
import os
import unittest
from datetime import timedelta
from unittest.mock import MagicMock, call, patch

Expand All @@ -41,7 +40,7 @@
)


class TestEmrAddStepsOperator(unittest.TestCase):
class TestEmrAddStepsOperator:
# When
_config = [
{
Expand All @@ -54,7 +53,7 @@ class TestEmrAddStepsOperator(unittest.TestCase):
}
]

def setUp(self):
def setup_method(self):
self.args = {"owner": "airflow", "start_date": DEFAULT_DATE}

# Mock out the emr_client (moto has incorrect response)
Expand All @@ -79,6 +78,22 @@ def test_init(self):
assert self.operator.job_flow_id == "j-8989898989"
assert self.operator.aws_conn_id == "aws_default"

@pytest.mark.parametrize(
"job_flow_id, job_flow_name",
[
pytest.param("j-8989898989", "test_cluster", id="both-specified"),
pytest.param(None, None, id="both-none"),
],
)
def test_validate_mutually_exclusive_args(self, job_flow_id, job_flow_name):
error_message = r"Exactly one of job_flow_id or job_flow_name must be specified\."
with pytest.raises(AirflowException, match=error_message):
EmrAddStepsOperator(
task_id="test_validate_mutually_exclusive_args",
job_flow_id=job_flow_id,
job_flow_name=job_flow_name,
)

def test_render_template(self):
dag_run = DagRun(dag_id=self.operator.dag.dag_id, execution_date=DEFAULT_DATE, run_id="test")
ti = TaskInstance(task=self.operator)
Expand Down
74 changes: 39 additions & 35 deletions tests/providers/amazon/aws/operators/test_s3_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from unittest import mock

import boto3
import pytest
from moto import mock_s3

from airflow import AirflowException
Expand Down Expand Up @@ -95,8 +96,8 @@ def test_s3_copy_object_arg_combination_2(self):
assert objects_in_dest_bucket["Contents"][0]["Key"] == self.dest_key


class TestS3DeleteObjectsOperator(unittest.TestCase):
@mock_s3
@mock_s3
class TestS3DeleteObjectsOperator:
def test_s3_delete_single_object(self):
bucket = "testbucket"
key = "path/data.txt"
Expand All @@ -116,7 +117,6 @@ def test_s3_delete_single_object(self):
# There should be no object found in the bucket created earlier
assert "Contents" not in conn.list_objects(Bucket=bucket, Prefix=key)

@mock_s3
def test_s3_delete_multiple_objects(self):
bucket = "testbucket"
key_pattern = "path/data"
Expand All @@ -139,7 +139,6 @@ def test_s3_delete_multiple_objects(self):
# There should be no object found in the bucket created earlier
assert "Contents" not in conn.list_objects(Bucket=bucket, Prefix=key_pattern)

@mock_s3
def test_s3_delete_prefix(self):
bucket = "testbucket"
key_pattern = "path/data"
Expand All @@ -162,7 +161,6 @@ def test_s3_delete_prefix(self):
# There should be no object found in the bucket created earlier
assert "Contents" not in conn.list_objects(Bucket=bucket, Prefix=key_pattern)

@mock_s3
def test_s3_delete_empty_list(self):
bucket = "testbucket"
key_of_test = "path/data.txt"
Expand All @@ -185,7 +183,6 @@ def test_s3_delete_empty_list(self):
# the object found should be consistent with dest_key specified earlier
assert objects_in_dest_bucket["Contents"][0]["Key"] == key_of_test

@mock_s3
def test_s3_delete_empty_string(self):
bucket = "testbucket"
key_of_test = "path/data.txt"
Expand All @@ -208,50 +205,57 @@ def test_s3_delete_empty_string(self):
# the object found should be consistent with dest_key specified earlier
assert objects_in_dest_bucket["Contents"][0]["Key"] == key_of_test

@mock_s3
def test_assert_s3_both_keys_and_prifix_given(self):
bucket = "testbucket"
keys = "path/data.txt"
key_pattern = "path/data"

conn = boto3.client("s3")
conn.create_bucket(Bucket=bucket)
conn.upload_fileobj(Bucket=bucket, Key=keys, Fileobj=io.BytesIO(b"input"))

# The object should be detected before the DELETE action is tested
objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=keys)
assert len(objects_in_dest_bucket["Contents"]) == 1
assert objects_in_dest_bucket["Contents"][0]["Key"] == keys
with self.assertRaises(AirflowException):
op = S3DeleteObjectsOperator(
task_id="test_assert_s3_both_keys_and_prifix_given",
bucket=bucket,
@pytest.mark.parametrize(
"keys, prefix",
[
pytest.param("path/data.txt", "path/data", id="single-key-and-prefix"),
pytest.param(["path/data.txt"], "path/data", id="multiple-keys-and-prefix"),
pytest.param(None, None, id="both-none"),
],
)
def test_validate_keys_and_prefix_in_constructor(self, keys, prefix):
with pytest.raises(AirflowException, match=r"Either keys or prefix should be set\."):
S3DeleteObjectsOperator(
task_id="test_validate_keys_and_prefix_in_constructor",
bucket="foo-bar-bucket",
keys=keys,
prefix=key_pattern,
prefix=prefix,
)
op.execute(None)

# The object found in the bucket created earlier should still be there
assert len(objects_in_dest_bucket["Contents"]) == 1
# the object found should be consistent with dest_key specified earlier
assert objects_in_dest_bucket["Contents"][0]["Key"] == keys

@mock_s3
def test_assert_s3_no_keys_or_prifix_given(self):
@pytest.mark.parametrize(
"keys, prefix",
[
pytest.param("path/data.txt", "path/data", id="single-key-and-prefix"),
pytest.param(["path/data.txt"], "path/data", id="multiple-keys-and-prefix"),
pytest.param(None, None, id="both-none"),
],
)
def test_validate_keys_and_prefix_in_execute(self, keys, prefix):
bucket = "testbucket"
key_of_test = "path/data.txt"

conn = boto3.client("s3")
conn.create_bucket(Bucket=bucket)
conn.upload_fileobj(Bucket=bucket, Key=key_of_test, Fileobj=io.BytesIO(b"input"))

# Set valid values for constructor, and change them later for emulate rendering template
op = S3DeleteObjectsOperator(
task_id="test_validate_keys_and_prefix_in_execute",
bucket=bucket,
keys="keys-exists",
prefix=None,
)
op.keys = keys
op.prefix = prefix

# The object should be detected before the DELETE action is tested
objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=key_of_test)
assert len(objects_in_dest_bucket["Contents"]) == 1
assert objects_in_dest_bucket["Contents"][0]["Key"] == key_of_test
with self.assertRaises(AirflowException):
op = S3DeleteObjectsOperator(task_id="test_assert_s3_no_keys_or_prifix_given", bucket=bucket)

with pytest.raises(AirflowException, match=r"Either keys or prefix should be set\."):
op.execute(None)

# The object found in the bucket created earlier should still be there
assert len(objects_in_dest_bucket["Contents"]) == 1
# the object found should be consistent with dest_key specified earlier
Expand Down
6 changes: 5 additions & 1 deletion tests/providers/google/cloud/operators/test_cloud_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,13 @@ def test_update_build_trigger(self, mock_hook):

class TestBuildProcessor(TestCase):
def test_verify_source(self):
with pytest.raises(AirflowException, match="The source could not be determined."):
error_message = r"The source could not be determined."
with pytest.raises(AirflowException, match=error_message):
BuildProcessor(build={"source": {"storage_source": {}, "repo_source": {}}}).process_body()

with pytest.raises(AirflowException, match=error_message):
BuildProcessor(build={"source": {}}).process_body()

@parameterized.expand(
[
(
Expand Down