Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
from airflow.utils.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -93,7 +94,7 @@ def __init__(
s3_key: str,
schema: str | None = None,
redshift_conn_id: str = "redshift_default",
aws_conn_id: str | None = "aws_default",
aws_conn_id: str | None | ArgNotSet = NOTSET,
verify: bool | str | None = None,
column_list: list[str] | None = None,
copy_options: list | None = None,
Expand All @@ -117,6 +118,16 @@ def __init__(
self.method = method
self.upsert_keys = upsert_keys
self.redshift_data_api_kwargs = redshift_data_api_kwargs or {}
# In execute() we attempt to fetch this aws connection to check for extras. If the user didn't
# actually provide a connection note that, because we don't want to let the exception bubble up in
# that case (since we're silently injecting a connection on their behalf).
self._aws_conn_id: str | None
if isinstance(aws_conn_id, ArgNotSet):
self.conn_set = False
self._aws_conn_id = "aws_default"
else:
self.conn_set = True
self._aws_conn_id = aws_conn_id

if self.redshift_data_api_kwargs:
for arg in ["sql", "parameters"]:
Expand Down Expand Up @@ -149,14 +160,19 @@ def execute(self, context: Context) -> None:
else:
redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)

conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None
conn = (
S3Hook.get_connection(conn_id=self._aws_conn_id)
# Only fetch the connection if it was set by the user and it is not None
if self.conn_set and self._aws_conn_id
else None
)
region_info = ""
if conn and conn.extra_dejson.get("region", False):
region_info = f"region '{conn.extra_dejson['region']}'"
if conn and conn.extra_dejson.get("role_arn", False):
credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
else:
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
s3_hook = S3Hook(aws_conn_id=self._aws_conn_id, verify=self.verify)
credentials = s3_hook.get_credentials()
credentials_block = build_credentials_block(credentials)

Expand Down