Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions airflow/providers/amazon/aws/operators/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence

import pytz
from dateutil import parser

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand Down Expand Up @@ -498,8 +501,8 @@ def __init__(
bucket: str,
keys: str | list | None = None,
prefix: str | None = None,
from_datetime: datetime | None = None,
to_datetime: datetime | None = None,
from_datetime: datetime | str | None = None,
to_datetime: datetime | str | None = None,
aws_conn_id: str | None = "aws_default",
verify: str | bool | None = None,
**kwargs,
Expand Down Expand Up @@ -530,6 +533,13 @@ def execute(self, context: Context):

if isinstance(self.keys, (list, str)) and not self.keys:
return
# handle case where dates are strings, specifically when sent as template fields and macros.
if isinstance(self.to_datetime, str):
self.to_datetime = parser.parse(self.to_datetime).replace(tzinfo=pytz.UTC)

if isinstance(self.from_datetime, str):
self.from_datetime = parser.parse(self.from_datetime).replace(tzinfo=pytz.UTC)

s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)

keys = self.keys or s3_hook.list_keys(
Expand Down
42 changes: 42 additions & 0 deletions tests/providers/amazon/aws/operators/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import shutil
import sys
from datetime import timedelta
from io import BytesIO
from tempfile import mkdtemp
from unittest import mock
Expand All @@ -29,7 +30,10 @@
import pytest
from moto import mock_aws

from airflow import DAG
from airflow.exceptions import AirflowException
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.s3 import (
S3CopyObjectOperator,
Expand All @@ -52,6 +56,7 @@
)
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.timezone import datetime, utcnow
from airflow.utils.types import DagRunType
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

BUCKET_NAME = os.environ.get("BUCKET_NAME", "test-airflow-bucket")
Expand Down Expand Up @@ -623,6 +628,43 @@ def test_s3_delete_multiple_objects(self):
# There should be no object found in the bucket created earlier
assert "Contents" not in conn.list_objects(Bucket=bucket, Prefix=key_pattern)

@pytest.mark.db_test
def test_dates_from_template(self, session):
"""Specifically test for dates passed from templating that could be strings"""
bucket = "testbucket"
key_pattern = "path/data"
n_keys = 3
keys = [key_pattern + str(i) for i in range(n_keys)]

conn = boto3.client("s3")
conn.create_bucket(Bucket=bucket)
for k in keys:
conn.upload_fileobj(Bucket=bucket, Key=k, Fileobj=BytesIO(b"input"))

execution_date = utcnow()
dag = DAG("test_dag", start_date=datetime(2020, 1, 1), schedule=timedelta(days=1))
# use macros.ds_add since it returns a string, not a date
op = S3DeleteObjectsOperator(
task_id="XXXXXXXXXXXXXXXXXXXXXXX",
bucket=bucket,
from_datetime="{{ macros.ds_add(ds, -1) }}",
to_datetime="{{ macros.ds_add(ds, 1) }}",
dag=dag,
)

dag_run = DagRun(
dag_id=dag.dag_id, execution_date=execution_date, run_id="test", run_type=DagRunType.MANUAL
)
ti = TaskInstance(task=op)
ti.dag_run = dag_run
session.add(ti)
session.commit()
context = ti.get_template_context(session)

ti.render_templates(context)
op.execute(None)
assert "Contents" not in conn.list_objects(Bucket=bucket)

def test_s3_delete_from_to_datetime(self):
bucket = "testbucket"
key_pattern = "path/data"
Expand Down