Skip to content

Unify run_sql tasks across DAGs #4883

Open

Description

Description

We have several unique implementations of a common run_sql task across our DAGs. This task was pulled out into common/sql.py in #4836:

@task
def run_sql(
sql_template: str,
postgres_conn_id: str = POSTGRES_CONN_ID,
task: AbstractOperator = None,
timeout: float = None,
handler: callable = RETURN_ROW_COUNT,
**kwargs,
):
"""
Run an SQL query with the given template and parameters. Any kwargs handed
into the function outside of those defined will be passed into the template
`.format` call.
"""
query = sql_template.format(**kwargs)
postgres = PostgresHook(
postgres_conn_id=postgres_conn_id,
default_statement_timeout=(
timeout if timeout else PostgresHook.get_execution_timeout(task)
),
)
return postgres.run(query, handler=handler)

We have several DAGs which can now use this run_sql function directly, rather than re-implementing their own:

  • delete_records:

def run_sql(
sql_template: str,
postgres_conn_id: str = POSTGRES_CONN_ID,
task: AbstractOperator = None,
timeout: timedelta = None,
handler: callable = RETURN_ROW_COUNT,
**kwargs,
):
query = sql_template.format(**kwargs)
postgres = PostgresHook(
postgres_conn_id=postgres_conn_id,
default_statement_timeout=(
timeout if timeout else PostgresHook.get_execution_timeout(task)
),
)
return postgres.run(query, handler=handler)

  • batched_update (this one may require some additional work on either the base function or the call to accommodate the dry_run variable):

def run_sql(
dry_run: bool,
sql_template: str,
query_id: str,
log_sql: bool = True,
postgres_conn_id: str = POSTGRES_CONN_ID,
task: AbstractOperator = None,
timeout: timedelta = None,
handler: callable = RETURN_ROW_COUNT,
**kwargs,
):
query = sql_template.format(
temp_table_name=constants.TEMP_TABLE_NAME.format(query_id=query_id), **kwargs
)
if dry_run:
logger.info(
"This is a dry run: no SQL will be executed. To perform the updates,"
" rerun the DAG with the conf option `'dry_run': false`."
)
logger.info(query)
return 0
postgres = PostgresHook(
postgres_conn_id=postgres_conn_id,
default_statement_timeout=(
timeout if timeout else PostgresHook.get_execution_timeout(task)
),
log_sql=log_sql,
)
return postgres.run(query, handler=handler)

  • add_license_url:

def run_sql(
sql: str,
log_sql: bool = True,
method: str = "get_records",
handler: callable = None,
autocommit: bool = False,
postgres_conn_id: str = POSTGRES_CONN_ID,
dag_task: AbstractOperator = None,
):
postgres = PostgresHook(
postgres_conn_id=postgres_conn_id,
default_statement_timeout=PostgresHook.get_execution_timeout(dag_task),
log_sql=log_sql,
)
if method == "get_records":
return postgres.get_records(sql)
elif method == "get_first":
return postgres.get_first(sql)
else:
return postgres.run(sql, autocommit=autocommit, handler=handler)

Additional context

This also came up in the discussion of #4572

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

Labels

good first issueNew-contributor friendlyhelp wantedOpen to participation from the community💻 aspect: codeConcerns the software code in the repository🔧 tech: airflowInvolves Apache Airflow🟩 priority: lowLow priority and doesn't need to be rushed🧰 goal: internal improvementImprovement that benefits maintainers, not users🧱 stack: catalogRelated to the catalog and Airflow DAGs

Type

No type

Projects

  • Status

    📅 To Do

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions