Skip to content

Commit

Permalink
feat: add OpenLineage support for RedshiftToS3Operator (#41632)
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Muda <mudakacper@gmail.com>
  • Loading branch information
kacpermuda authored Oct 22, 2024
1 parent 15b41b4 commit be55378
Show file tree
Hide file tree
Showing 3 changed files with 437 additions and 8 deletions.
113 changes: 106 additions & 7 deletions providers/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def default_select_query(self) -> str | None:
table = self.table
return f"SELECT * FROM {table}"

@property
def use_redshift_data(self):
return bool(self.redshift_data_api_kwargs)

def execute(self, context: Context) -> None:
if self.table and self.table_as_file_name:
self.s3_key = f"{self.s3_key}/{self.table}_"
Expand All @@ -164,14 +168,13 @@ def execute(self, context: Context) -> None:
if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]:
self.unload_options = [*self.unload_options, "HEADER"]

redshift_hook: RedshiftDataHook | RedshiftSQLHook
if self.redshift_data_api_kwargs:
redshift_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
if self.use_redshift_data:
redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
for arg in ["sql", "parameters"]:
if arg in self.redshift_data_api_kwargs:
raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")
else:
redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None
if conn and conn.extra_dejson.get("role_arn", False):
credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
Expand All @@ -187,10 +190,106 @@ def execute(self, context: Context) -> None:
)

self.log.info("Executing UNLOAD command...")
if isinstance(redshift_hook, RedshiftDataHook):
redshift_hook.execute_query(
if self.use_redshift_data:
redshift_data_hook.execute_query(
sql=unload_query, parameters=self.parameters, **self.redshift_data_api_kwargs
)
else:
redshift_hook.run(unload_query, self.autocommit, parameters=self.parameters)
redshift_sql_hook.run(unload_query, self.autocommit, parameters=self.parameters)
self.log.info("UNLOAD command complete...")

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as we may query for table details."""
from airflow.providers.amazon.aws.utils.openlineage import (
get_facets_from_redshift_table,
get_identity_column_lineage_facet,
)
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
Error,
ExtractionErrorRunFacet,
)
from airflow.providers.openlineage.extractors import OperatorLineage

output_dataset = Dataset(
namespace=f"s3://{self.s3_bucket}",
name=self.s3_key,
)

if self.use_redshift_data:
redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
database = self.redshift_data_api_kwargs.get("database")
identifier = self.redshift_data_api_kwargs.get(
"cluster_identifier", self.redshift_data_api_kwargs.get("workgroup_name")
)
port = self.redshift_data_api_kwargs.get("port", "5439")
authority = f"{identifier}.{redshift_data_hook.region_name}:{port}"
else:
redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
database = redshift_sql_hook.conn.schema
authority = redshift_sql_hook.get_openlineage_database_info(redshift_sql_hook.conn).authority

if self.select_query == self.default_select_query:
if self.use_redshift_data:
input_dataset_facets = get_facets_from_redshift_table(
redshift_data_hook, self.table, self.redshift_data_api_kwargs, self.schema
)
else:
input_dataset_facets = get_facets_from_redshift_table(
redshift_sql_hook, self.table, {}, self.schema
)

input_dataset = Dataset(
namespace=f"redshift://{authority}",
name=f"{database}.{self.schema}.{self.table}" if database else f"{self.schema}.{self.table}",
facets=input_dataset_facets,
)

# If default select query is used (SELECT *) output file matches the input table.
output_dataset.facets = {
"schema": input_dataset_facets["schema"],
"columnLineage": get_identity_column_lineage_facet(
field_names=[field.name for field in input_dataset_facets["schema"].fields],
input_datasets=[input_dataset],
),
}

return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])

try:
from airflow.providers.openlineage.sqlparser import SQLParser, from_table_meta
except ImportError:
return OperatorLineage(outputs=[output_dataset])

run_facets = {}
parse_result = SQLParser(dialect="redshift", default_schema=self.schema).parse(self.select_query)
if parse_result.errors:
run_facets["extractionError"] = ExtractionErrorRunFacet(
totalTasks=1,
failedTasks=1,
errors=[
Error(
errorMessage=error.message,
stackTrace=None,
task=error.origin_statement,
taskNumber=error.index,
)
for error in parse_result.errors
],
)

input_datasets = []
for in_tb in parse_result.in_tables:
ds = from_table_meta(in_tb, database, f"redshift://{authority}", False)
schema, table = ds.name.split(".")[-2:]
if self.use_redshift_data:
input_dataset_facets = get_facets_from_redshift_table(
redshift_data_hook, table, self.redshift_data_api_kwargs, schema
)
else:
input_dataset_facets = get_facets_from_redshift_table(redshift_sql_hook, table, {}, schema)

ds.facets = input_dataset_facets
input_datasets.append(ds)

return OperatorLineage(inputs=input_datasets, outputs=[output_dataset], run_facets=run_facets)
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def get_openlineage_facets_on_complete(self, task_instance):

output_dataset = Dataset(
namespace=f"redshift://{authority}",
name=f"{database}.{self.schema}.{self.table}",
name=f"{database}.{self.schema}.{self.table}" if database else f"{self.schema}.{self.table}",
facets=output_dataset_facets,
)

Expand Down
Loading

0 comments on commit be55378

Please sign in to comment.