Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6a91eb9
refactor: Make sure reserved words in column names are all escaped in…
davidblain-infrabel Jan 14, 2025
be67e62
refactor: Fixed the remove_quotes method so it removes all special ch…
davidblain-infrabel Jan 14, 2025
9c9beaf
refactor: Make sure only escape characters are being removed when une…
davidblain-infrabel Jan 14, 2025
156a95e
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 14, 2025
b388be8
refactor: Renamed escape_column_name to escape_word and unescape_colu…
davidblain-infrabel Jan 14, 2025
c667158
refactor: Reformatted default and postgres dialect
davidblain-infrabel Jan 15, 2025
baa7d31
refactor: Fixed TestMsSqlHook
davidblain-infrabel Jan 15, 2025
bee3291
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 15, 2025
6a94065
refactor: Fixed docstring of unescape_word method in Dialect class
davidblain-infrabel Jan 15, 2025
a1532d4
refactor: Fixed passing escape_word_format in MSSQLHook and MySQLHook
davidblain-infrabel Jan 15, 2025
f0d71cc
refactor: Also allow escape_word_format, insert_statement_format and …
davidblain-infrabel Jan 15, 2025
04cf829
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 15, 2025
83b1f8a
refactor: Fixed TestTeradataHook
davidblain-infrabel Jan 15, 2025
1cfeea3
refactor: Reformatted escape_word_format method in DbAPiHook
davidblain-infrabel Jan 15, 2025
d442159
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 15, 2025
f86c078
refactor: Renamed formatting properties on Dialect and added unit tests
davidblain-infrabel Jan 15, 2025
4ef0d9a
refactor: Removed duplicate inspector property from Dialect
davidblain-infrabel Jan 15, 2025
531de1e
refactor: Added missing escape_word_format property from Dialect defi…
davidblain-infrabel Jan 15, 2025
1304a22
refactor: Also support un-escaping words when schema and table is con…
davidblain-infrabel Jan 15, 2025
540eb4b
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 15, 2025
bc04dee
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 15, 2025
c8d9a41
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 15, 2025
0c824b3
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 16, 2025
648b5d8
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 17, 2025
b208f75
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 17, 2025
3dde4d5
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 17, 2025
fa9e01c
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 17, 2025
934c16f
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 18, 2025
047a8e6
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 19, 2025
3cc2610
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 19, 2025
8c35659
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 20, 2025
2c5b469
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 20, 2025
5b8e6f1
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 22, 2025
209f820
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 22, 2025
147c4f6
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 22, 2025
880bded
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 22, 2025
468eb4e
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 22, 2025
99133ea
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 23, 2025
066280a
refactor: Added 2 tests with other escape format
davidblain-infrabel Jan 23, 2025
58d2f4c
refactor: Changed import of handlers to avoid deprecation warnings
davidblain-infrabel Jan 23, 2025
2f91392
refactor: Added option to allow forcing escaping of column names even…
davidblain-infrabel Jan 23, 2025
3f476d5
refactor: Added more test cases
davidblain-infrabel Jan 23, 2025
10592f9
refactor: Added even more test cases
davidblain-infrabel Jan 23, 2025
4b1e3f8
refactor: escape_column_names of DbApiHook should be property instead…
davidblain-infrabel Jan 23, 2025
df09c3c
refactor: Updated docstring escape_word method of Dialect
davidblain-infrabel Jan 23, 2025
9ee0367
refactor: Fixed TestPostgresDialect
davidblain-infrabel Jan 23, 2025
ce0cf51
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 23, 2025
534ec26
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 24, 2025
ffcf0d9
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 24, 2025
e6b8b14
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 24, 2025
a23ed8c
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
davidblain-infrabel Jan 26, 2025
f4b74ce
fix: Fixed test_escape_column_names use mock_db_hook instead of mock_…
davidblain-infrabel Jan 26, 2025
c4357fa
Merge branch 'main' into fix/escape-columns-as-reserver-words-mssql
dabla Jan 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class Dialect(LoggingMixin):
"""Generic dialect implementation."""

pattern = re.compile(r'"([a-zA-Z0-9_]+)"')
pattern = re.compile(r"[^\w]")

def __init__(self, hook, **kwargs) -> None:
super().__init__(**kwargs)
Expand All @@ -45,12 +45,6 @@ def __init__(self, hook, **kwargs) -> None:

self.hook: DbApiHook = hook

@classmethod
def remove_quotes(cls, value: str | None) -> str | None:
if value:
return cls.pattern.sub(r"\1", value)
return value

@property
def placeholder(self) -> str:
return self.hook.placeholder
Expand All @@ -60,16 +54,56 @@ def inspector(self) -> Inspector:
return self.hook.inspector

@property
def _insert_statement_format(self) -> str:
return self.hook._insert_statement_format # type: ignore
def insert_statement_format(self) -> str:
return self.hook.insert_statement_format

@property
def replace_statement_format(self) -> str:
return self.hook.replace_statement_format

@property
def _replace_statement_format(self) -> str:
return self.hook._replace_statement_format # type: ignore
def escape_word_format(self) -> str:
return self.hook.escape_word_format

@property
def _escape_column_name_format(self) -> str:
return self.hook._escape_column_name_format # type: ignore
def escape_column_names(self) -> bool:
return self.hook.escape_column_names

def escape_word(self, word: str) -> str:
"""
Escape the word if necessary.

If the word is a reserved word or contains special characters or if the ``escape_column_names``
property is set to True in connection extra field, then the given word will be escaped.

:param word: Name of the column
:return: The escaped word
"""
if word != self.escape_word_format.format(self.unescape_word(word)) and (
self.escape_column_names or word.casefold() in self.reserved_words or self.pattern.search(word)
):
return self.escape_word_format.format(word)
return word

def unescape_word(self, word: str | None) -> str | None:
"""
Remove escape characters from each part of a dotted identifier (e.g., schema.table).

:param word: Escaped schema, table, or column name, potentially with multiple segments.
:return: The word without escaped characters.
"""
if not word:
return word

escape_char_start = self.escape_word_format[0]
escape_char_end = self.escape_word_format[-1]

def unescape_part(part: str) -> str:
if part.startswith(escape_char_start) and part.endswith(escape_char_end):
return part[1:-1]
return part

return ".".join(map(unescape_part, word.split(".")))

@classmethod
def extract_schema_from_table(cls, table: str) -> tuple[str, str | None]:
Expand All @@ -87,8 +121,8 @@ def get_column_names(
for column in filter(
predicate,
self.inspector.get_columns(
table_name=self.remove_quotes(table),
schema=self.remove_quotes(schema) if schema else None,
table_name=self.unescape_word(table),
schema=self.unescape_word(schema) if schema else None,
),
)
)
Expand All @@ -110,8 +144,8 @@ def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] |
if schema is None:
table, schema = self.extract_schema_from_table(table)
primary_keys = self.inspector.get_pk_constraint(
table_name=self.remove_quotes(table),
schema=self.remove_quotes(schema) if schema else None,
table_name=self.unescape_word(table),
schema=self.unescape_word(schema) if schema else None,
).get("constrained_columns", [])
self.log.debug("Primary keys for table '%s': %s", table, primary_keys)
return primary_keys
Expand All @@ -138,20 +172,6 @@ def get_records(
def reserved_words(self) -> set[str]:
return self.hook.reserved_words

def escape_column_name(self, column_name: str) -> str:
"""
Escape the column name if it's a reserved word.

:param column_name: Name of the column
:return: The escaped column name if needed
"""
if (
column_name != self._escape_column_name_format.format(column_name)
and column_name.casefold() in self.reserved_words
):
return self._escape_column_name_format.format(column_name)
return column_name

def _joined_placeholders(self, values) -> str:
placeholders = [
self.placeholder,
Expand All @@ -160,7 +180,7 @@ def _joined_placeholders(self, values) -> str:

def _joined_target_fields(self, target_fields) -> str:
if target_fields:
target_fields = ", ".join(map(self.escape_column_name, target_fields))
target_fields = ", ".join(map(self.escape_word, target_fields))
return f"({target_fields})"
return ""

Expand All @@ -173,7 +193,7 @@ def generate_insert_sql(self, table, values, target_fields, **kwargs) -> str:
:param target_fields: The names of the columns to fill in the table
:return: The generated INSERT SQL statement
"""
return self._insert_statement_format.format(
return self.insert_statement_format.format(
table, self._joined_target_fields(target_fields), self._joined_placeholders(values)
)

Expand All @@ -186,6 +206,6 @@ def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str:
:param target_fields: The names of the columns to fill in the table
:return: The generated REPLACE SQL statement
"""
return self._replace_statement_format.format(
return self.replace_statement_format.format(
table, self._joined_target_fields(target_fields), self._joined_placeholders(values)
)
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,17 @@ T = TypeVar("T")
class Dialect(LoggingMixin):
hook: Incomplete
def __init__(self, hook, **kwargs) -> None: ...
@classmethod
def remove_quotes(cls, value: str | None) -> str | None: ...
def escape_word(self, column_name: str) -> str: ...
def unescape_word(self, value: str | None) -> str | None: ...
@property
def placeholder(self) -> str: ...
@property
def insert_statement_format(self) -> str: ...
@property
def replace_statement_format(self) -> str: ...
@property
def escape_word_format(self) -> str: ...
@property
def inspector(self) -> Inspector: ...
@classmethod
def extract_schema_from_table(cls, table: str) -> tuple[str, str | None]: ...
Expand All @@ -72,6 +78,5 @@ class Dialect(LoggingMixin):
) -> Any: ...
@property
def reserved_words(self) -> set[str]: ...
def escape_column_name(self, column_name: str) -> str: ...
def generate_insert_sql(self, table, values, target_fields, **kwargs) -> str: ...
def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: ...
58 changes: 41 additions & 17 deletions providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.dialects.dialect import Dialect
from airflow.providers.common.sql.hooks import handlers
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
Expand All @@ -67,24 +68,18 @@
def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool | None):
warnings.warn(WARNING_MESSAGE.format("return_single_query_results"), DeprecationWarning, stacklevel=2)

from airflow.providers.common.sql.hooks import handlers

return handlers.return_single_query_results(sql, return_last, split_statements)


def fetch_all_handler(cursor) -> list[tuple] | None:
warnings.warn(WARNING_MESSAGE.format("fetch_all_handler"), DeprecationWarning, stacklevel=2)

from airflow.providers.common.sql.hooks import handlers

return handlers.fetch_all_handler(cursor)


def fetch_one_handler(cursor) -> list[tuple] | None:
warnings.warn(WARNING_MESSAGE.format("fetch_one_handler"), DeprecationWarning, stacklevel=2)

from airflow.providers.common.sql.hooks import handlers

return handlers.fetch_one_handler(cursor)


Expand Down Expand Up @@ -184,13 +179,10 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa
self.__schema = schema
self.log_sql = log_sql
self.descriptions: list[Sequence[Sequence] | None] = []
self._insert_statement_format: str = kwargs.get(
"insert_statement_format", "INSERT INTO {} {} VALUES ({})"
)
self._replace_statement_format: str = kwargs.get(
"replace_statement_format", "REPLACE INTO {} {} VALUES ({})"
)
self._escape_column_name_format: str = kwargs.get("escape_column_name_format", '"{}"')
self._insert_statement_format: str | None = kwargs.get("insert_statement_format")
self._replace_statement_format: str | None = kwargs.get("replace_statement_format")
self._escape_word_format: str | None = kwargs.get("escape_word_format")
self._escape_column_names: bool | None = kwargs.get("escape_column_names")
self._connection: Connection | None = kwargs.pop("connection", None)

def get_conn_id(self) -> str:
Expand All @@ -212,6 +204,38 @@ def placeholder(self) -> str:
)
return self._placeholder

@property
def insert_statement_format(self) -> str:
"""Return the insert statement format."""
if not self._insert_statement_format:
self._insert_statement_format = self.connection_extra.get(
"insert_statement_format", "INSERT INTO {} {} VALUES ({})"
)
return self._insert_statement_format

@property
def replace_statement_format(self) -> str:
"""Return the replacement statement format."""
if not self._replace_statement_format:
self._replace_statement_format = self.connection_extra.get(
"replace_statement_format", "REPLACE INTO {} {} VALUES ({})"
)
return self._replace_statement_format

@property
def escape_word_format(self) -> str:
"""Return the escape word format."""
if not self._escape_word_format:
self._escape_word_format = self.connection_extra.get("escape_word_format", '"{}"')
return self._escape_word_format

@property
def escape_column_names(self) -> bool:
"""Return the escape column names flag."""
if not self._escape_column_names:
self._escape_column_names = self.connection_extra.get("escape_column_names", False)
return self._escape_column_names

@property
def connection(self) -> Connection:
if self._connection is None:
Expand Down Expand Up @@ -413,7 +437,7 @@ def get_records(
:param sql: the sql statement to be executed (str) or a list of sql statements to execute
:param parameters: The parameters to render the SQL query with.
"""
return self.run(sql=sql, parameters=parameters, handler=fetch_all_handler)
return self.run(sql=sql, parameters=parameters, handler=handlers.fetch_all_handler)

def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None) -> Any:
"""
Expand All @@ -422,7 +446,7 @@ def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, An
:param sql: the sql statement to be executed (str) or a list of sql statements to execute
:param parameters: The parameters to render the SQL query with.
"""
return self.run(sql=sql, parameters=parameters, handler=fetch_one_handler)
return self.run(sql=sql, parameters=parameters, handler=handlers.fetch_one_handler)

@staticmethod
def strip_sql_string(sql: str) -> str:
Expand Down Expand Up @@ -557,7 +581,7 @@ def run(

if handler is not None:
result = self._make_common_data_structure(handler(cur))
if return_single_query_results(sql, return_last, split_statements):
if handlers.return_single_query_results(sql, return_last, split_statements):
_last_result = result
_last_description = cur.description
else:
Expand All @@ -572,7 +596,7 @@ def run(

if handler is None:
return None
if return_single_query_results(sql, return_last, split_statements):
if handlers.return_single_query_results(sql, return_last, split_statements):
self.descriptions = [_last_description]
return _last_result
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ class DbApiHook(BaseHook):
@cached_property
def placeholder(self) -> str: ...
@property
def insert_statement_format(self) -> str: ...
@property
def replace_statement_format(self) -> str: ...
@property
def escape_word_format(self) -> str: ...
@property
def escape_column_names(self) -> bool: ...
@property
def connection(self) -> Connection: ...
@connection.setter
def connection(self, value: Any) -> None: ...
Expand Down
Loading