Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,6 @@ norecursedirs = [
log_level = "INFO"
filterwarnings = [
"error::pytest.PytestCollectionWarning",
"error::pytest.PytestReturnNotNoneWarning",
# Avoid building cartesian product which might impact performance
"error:SELECT statement has a cartesian product between FROM:sqlalchemy.exc.SAWarning:airflow",
'error:Coercing Subquery object into a select\(\) for use in IN\(\):sqlalchemy.exc.SAWarning:airflow',
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/amazon/aws/hooks/test_batch_waiters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def aws_region():
return AWS_REGION


@mock_aws
@pytest.fixture
def patch_hook(monkeypatch, aws_region):
"""Patch hook object by dummy boto3 Batch client."""
Expand All @@ -59,6 +58,7 @@ def test_batch_waiters(aws_region):
assert isinstance(batch_waiters, BatchWaitersHook)


@mock_aws
class TestBatchWaiters:
@pytest.fixture(autouse=True)
def setup_tests(self, patch_hook):
Expand Down Expand Up @@ -216,6 +216,7 @@ def test_wait_for_job_raises_for_waiter_error(self):
assert mock_waiter.wait.call_count == 1


@mock_aws
class TestBatchJobWaiters:
"""Test default waiters."""

Expand Down
249 changes: 170 additions & 79 deletions tests/providers/amazon/aws/log/test_s3_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@
import pytest
from botocore.exceptions import ClientError
from moto import mock_aws
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

from airflow.models import DAG, DagRun, TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.log.s3_task_handler import S3TaskHandler
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.timezone import datetime
from tests.test_utils.config import conf_vars


@pytest.fixture(autouse=True)
Expand All @@ -43,26 +44,40 @@ def s3mock():


@pytest.mark.db_test
class TestS3TaskHandler:
@conf_vars({("logging", "remote_log_conn_id"): "aws_default"})
class TestS3RemoteLogIO:
@pytest.fixture(autouse=True)
def setup_tests(self, create_log_template, tmp_path_factory, session):
self.remote_log_base = "s3://bucket/remote/log/location"
self.remote_log_location = "s3://bucket/remote/log/location/1.log"
self.remote_log_key = "remote/log/location/1.log"
self.local_log_location = str(tmp_path_factory.mktemp("local-s3-log-location"))
create_log_template("{try_number}.log")
self.s3_task_handler = S3TaskHandler(self.local_log_location, self.remote_log_base)
# Verify the hook now with the config override
assert self.s3_task_handler.hook is not None

date = datetime(2016, 1, 1)
self.dag = DAG("dag_for_testing_s3_task_handler", schedule=None, start_date=date)
task = EmptyOperator(task_id="task_for_testing_s3_log_handler", dag=self.dag)
dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=date, run_id="test", run_type="manual")
session.add(dag_run)
session.commit()
session.refresh(dag_run)
with conf_vars({("logging", "remote_log_conn_id"): "aws_default"}):
self.remote_log_base = "s3://bucket/remote/log/location"
self.remote_log_location = "s3://bucket/remote/log/location/1.log"
self.remote_log_key = "remote/log/location/1.log"
self.local_log_location = str(tmp_path_factory.mktemp("local-s3-log-location"))
create_log_template("{try_number}.log")
self.s3_task_handler = S3TaskHandler(self.local_log_location, self.remote_log_base)
# Verify the hook now with the config override
self.subject = self.s3_task_handler.io
assert self.subject.hook is not None

date = datetime(2016, 1, 1)
self.dag = DAG("dag_for_testing_s3_task_handler", schedule=None, start_date=date)
task = EmptyOperator(task_id="task_for_testing_s3_log_handler", dag=self.dag)
if AIRFLOW_V_3_0_PLUS:
dag_run = DagRun(
dag_id=self.dag.dag_id,
logical_date=date,
run_id="test",
run_type="manual",
)
else:
dag_run = DagRun(
dag_id=self.dag.dag_id,
execution_date=date,
run_id="test",
run_type="manual",
)
session.add(dag_run)
session.commit()
session.refresh(dag_run)

self.ti = TaskInstance(task=task, run_id=dag_run.run_id)
self.ti.dag_run = dag_run
Expand All @@ -83,71 +98,30 @@ def setup_tests(self, create_log_template, tmp_path_factory, session):
os.remove(self.s3_task_handler.handler.baseFilename)

def test_hook(self):
assert isinstance(self.s3_task_handler.hook, S3Hook)
assert self.s3_task_handler.hook.transfer_config.use_threads is False
assert isinstance(self.subject.hook, S3Hook)
assert self.subject.hook.transfer_config.use_threads is False

def test_log_exists(self):
self.conn.put_object(Bucket="bucket", Key=self.remote_log_key, Body=b"")
assert self.s3_task_handler.s3_log_exists(self.remote_log_location)
assert self.subject.s3_log_exists(self.remote_log_location)

def test_log_exists_none(self):
assert not self.s3_task_handler.s3_log_exists(self.remote_log_location)
assert not self.subject.s3_log_exists(self.remote_log_location)

def test_log_exists_raises(self):
assert not self.s3_task_handler.s3_log_exists("s3://nonexistentbucket/foo")
assert not self.subject.s3_log_exists("s3://nonexistentbucket/foo")

def test_log_exists_no_hook(self):
handler = S3TaskHandler(self.local_log_location, self.remote_log_base)
subject = S3TaskHandler(self.local_log_location, self.remote_log_base).io
with mock.patch.object(S3Hook, "__init__", spec=S3Hook) as mock_hook:
mock_hook.side_effect = ConnectionError("Fake: Failed to connect")
with pytest.raises(ConnectionError, match="Fake: Failed to connect"):
handler.s3_log_exists(self.remote_log_location)

def test_set_context_raw(self):
self.ti.raw = True
mock_open = mock.mock_open()
with mock.patch("airflow.providers.amazon.aws.log.s3_task_handler.open", mock_open):
self.s3_task_handler.set_context(self.ti)

assert not self.s3_task_handler.upload_on_close
mock_open.assert_not_called()

def test_set_context_not_raw(self):
mock_open = mock.mock_open()
with mock.patch("airflow.providers.amazon.aws.log.s3_task_handler.open", mock_open):
self.s3_task_handler.set_context(self.ti)

assert self.s3_task_handler.upload_on_close
mock_open.assert_called_once_with(os.path.join(self.local_log_location, "1.log"), "w")
mock_open().write.assert_not_called()

def test_read(self):
self.conn.put_object(Bucket="bucket", Key=self.remote_log_key, Body=b"Log line\n")
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
log, metadata = self.s3_task_handler.read(ti)
actual = log[0][0][-1]
assert "*** Found logs in s3:\n*** * s3://bucket/remote/log/location/1.log\n" in actual
assert actual.endswith("Log line")
assert metadata == [{"end_of_log": True, "log_pos": 8}]

def test_read_when_s3_log_missing(self):
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
self.s3_task_handler._read_from_logs_server = mock.Mock(return_value=([], []))
log, metadata = self.s3_task_handler.read(ti)
assert 1 == len(log)
assert len(log) == len(metadata)
actual = log[0][0][-1]
expected = "*** No logs found on s3 for ti=<TaskInstance: dag_for_testing_s3_task_handler.task_for_testing_s3_log_handler test [success]>\n"
assert expected in actual
assert {"end_of_log": True, "log_pos": 0} == metadata[0]
subject.s3_log_exists(self.remote_log_location)

def test_s3_read_when_log_missing(self):
handler = self.s3_task_handler
url = "s3://bucket/foo"
with mock.patch.object(handler.log, "error") as mock_error:
result = handler.s3_read(url, return_error=True)
with mock.patch.object(self.subject.log, "error") as mock_error:
result = self.subject.s3_read(url, return_error=True)
msg = (
f"Could not read logs from {url} with error: An error occurred (404) when calling the "
f"HeadObject operation: Not Found"
Expand All @@ -156,10 +130,9 @@ def test_s3_read_when_log_missing(self):
mock_error.assert_called_once_with(msg, exc_info=True)

def test_read_raises_return_error(self):
handler = self.s3_task_handler
url = "s3://nonexistentbucket/foo"
with mock.patch.object(handler.log, "error") as mock_error:
result = handler.s3_read(url, return_error=True)
with mock.patch.object(self.subject.log, "error") as mock_error:
result = self.subject.s3_read(url, return_error=True)
msg = (
f"Could not read logs from {url} with error: An error occurred (NoSuchBucket) when "
f"calling the HeadObject operation: The specified bucket does not exist"
Expand All @@ -168,8 +141,8 @@ def test_read_raises_return_error(self):
mock_error.assert_called_once_with(msg, exc_info=True)

def test_write(self):
with mock.patch.object(self.s3_task_handler.log, "error") as mock_error:
self.s3_task_handler.s3_write("text", self.remote_log_location)
with mock.patch.object(self.subject.log, "error") as mock_error:
self.subject.write("text", self.remote_log_location)
# We shouldn't expect any error logs in the default working case.
mock_error.assert_not_called()
body = boto3.resource("s3").Object("bucket", self.remote_log_key).get()["Body"].read()
Expand All @@ -178,18 +151,132 @@ def test_write(self):

def test_write_existing(self):
self.conn.put_object(Bucket="bucket", Key=self.remote_log_key, Body=b"previous ")
self.s3_task_handler.s3_write("text", self.remote_log_location)
self.subject.write("text", self.remote_log_location)
body = boto3.resource("s3").Object("bucket", self.remote_log_key).get()["Body"].read()

assert body == b"previous \ntext"

def test_write_raises(self):
handler = self.s3_task_handler
url = "s3://nonexistentbucket/foo"
with mock.patch.object(handler.log, "error") as mock_error:
handler.s3_write("text", url)
with mock.patch.object(self.subject.log, "error") as mock_error:
self.subject.write("text", url)
mock_error.assert_called_once_with("Could not write logs to %s", url, exc_info=True)


@pytest.mark.db_test
class TestS3TaskHandler:
@conf_vars({("logging", "remote_log_conn_id"): "aws_default"})
@pytest.fixture(autouse=True)
def setup_tests(self, create_log_template, tmp_path_factory, session):
self.remote_log_base = "s3://bucket/remote/log/location"
self.remote_log_location = "s3://bucket/remote/log/location/1.log"
self.remote_log_key = "remote/log/location/1.log"
self.local_log_location = str(tmp_path_factory.mktemp("local-s3-log-location"))
create_log_template("{try_number}.log")
self.s3_task_handler = S3TaskHandler(self.local_log_location, self.remote_log_base)
# Verify the hook now with the config override
assert self.s3_task_handler.io.hook is not None

date = datetime(2016, 1, 1)
self.dag = DAG("dag_for_testing_s3_task_handler", schedule=None, start_date=date)
task = EmptyOperator(task_id="task_for_testing_s3_log_handler", dag=self.dag)
if AIRFLOW_V_3_0_PLUS:
dag_run = DagRun(
dag_id=self.dag.dag_id,
logical_date=date,
run_id="test",
run_type="manual",
)
else:
dag_run = DagRun(
dag_id=self.dag.dag_id,
execution_date=date,
run_id="test",
run_type="manual",
)
session.add(dag_run)
session.commit()
session.refresh(dag_run)

self.ti = TaskInstance(task=task, run_id=dag_run.run_id)
self.ti.dag_run = dag_run
self.ti.try_number = 1
self.ti.state = State.RUNNING
session.add(self.ti)
session.commit()

self.conn = boto3.client("s3")
self.conn.create_bucket(Bucket="bucket")
yield

self.dag.clear()

session.query(DagRun).delete()
if self.s3_task_handler.handler:
with contextlib.suppress(Exception):
os.remove(self.s3_task_handler.handler.baseFilename)

def test_set_context_raw(self):
self.ti.raw = True
mock_open = mock.mock_open()
with mock.patch("airflow.providers.amazon.aws.log.s3_task_handler.open", mock_open):
self.s3_task_handler.set_context(self.ti)

assert not self.s3_task_handler.upload_on_close
mock_open.assert_not_called()

def test_set_context_not_raw(self):
mock_open = mock.mock_open()
with mock.patch("airflow.providers.amazon.aws.log.s3_task_handler.open", mock_open):
self.s3_task_handler.set_context(self.ti)

assert self.s3_task_handler.upload_on_close
mock_open.assert_called_once_with(os.path.join(self.local_log_location, "1.log"), "w")
mock_open().write.assert_not_called()

def test_read(self):
# Test what happens when we have two log files to read
self.conn.put_object(Bucket="bucket", Key=self.remote_log_key, Body=b"Log line\nLine 2\n")
self.conn.put_object(
Bucket="bucket", Key=self.remote_log_key + ".trigger.log", Body=b"Log line 3\nLine 4\n"
)
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
log, metadata = self.s3_task_handler.read(ti)

expected_s3_uri = f"s3://bucket/{self.remote_log_key}"

if AIRFLOW_V_3_0_PLUS:
assert log[0].event == "::group::Log message source details"
assert expected_s3_uri in log[0].sources
assert log[1].event == "::endgroup::"
assert log[2].event == "Log line"
assert log[3].event == "Line 2"
assert log[4].event == "Log line 3"
assert log[5].event == "Line 4"
assert metadata == {"end_of_log": True, "log_pos": 4}
else:
actual = log[0][0][-1]
assert f"*** Found logs in s3:\n*** * {expected_s3_uri}\n" in actual
assert actual.endswith("Line 4")
assert metadata == [{"end_of_log": True, "log_pos": 33}]

def test_read_when_s3_log_missing(self):
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
self.s3_task_handler._read_from_logs_server = mock.Mock(return_value=([], []))
log, metadata = self.s3_task_handler.read(ti)
if AIRFLOW_V_3_0_PLUS:
assert len(log) == 2
assert metadata == {"end_of_log": True, "log_pos": 0}
else:
assert len(log) == 1
assert len(log) == len(metadata)
actual = log[0][0][-1]
expected = "*** No logs found on s3 for ti=<TaskInstance: dag_for_testing_s3_task_handler.task_for_testing_s3_log_handler test [success]>\n"
assert expected in actual
assert metadata[0] == {"end_of_log": True, "log_pos": 0}

def test_close(self):
self.s3_task_handler.set_context(self.ti)
assert self.s3_task_handler.upload_on_close
Expand Down Expand Up @@ -221,3 +308,7 @@ def test_close_with_delete_local_logs_conf(self, delete_local_copy, expected_exi

handler.close()
assert os.path.exists(handler.handler.baseFilename) == expected_existence_of_local_copy

def test_filename_template_for_backward_compatibility(self):
# filename_template arg support for running the latest provider on airflow 2
S3TaskHandler(self.local_log_location, self.remote_log_base, filename_template=None)