Skip to content

Commit

Permalink
refactor: put wildcard expansion logic at end
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed Dec 27, 2023
1 parent 8a1ef1d commit b2840be
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 41 deletions.
7 changes: 5 additions & 2 deletions sqllineage/core/holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def add_read(self, value) -> None:

@property
def write(self) -> Set[Union[SubQuery, Table]]:
# because subquery can be nested, SubQueryLineageHolder.write can return SubQuery or Table,
# or both when __or__ together.
# This is different from StatementLineageHolder.write, where Table is the only possibility.
return self._property_getter(NodeTag.WRITE)

def add_write(self, value) -> None:
Expand All @@ -93,8 +96,8 @@ def write_columns(self) -> List[Column]:
or manually added via `add_write_column` if specified in DML
"""
tgt_cols = []
if self.write:
tgt_tbl = list(self.write)[0]
if write_only := self.write.difference(self.read):
tgt_tbl = list(write_only)[0]
tgt_col_with_idx: List[Tuple[Column, int]] = sorted(
[
(col, attr.get(EdgeTag.INDEX, 0))
Expand Down
77 changes: 38 additions & 39 deletions sqllineage/core/parser/sqlparse/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,6 @@ def _extract_from_dml(
for next_handler in next_handlers:
next_handler.end_of_query_cleanup(holder)

# find wildcard in current subquery
wildcards = []
for column in holder.write_columns:
if column.raw_name == "*":
wildcards.append(column)
# save write table of current subquery
if len(wildcards) > 0:
tgt_table = list(holder.write)[0]

# By recursively extracting each subquery of the parent and merge, we're doing Depth-first search
for sq in subqueries:
holder |= cls._extract_from_dml(
Expand All @@ -259,36 +250,44 @@ def _extract_from_dml(
metadata_provider,
)

# replace wildcard with real columns
for wildcard in wildcards:
for src_wildcard in holder.get_source_columns(wildcard):
if source_table := src_wildcard.parent:
if isinstance(source_table, SubQuery):
# the columns of SubQuery can be inferred from graph
if src_table_columns := holder.get_table_columns(source_table):
holder.replace_wildcard(
tgt_table,
src_table_columns,
wildcard,
src_wildcard,
)
elif isinstance(source_table, Table):
# search by metadata service
if metadata_provider:
db = source_table.schema.raw_name
table = source_table.raw_name
source_columns = []
for col in metadata_provider.get_table_columns(db, table):
column = Column(col)
column.parent = source_table
source_columns.append(column)
if len(source_columns) > 0:
holder.replace_wildcard(
tgt_table,
source_columns,
wildcard,
src_wildcard,
)
# replace wildcard with real columns, put here so that wildcard in subqueries are already replaced
if write_only := holder.write.difference(holder.read):
tgt_table = list(write_only)[0]
for column in holder.write_columns:
if column.raw_name == "*":
wildcard = column
for src_wildcard in holder.get_source_columns(wildcard):
if source_table := src_wildcard.parent:
if isinstance(source_table, SubQuery):
# the columns of SubQuery can be inferred from graph
if src_table_columns := holder.get_table_columns(
source_table
):
holder.replace_wildcard(
tgt_table,
src_table_columns,
wildcard,
src_wildcard,
)
elif isinstance(source_table, Table):
# search by metadata service
if metadata_provider:
db = source_table.schema.raw_name
table = source_table.raw_name
source_columns = []
for col in metadata_provider.get_table_columns(
db, table
):
column = Column(col)
column.parent = source_table
source_columns.append(column)
if len(source_columns) > 0:
holder.replace_wildcard(
tgt_table,
source_columns,
wildcard,
src_wildcard,
)
return holder

@classmethod
Expand Down

0 comments on commit b2840be

Please sign in to comment.