Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Align PK support with Datafold SaaS #446

Merged
merged 5 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 103 additions & 19 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import json
import logging
import os
import time
import rich

from collections import defaultdict
from dataclasses import dataclass
from packaging.version import parse as parse_version
from typing import List, Optional, Dict, Tuple
from typing import List, Optional, Dict, Tuple, Set
from .utils import getLogger
from .version import __version__
from pathlib import Path

import requests

logger = getLogger(__name__)


def import_dbt():
try:
Expand Down Expand Up @@ -72,7 +77,6 @@ def dbt_diff(
set_entrypoint_name("CLI-dbt")
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, is_cloud)
models = dbt_parser.get_models()
dbt_parser.set_project_dict()
datadiff_variables = dbt_parser.get_datadiff_variables()
config_prod_database = datadiff_variables.get("prod_database")
config_prod_schema = datadiff_variables.get("prod_schema")
Expand Down Expand Up @@ -105,7 +109,7 @@ def dbt_diff(
+ " <> "
+ ".".join(diff_vars.dev_path)
+ "[/] \n"
+ "Skipped due to missing primary-key tag(s).\n"
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
)

rich.print("Diffs Complete!")
Expand All @@ -121,7 +125,8 @@ def _get_diff_vars(
) -> DiffVars:
dev_database = model.database
dev_schema = model.schema_
primary_keys = dbt_parser.get_primary_keys(model)

primary_keys = dbt_parser.get_pk_from_model(model, dbt_parser.unique_columns, "primary-key")

prod_database = config_prod_database if config_prod_database else dev_database
prod_schema = config_prod_schema if config_prod_schema else dev_schema
Expand Down Expand Up @@ -162,7 +167,7 @@ def _local_diff(diff_vars: DiffVars) -> None:
table2_columns = list(table2.get_schema())
# Not ideal, but we don't have more specific exceptions yet
except Exception as ex:
logging.info(ex)
logger.debug(ex)
rich.print(
"[red]"
+ prod_qualified_string
Expand Down Expand Up @@ -287,15 +292,16 @@ def _cloud_diff(diff_vars: DiffVars) -> None:

class DbtParser:
def __init__(self, profiles_dir_override: str, project_dir_override: str, is_cloud: bool) -> None:
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
self.profiles_dir = Path(profiles_dir_override or default_profiles_dir())
self.project_dir = Path(project_dir_override or default_project_dir())
self.is_cloud = is_cloud
self.connection = None
self.project_dict = None
self.project_dict = self.get_project_dict()
self.manifest_obj = self.get_manifest_obj()
self.requires_upper = False
self.threads = None

self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
self.unique_columns = self.get_unique_columns()

def get_datadiff_variables(self) -> dict:
return self.project_dict.get("vars").get("data_diff")
Expand All @@ -315,24 +321,24 @@ def get_models(self):
f"Found dbt: v{dbt_version} Expected the dbt project's version to be >= {LOWER_DBT_V} and < {UPPER_DBT_V}"
)

with open(self.project_dir / MANIFEST_PATH) as manifest:
manifest_dict = json.load(manifest)
manifest_obj = self.parse_manifest(manifest=manifest_dict)

success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
models = [manifest_obj.nodes.get(x) for x in success_models]
models = [self.manifest_obj.nodes.get(x) for x in success_models]
if not models:
raise ValueError("Expected > 0 successful models runs from the last dbt command.")

rich.print(f"Found {str(len(models))} successful model runs from the last dbt command.")
print(f"Running with data-diff={__version__}\n")
return models

def get_primary_keys(self, model):
return list((x.name for x in model.columns.values() if "primary-key" in x.tags))
def get_manifest_obj(self):
with open(self.project_dir / MANIFEST_PATH) as manifest:
manifest_dict = json.load(manifest)
manifest_obj = self.parse_manifest(manifest=manifest_dict)
return manifest_obj

def set_project_dict(self):
def get_project_dict(self):
with open(self.project_dir / PROJECT_FILE) as project:
self.project_dict = self.yaml.safe_load(project)
project_dict = self.yaml.safe_load(project)
return project_dict

def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
profiles_path = self.profiles_dir / PROFILES_FILE
Expand Down Expand Up @@ -437,3 +443,81 @@ def set_connection(self):
raise NotImplementedError(f"Provider {conn_type} is not yet supported for dbt diffs")

self.connection = conn_info

def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str]:
try:
# Get a set of all the column names
column_names = {name for name, params in node.columns.items()}
# Check if the tag is present on a table level
if pk_tag in node.meta:
# Get all the PKs that are also present as a column
pks = [pk for pk in pk_tag in node.meta[pk_tag] if pk in column_names]
if pks:
# If there are any left, return it
logger.debug("Found PKs via Table META: " + str(pks))
return pks

from_meta = [name for name, params in node.columns.items() if pk_tag in params.meta] or None
if from_meta:
logger.debug("Found PKs via META: " + str(from_meta))
return from_meta

from_tags = [name for name, params in node.columns.items() if pk_tag in params.tags] or None
if from_tags:
logger.debug("Found PKs via Tags: " + str(from_tags))
return from_tags

if node.unique_id in unique_columns:
from_uniq = unique_columns.get(node.unique_id)
if from_uniq is not None:
logger.debug("Found PKs via Uniqueness tests: " + str(from_uniq))
return list(from_uniq)

except (KeyError, IndexError, TypeError) as e:
raise e

logger.debug("Found no PKs")
return []

def get_unique_columns(self) -> Dict[str, Set[str]]:
manifest = self.manifest_obj
cols_by_uid = defaultdict(set)
for node in manifest.nodes.values():
try:
if not (node.resource_type.value == "test" and hasattr(node, "test_metadata")):
continue

if node.depends_on is None or node.depends_on.nodes is []:
continue

uid = node.depends_on.nodes[0]
model_node = manifest.nodes[uid]

if node.test_metadata.name == "unique":
column_name: str = node.test_metadata.kwargs["column_name"]
for col in self._parse_concat_pk_definition(column_name):
if model_node is None or col in model_node.columns:
# skip anything that is not a column.
# for example, string literals used in concat
# like "pk1 || '-' || pk2"
cols_by_uid[uid].add(col)

if node.test_metadata.name == "unique_combination_of_columns":
for col in node.test_metadata.kwargs["combination_of_columns"]:
cols_by_uid[uid].add(col)

except (KeyError, IndexError, TypeError) as e:
logger.warning("Failure while finding unique cols: %s", e)

return cols_by_uid

def _parse_concat_pk_definition(self, definition: str) -> List[str]:
definition = definition.strip()
if definition.lower().startswith("concat(") and definition.endswith(")"):
definition = definition[7:-1] # Removes concat( and )
columns = definition.split(",")
else:
columns = definition.split("||")

stripped_columns = [col.strip('" ()') for col in columns]
return stripped_columns
Loading