Skip to content

feat(ingest/snowflake): generate lineage through temp views #13517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def query(self, query: str) -> Any:
# We often run multiple queries in parallel across multiple threads,
# so we need to number them to help with log readability.
query_num = self.get_query_no()
logger.info(f"Query #{query_num}: {query}", stacklevel=2)
logger.info(f"Query #{query_num}: {query.rstrip()}", stacklevel=2)
resp = self._connection.cursor(DictCursor).execute(query)
if resp is not None and resp.rowcount is not None:
logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,13 @@ def fetch_query_log(
if entry:
yield entry

@classmethod
def _has_temp_keyword(cls, query_text: str) -> bool:
return (
re.search(r"\bTEMP\b", query_text, re.IGNORECASE) is not None
or re.search(r"\bTEMPORARY\b", query_text, re.IGNORECASE) is not None
)

def _parse_audit_log_row(
self, row: Dict[str, Any], users: UsersMapping
) -> Optional[Union[TableRename, TableSwap, PreparsedQuery, ObservedQuery]]:
Expand All @@ -389,6 +396,15 @@ def _parse_audit_log_row(
key = key.lower()
res[key] = value

timestamp: datetime = res["query_start_time"]
timestamp = timestamp.astimezone(timezone.utc)

# TODO need to map snowflake query types to ours
query_text: str = res["query_text"]
query_type: QueryType = SNOWFLAKE_QUERY_TYPE_MAPPING.get(
res["query_type"], QueryType.UNKNOWN
)

direct_objects_accessed = res["direct_objects_accessed"]
objects_modified = res["objects_modified"]
object_modified_by_ddl = res["object_modified_by_ddl"]
Expand All @@ -399,9 +415,9 @@ def _parse_audit_log_row(
"Error fetching ddl lineage from Snowflake"
):
known_ddl_entry = self.parse_ddl_query(
res["query_text"],
query_text,
res["session_id"],
res["query_start_time"],
timestamp,
object_modified_by_ddl,
res["query_type"],
)
Expand All @@ -419,24 +435,33 @@ def _parse_audit_log_row(
)
)

# Use direct_objects_accessed instead objects_modified
# objects_modified returns $SYS_VIEW_X with no mapping
# There are a couple cases when we'd want to prefer our own SQL parsing
# over Snowflake's metadata.
# 1. For queries that use a stream, objects_modified returns $SYS_VIEW_X with no mapping.
# We can check direct_objects_accessed to see if there is a stream used, and if so,
# prefer doing SQL parsing over Snowflake's metadata.
# 2. For queries that create a view, objects_modified is empty and object_modified_by_ddl
# contains the view name and columns. Because `object_modified_by_ddl` doesn't contain
# source columns e.g. lineage information, we must do our own SQL parsing. We're mainly
# focused on temporary views. It's fine if we parse a couple extra views, but in general
# we want view definitions to come from Snowflake's schema metadata and not from query logs.

has_stream_objects = any(
obj.get("objectDomain") == "Stream" for obj in direct_objects_accessed
)
is_create_view = query_type == QueryType.CREATE_VIEW
is_create_temp_view = is_create_view and self._has_temp_keyword(query_text)

# If a stream is used, default to query parsing.
if has_stream_objects:
logger.debug("Found matching stream object")
if has_stream_objects or is_create_temp_view:
return ObservedQuery(
query=res["query_text"],
query=query_text,
session_id=res["session_id"],
timestamp=res["query_start_time"].astimezone(timezone.utc),
timestamp=timestamp,
user=user,
default_db=res["default_db"],
default_schema=res["default_schema"],
query_hash=get_query_fingerprint(
res["query_text"], self.identifiers.platform, fast=True
query_text, self.identifiers.platform, fast=True
),
)

Expand Down Expand Up @@ -502,25 +527,17 @@ def _parse_audit_log_row(
)
)

timestamp: datetime = res["query_start_time"]
timestamp = timestamp.astimezone(timezone.utc)

# TODO need to map snowflake query types to ours
query_type = SNOWFLAKE_QUERY_TYPE_MAPPING.get(
res["query_type"], QueryType.UNKNOWN
)

entry = PreparsedQuery(
# Despite having Snowflake's fingerprints available, our own fingerprinting logic does a better
# job at eliminating redundant / repetitive queries. As such, we include the fast fingerprint
# here
query_id=get_query_fingerprint(
res["query_text"],
query_text,
self.identifiers.platform,
fast=True,
secondary_id=res["query_secondary_fingerprint"],
),
query_text=res["query_text"],
query_text=query_text,
upstreams=upstreams,
downstream=downstream,
column_lineage=column_lineage,
Expand All @@ -543,7 +560,6 @@ def parse_ddl_query(
object_modified_by_ddl: dict,
query_type: str,
) -> Optional[Union[TableRename, TableSwap]]:
timestamp = timestamp.astimezone(timezone.utc)
if (
object_modified_by_ddl["operationType"] == "ALTER"
and query_type == "RENAME_TABLE"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,6 @@ class SnowflakeQuery:
ACCESS_HISTORY_TABLE_VIEW_DOMAINS_FILTER = "({})".format(
",".join(f"'{domain}'" for domain in ACCESS_HISTORY_TABLE_VIEW_DOMAINS)
)
ACCESS_HISTORY_TABLE_DOMAINS_FILTER = (
"("
f"'{SnowflakeObjectDomain.TABLE.capitalize()}',"
f"'{SnowflakeObjectDomain.VIEW.capitalize()}',"
f"'{SnowflakeObjectDomain.STREAM.capitalize()}',"
")"
)
Comment on lines -46 to -52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this was just unused ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup - it was an unrelated cleanup


@staticmethod
def current_account() -> str:
Expand Down Expand Up @@ -262,6 +255,33 @@ def show_views_for_database(
LIMIT {limit} {from_clause};
"""

@staticmethod
def get_views_for_database(db_name: str) -> str:
# We've seen some issues with the `SHOW VIEWS` query,
# particularly when it requires pagination.
# This is an experimental alternative query that might be more reliable.
return f"""\
SELECT
TABLE_CATALOG as "VIEW_CATALOG",
TABLE_SCHEMA as "VIEW_SCHEMA",
TABLE_NAME as "VIEW_NAME",
COMMENT,
VIEW_DEFINITION,
CREATED,
LAST_ALTERED,
IS_SECURE
FROM "{db_name}".information_schema.views
WHERE TABLE_CATALOG = '{db_name}'
AND TABLE_SCHEMA != 'INFORMATION_SCHEMA'
"""

@staticmethod
def get_views_for_schema(db_name: str, schema_name: str) -> str:
return f"""\
{SnowflakeQuery.get_views_for_database(db_name).rstrip()}
AND TABLE_SCHEMA = '{schema_name}'
"""

@staticmethod
def get_secure_view_definitions() -> str:
# https://docs.snowflake.com/en/sql-reference/account-usage/views
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class SnowflakeV2Report(
# "Information schema query returned too much data. Please repeat query with more selective predicates.""
# This will result in overall increase in time complexity
num_get_tables_for_schema_queries: int = 0
num_get_views_for_schema_queries: int = 0

# these will be non-zero if the user choses to enable the extract_tags = "with_lineage" option, which requires
# individual queries per object (database, schema, table) and an extra query per table to get the tags on the columns.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import Callable, Dict, Iterable, List, MutableMapping, Optional
from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Tuple

from datahub.cli.env_utils import get_boolean_env_variable
from datahub.ingestion.api.report import SupportsAsObj
from datahub.ingestion.source.common.subtypes import DatasetSubTypes
from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain
Expand Down Expand Up @@ -229,7 +230,11 @@ class SnowflakeDataDictionary(SupportsAsObj):
def __init__(self, connection: SnowflakeConnection) -> None:
self.connection = connection

def as_obj(self) -> Dict[str, Dict[str, int]]:
self._use_information_schema_for_views = get_boolean_env_variable(
"DATAHUB_SNOWFLAKE_USE_INFORMATION_SCHEMA_FOR_VIEWS", default=False
)

def as_obj(self) -> Dict[str, Any]:
# TODO: Move this into a proper report type that gets computed.

# Reports how many times we reset in-memory `functools.lru_cache` caches of data,
Expand All @@ -245,7 +250,9 @@ def as_obj(self) -> Dict[str, Dict[str, int]]:
self.get_fk_constraints_for_schema,
]

report = {}
report: Dict[str, Any] = {
"use_information_schema_for_views": self._use_information_schema_for_views,
}
for func in lru_cache_functions:
report[func.__name__] = func.cache_info()._asdict() # type: ignore
return report
Expand Down Expand Up @@ -400,7 +407,17 @@ def get_tables_for_schema(
return tables

@serialized_lru_cache(maxsize=1)
def get_views_for_database(self, db_name: str) -> Dict[str, List[SnowflakeView]]:
def get_views_for_database(
self, db_name: str
) -> Optional[Dict[str, List[SnowflakeView]]]:
if self._use_information_schema_for_views:
return self._get_views_for_database_using_information_schema(db_name)
else:
return self._get_views_for_database_using_show(db_name)

def _get_views_for_database_using_show(
self, db_name: str
) -> Dict[str, List[SnowflakeView]]:
page_limit = SHOW_VIEWS_MAX_PAGE_SIZE

views: Dict[str, List[SnowflakeView]] = {}
Expand Down Expand Up @@ -431,10 +448,9 @@ def get_views_for_database(self, db_name: str) -> Dict[str, List[SnowflakeView]]
SnowflakeView(
name=view_name,
created=view["created_on"],
# last_altered=table["last_altered"],
comment=view["comment"],
view_definition=view["text"],
last_altered=view["created_on"],
last_altered=view["created_on"], # TODO: This is not correct.
materialized=(
view.get("is_materialized", "false").lower() == "true"
),
Expand All @@ -449,8 +465,53 @@ def get_views_for_database(self, db_name: str) -> Dict[str, List[SnowflakeView]]
)
view_pagination_marker = view_name

# Because this is in a cached function, this will only log once per database.
view_counts = {schema_name: len(views[schema_name]) for schema_name in views}
logger.info(
f"Finished fetching views in {db_name}; counts by schema {view_counts}"
)
return views

@classmethod
def _map_view(cls, row: Dict[str, Any]) -> Tuple[str, SnowflakeView]:
schema_name = row["VIEW_SCHEMA"]
return schema_name, SnowflakeView(
name=row["VIEW_NAME"],
created=row["CREATED"],
comment=row["COMMENT"],
view_definition=row["VIEW_DEFINITION"],
last_altered=row["LAST_ALTERED"],
is_secure=(row.get("IS_SECURE", "false").lower() == "true"),
# TODO: This doesn't work for materialized views.
materialized=False,
)

def _get_views_for_database_using_information_schema(
self, db_name: str
) -> Optional[Dict[str, List[SnowflakeView]]]:
try:
cur = self.connection.query(
SnowflakeQuery.get_views_for_database(db_name),
)
except Exception as e:
logger.debug(f"Failed to get all views for database {db_name}", exc_info=e)
# Error - Information schema query returned too much data. Please repeat query with more selective predicates.
return None

views: Dict[str, List[SnowflakeView]] = {}
for row in cur:
schema_name, view = self._map_view(row)
views.setdefault(schema_name, []).append(view)
return views

def get_views_for_schema_using_information_schema(
self, *, schema_name: str, db_name: str
) -> List[SnowflakeView]:
cur = self.connection.query(
SnowflakeQuery.get_views_for_schema(schema_name, db_name),
)
return [self._map_view(row)[1] for row in cur]

@serialized_lru_cache(maxsize=SCHEMA_PARALLELISM)
def get_columns_for_schema(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,10 @@ def get_tables_for_schema(
# falling back to get tables for schema
if tables is None:
self.report.num_get_tables_for_schema_queries += 1
return self.data_dictionary.get_tables_for_schema(schema_name, db_name)
return self.data_dictionary.get_tables_for_schema(
db_name=db_name,
schema_name=schema_name,
)

# Some schema may not have any table
return tables.get(schema_name, [])
Expand All @@ -1228,8 +1231,17 @@ def get_views_for_schema(
) -> List[SnowflakeView]:
views = self.data_dictionary.get_views_for_database(db_name)

# Some schema may not have any table
return views.get(schema_name, [])
if views is not None:
# Some schemas may not have any views
return views.get(schema_name, [])

# Usually this fails when there are too many views in the schema.
# Fall back to per-schema queries.
self.report.num_get_views_for_schema_queries += 1
return self.data_dictionary.get_views_for_schema_using_information_schema(
db_name=db_name,
schema_name=schema_name,
)

def get_columns_for_table(
self, table_name: str, snowflake_schema: SnowflakeSchema, db_name: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class ObservedQuery:
query_hash: Optional[str] = None
usage_multiplier: int = 1

# Use this to store addtitional key-value information about query for debugging
# Use this to store additional key-value information about the query for debugging.
extra_info: Optional[dict] = None


Expand Down
Loading