Skip to content

Commit

Permalink
feat: support parsing wildcard in sqlparse
Browse files Browse the repository at this point in the history
  • Loading branch information
lixxvsky authored and reata committed Dec 26, 2023
1 parent 8b37464 commit b41ee16
Show file tree
Hide file tree
Showing 8 changed files with 505 additions and 14 deletions.
9 changes: 8 additions & 1 deletion sqllineage/core/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

from sqllineage.core.holders import StatementLineageHolder
from sqllineage.core.metadata_provider import MetaDataProvider


class LineageAnalyzer:
Expand All @@ -13,7 +14,13 @@ class LineageAnalyzer:
SUPPORTED_DIALECTS: List[str] = []

@abstractmethod
def analyze(self, sql: str, silent_mode: bool) -> StatementLineageHolder:
def analyze(
self,
sql: str,
pre_stmt_holders: List[StatementLineageHolder],
metadata_provider: MetaDataProvider,
silent_mode: bool,
) -> StatementLineageHolder:
"""
to analyze single statement sql and store the result into
:class:`sqllineage.core.holders.StatementLineageHolder`
Expand Down
42 changes: 42 additions & 0 deletions sqllineage/core/holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,48 @@ def add_column_lineage(self, src: Column, tgt: Column) -> None:
# starting NetworkX v2.6, None is not allowed as node, see https://github.com/networkx/networkx/pull/4892
self.graph.add_edge(src.parent, src, type=EdgeType.HAS_COLUMN)

def get_node_src_lineage(
self, node: Union[Column, Table, SubQuery]
) -> List[Union[Column, Table, SubQuery]]:
src_list = []
for src, tgt, edge_type in self.graph.in_edges(nbunch=node, data="type"):
if edge_type == EdgeType.LINEAGE:
src_list.append(src)
return src_list

def get_table_columns(self, table: Union[Table, SubQuery]) -> List[Column]:
columns = []
for tbl, col, e_type in self.graph.edges(nbunch=table, data="type"):
if (
e_type == EdgeType.HAS_COLUMN
and isinstance(col, Column)
and col.raw_name != "*"
):
columns.append(col)
return columns

def replace_wildcard(
self,
target_table: Union[Table, SubQuery],
source_columns: List[Column],
current_wildcard: Column,
src_wildcard: Column,
) -> None:
target_columns = self.get_table_columns(target_table)
for src_col in source_columns:
new_column = Column(src_col.raw_name)
new_column.parent = target_table
if new_column in target_columns or src_col.raw_name == "*":
continue
self.graph.add_edge(target_table, new_column, type=EdgeType.HAS_COLUMN)
self.graph.add_edge(src_col.parent, src_col, type=EdgeType.HAS_COLUMN)
self.graph.add_edge(src_col, new_column, type=EdgeType.LINEAGE)
# remove wildcard
if current_wildcard is not None and self.graph.has_node(current_wildcard):
self.graph.remove_node(current_wildcard)
if src_wildcard is not None and self.graph.has_node(src_wildcard):
self.graph.remove_node(src_wildcard)


class StatementLineageHolder(SubQueryLineageHolder, ColumnLineageMixin):
"""
Expand Down
9 changes: 8 additions & 1 deletion sqllineage/core/parser/sqlfluff/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sqllineage.core.analyzer import LineageAnalyzer
from sqllineage.core.holders import StatementLineageHolder
from sqllineage.core.metadata_provider import MetaDataProvider
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
from sqllineage.exceptions import (
InvalidSyntaxException,
Expand Down Expand Up @@ -35,7 +36,13 @@ def split_tsql(self, sql: str) -> List[str]:
sqls.append(segment.raw)
return sqls

def analyze(self, sql: str, silent_mode: bool = False) -> StatementLineageHolder:
def analyze(
self,
sql: str,
pre_stmt_holders: List[StatementLineageHolder],
metadata_provider: MetaDataProvider,
silent_mode: bool = False,
) -> StatementLineageHolder:
if sql in self.tsql_split_cache:
statement_segments = [self.tsql_split_cache[sql]]
else:
Expand Down
104 changes: 97 additions & 7 deletions sqllineage/core/parser/sqlparse/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from sqllineage.core.analyzer import LineageAnalyzer
from sqllineage.core.holders import StatementLineageHolder, SubQueryLineageHolder
from sqllineage.core.models import Column, SubQuery, Table
from sqllineage.core.metadata_provider import MetaDataProvider
from sqllineage.core.models import Column, Schema, SubQuery, Table
from sqllineage.core.parser.sqlparse.handlers.base import (
CurrentTokenBaseHandler,
NextTokenBaseHandler,
Expand All @@ -38,7 +39,13 @@ class SqlParseLineageAnalyzer(LineageAnalyzer):
PARSER_NAME = "sqlparse"
SUPPORTED_DIALECTS = ["non-validating"]

def analyze(self, sql: str, silent_mode: bool = False) -> StatementLineageHolder:
def analyze(
self,
sql: str,
pre_stmt_holders: List[StatementLineageHolder],
metadata_provider: MetaDataProvider,
silent_mode: bool = False
) -> StatementLineageHolder:
# get rid of comments, which cause inconsistencies in sqlparse output
stmt = sqlparse.parse(trim_comment(sql))[0]
if (
Expand All @@ -58,11 +65,15 @@ def analyze(self, sql: str, silent_mode: bool = False) -> StatementLineageHolder
):
holder = self._extract_from_ddl_alter(stmt)
elif stmt.get_type() == "MERGE":
holder = self._extract_from_dml_merge(stmt)
holder = self._extract_from_dml_merge(
stmt, pre_stmt_holders, metadata_provider
)
else:
# DML parsing logic also applies to CREATE DDL
holder = StatementLineageHolder.of(
self._extract_from_dml(stmt, AnalyzerContext())
self._extract_from_dml(
stmt, AnalyzerContext(), pre_stmt_holders, metadata_provider
)
)
return holder

Expand Down Expand Up @@ -101,7 +112,12 @@ def _extract_from_ddl_alter(cls, stmt: Statement) -> StatementLineageHolder:
return holder

@classmethod
def _extract_from_dml_merge(cls, stmt: Statement) -> StatementLineageHolder:
def _extract_from_dml_merge(
cls,
stmt: Statement,
pre_stmt_holders: List[StatementLineageHolder],
metadata_provider: MetaDataProvider,
) -> StatementLineageHolder:
holder = StatementLineageHolder()
src_flag = tgt_flag = update_flag = insert_flag = False
insert_columns = []
Expand Down Expand Up @@ -134,6 +150,8 @@ def _extract_from_dml_merge(cls, stmt: Statement) -> StatementLineageHolder:
holder |= cls._extract_from_dml(
sq.query,
AnalyzerContext(cte=holder.cte, write={sq}),
pre_stmt_holders,
metadata_provider,
)
else:
direct_source = SqlParseTable.of(token)
Expand Down Expand Up @@ -187,7 +205,11 @@ def _extract_from_dml_merge(cls, stmt: Statement) -> StatementLineageHolder:

@classmethod
def _extract_from_dml(
cls, token: TokenList, context: AnalyzerContext
cls,
token: TokenList,
context: AnalyzerContext,
pre_stmt_holders: List[StatementLineageHolder],
metadata_provider: MetaDataProvider,
) -> SubQueryLineageHolder:
holder = SubQueryLineageHolder()
if context.cte is not None:
Expand Down Expand Up @@ -230,11 +252,79 @@ def _extract_from_dml(
# call end of query hook here as loop is over
for next_handler in next_handlers:
next_handler.end_of_query_cleanup(holder)

# find wildcard in current subquery
wildcards = []
for column in holder.write_columns:
if column.raw_name == "*":
wildcards.append(column)
# save write table of current subquery
if len(wildcards) > 0:
target_table = list(holder.write)[0]

# By recursively extracting each subquery of the parent and merge, we're doing Depth-first search
for sq in subqueries:
holder |= cls._extract_from_dml(
sq.query, AnalyzerContext(cte=holder.cte, write={sq})
sq.query,
AnalyzerContext(cte=holder.cte, write={sq}),
pre_stmt_holders,
metadata_provider,
)

# replace wildcard with real columns
for wildcard in wildcards:
for src_wildcard in holder.get_node_src_lineage(wildcard):
if isinstance(src_wildcard, Column):
source_table = src_wildcard.parent
if source_table is not None:
if isinstance(source_table, SubQuery):
src_table_columns = holder.get_table_columns(source_table)
if len(src_table_columns) > 0:
holder.replace_wildcard(
target_table,
src_table_columns,
wildcard,
src_wildcard,
)
elif isinstance(source_table, Table):
# if wildcard's source table is <default>
# means it is probably a temporary view, search in previous statement holder firstly
is_replaced = False
if source_table.schema.raw_name == Schema.unknown:
for pre_stmt_holder in reversed(pre_stmt_holders):
if pre_stmt_holder.graph.has_node(source_table):
src_table_columns = (
pre_stmt_holder.get_table_columns(
source_table
)
)
if len(src_table_columns) > 0:
holder.replace_wildcard(
target_table,
src_table_columns,
wildcard,
src_wildcard,
)
is_replaced = True
break
# if not founded or source table is not <default>, try to search by metadata service
if not is_replaced and metadata_provider is not None:
db = source_table.schema.raw_name
table = source_table.raw_name
source_columns = []
for col in metadata_provider.get_table_columns(
db, table
):
column = Column(col)
column.parent = source_table
source_columns.append(column)
if len(source_columns) > 0:
holder.replace_wildcard(
target_table,
source_columns,
wildcard,
src_wildcard,
)
return holder

@classmethod
Expand Down
7 changes: 5 additions & 2 deletions sqllineage/core/parser/sqlparse/handlers/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,11 @@ def _handle_column(self, token: Token) -> None:
column_tokens = [
sub_token
for sub_token in token.tokens
if isinstance(sub_token, column_token_types)
and not sub_token.value.startswith("@")
if (
isinstance(sub_token, column_token_types)
and not sub_token.value.startswith("@")
)
or sub_token.ttype is Wildcard
# ignore tsql variable column starts with @
]
else:
Expand Down
8 changes: 5 additions & 3 deletions sqllineage/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,11 @@ def _eval(self):
)
self._stmt = split(self._sql.strip())

self._stmt_holders = [
analyzer.analyze(stmt, self._silent_mode) for stmt in self._stmt
]
stmt_holders = []
for stmt in self._stmt:
stmt_holder = analyzer.analyze(stmt, stmt_holders, self._metadata_provider, self._silent_mode)
stmt_holders.append(stmt_holder)
self._stmt_holders = stmt_holders
self._sql_holder = SQLLineageHolder.of(
self._metadata_provider, *self._stmt_holders
)
Expand Down
23 changes: 23 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,26 @@ def assert_lr_graphs_match(lr: LineageRunner, lr_sqlfluff: LineageRunner) -> Non
f"\n\tGraph with sqlparse: {lr._sql_holder.graph}\n\t"
f"Graph with sqlfluff: {lr_sqlfluff._sql_holder.graph}"
)


def assert_wildcard_lineage(
sql: str,
metadata_provider: MetaDataProvider,
column_lineages=None,
dialect: str = "ansi",
test_sqlfluff: bool = False,
test_sqlparse: bool = True,
skip_graph_check: bool = False,
):
lr = LineageRunner(
sql, dialect=SQLPARSE_DIALECT, metadata_provider=metadata_provider
)
lr_sqlfluff = LineageRunner(
sql, dialect=dialect, metadata_provider=metadata_provider
)
if test_sqlparse:
assert_column_lineage(lr, column_lineages)
if test_sqlfluff:
assert_column_lineage(lr_sqlfluff, column_lineages)
if test_sqlparse and test_sqlfluff and not skip_graph_check:
assert_lr_graphs_match(lr, lr_sqlfluff)
Loading

0 comments on commit b41ee16

Please sign in to comment.