Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_conn(self) -> connection:

# check for authentication via AWS IAM
if conn.extra_dejson.get("iam", False):
conn.login, conn.password, conn.port = self.get_iam_token(conn)
conn.login, conn.password, conn.host, conn.port = self.get_iam_token(conn)

conn_args = dict(
host=conn.host,
Expand Down Expand Up @@ -177,6 +177,33 @@ def get_uri(self) -> str:
uri = conn.get_uri().replace("postgres://", "postgresql://")
return uri

def resolve_rds_cname(self, hostname):
"""Resolve a CNAME record to the original RDS endpoint.
This is required for AWS where the hostname of the RDS instance is part of
the signing request.
Looking for the endpoint which is of the form cluster-name.accountandregionhash.regionid.rds.amazonaws.com
To do so, recursively resolve the host name until it's a subdomain of rds.amazonaws.com.
"""
try:
import dns.name
import dns.resolver
from dns.exception import DNSException
except ImportError:
self.log.warning(
"Module dns cannot be imported. Make sure you added it to your requirements",
)
return hostname

base_domain = dns.name.from_text("rds.amazonaws.com")
answer = dns.name.from_text(hostname)
while not answer.is_subdomain(base_domain):
try:
answer = dns.resolver.resolve(answer, dns.rdatatype.CNAME, search=True)[0].target
except DNSException as e:
self.log.error("Failed to resolve hostname to database endpoint")
raise e
return answer.to_text().strip(".")

def bulk_load(self, table: str, tmp_file: str) -> None:
"""Loads a tab-delimited file into a database table"""
self.copy_expert(f"COPY {table} FROM STDIN", tmp_file)
Expand All @@ -200,7 +227,7 @@ def _serialize_cell(cell: object, conn: connection | None = None) -> Any:
"""
return cell

def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
def get_iam_token(self, conn: Connection) -> tuple[str, str, str, int]:
"""
Uses AWSHook to retrieve a temporary password to connect to Postgres
or Redshift. Port is required. If none is provided, default is used for
Expand All @@ -218,6 +245,9 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:

aws_conn_id = conn.extra_dejson.get("aws_conn_id", "aws_default")
login = conn.login
host = conn.host
if conn.extra_dejson.get("resolve_rds_cname", False):
host = self.resolve_rds_cname(conn.host)
if conn.extra_dejson.get("redshift", False):
port = conn.port or 5439
# Pull the custer-identifier from the beginning of the Redshift URL
Expand All @@ -237,8 +267,8 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
port = conn.port or 5432
rds_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="rds").conn
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.generate_db_auth_token
token = rds_client.generate_db_auth_token(conn.host, port, conn.login)
return login, token, port
token = rds_client.generate_db_auth_token(host, port, conn.login)
return login, token, host, port

def get_table_primary_key(self, table: str, schema: str | None = "public") -> list[str] | None:
"""
Expand Down