Skip to content

dbt cloud powered data-diff #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
metavar="PATH",
help="Override the dbt project directory. Otherwise assumed to be the current directory.",
)
@click.option(
"--select",
default=None,
metavar="PATH",
help="select dbt resources to compare",
)

def main(conf, run, **kw):
if kw["table2"] is None and kw["database2"]:
# Use the "database table table" form
Expand Down Expand Up @@ -263,8 +270,12 @@ def main(conf, run, **kw):
profiles_dir_override=kw["dbt_profiles_dir"],
project_dir_override=kw["dbt_project_dir"],
is_cloud=kw["cloud"],
selection=kw["select"],
)
render_diff(diff, kw["limit"], kw["stats"], kw["json_output"])
for d in diff:
# import pdb; pdb.set_trace()
rich.print(f"Diffing {'.'.join(d.info_tree.info.tables[0].table_path)} with {'.'.join(d.info_tree.info.tables[1].table_path)}")
render_diff(d, kw["limit"], kw["stats"], kw["json_output"])

else:
return _data_diff(**kw)
Expand Down
87 changes: 49 additions & 38 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import List, Optional, Dict

import requests
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
# from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
from dbt.config.renderer import ProfileRenderer

from .tracking import (
Expand All @@ -21,14 +21,19 @@
)
from .utils import run_as_daemon, truncate_error
from . import connect_to_table, diff_tables, Algorithm
import subprocess
from .dbt_cloud import get_client, dynamic_request

RUN_RESULTS_PATH = "/target/run_results.json"
MANIFEST_PATH = "/target/manifest.json"
PROJECT_FILE = "/dbt_project.yml"
PROFILES_FILE = "/profiles.yml"
LOWER_DBT_V = "1.0.0"
UPPER_DBT_V = "1.5.0"
DBT_CLOUD_API_KEY = os.getenv('DBT_CLOUD_API_KEY', None)
DBT_CLOUD_PROD_ENV_ID = os.getenv('DBT_CLOUD_PROD_ENV_ID', None)

dbtc_client = get_client(DBT_CLOUD_API_KEY)

@dataclass
class DiffVars:
Expand All @@ -40,25 +45,27 @@ class DiffVars:


def dbt_diff(
profiles_dir_override: Optional[str] = None, project_dir_override: Optional[str] = None, is_cloud: bool = False
profiles_dir_override: Optional[str] = None, project_dir_override: Optional[str] = None, is_cloud: bool = False, selection: str = None
) -> None:
set_entrypoint_name("CLI-dbt")
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, is_cloud)
models = dbt_parser.get_models()
models = dbt_parser.get_models(selection)
dbt_parser.set_project_dict()
datadiff_variables = dbt_parser.get_datadiff_variables()
config_prod_database = datadiff_variables.get("prod_database")
config_prod_schema = datadiff_variables.get("prod_schema")
# config_prod_database = datadiff_variables.get("prod_database")
# config_prod_schema = datadiff_variables.get("prod_schema")
datasource_id = datadiff_variables.get("datasource_id")

if not is_cloud:
dbt_parser.set_connection()

if config_prod_database is None or config_prod_schema is None:
raise ValueError("Expected a value for prod_database: or prod_schema: under \nvars:\n data_diff: ")
# if config_prod_database is None or config_prod_schema is None:
# raise ValueError("Expected a value for prod_database: or prod_schema: under \nvars:\n data_diff: ")

# import pdb; pdb.set_trace()
model_output = []
for model in models:
diff_vars = _get_diff_vars(dbt_parser, config_prod_database, config_prod_schema, model, datasource_id)
diff_vars = _get_diff_vars(dbt_parser, model, datasource_id)

if is_cloud and len(diff_vars.primary_keys) > 0:
_cloud_diff(diff_vars)
Expand All @@ -73,7 +80,8 @@ def dbt_diff(
)

if not is_cloud and len(diff_vars.primary_keys) == 1:
return _local_diff(diff_vars)
model_output.append(_local_diff(diff_vars))
# print(result.diff)
elif not is_cloud:
rich.print(
"[red]"
Expand All @@ -83,31 +91,38 @@ def dbt_diff(
+ "[/] \n"
+ "Skipped due to missing primary-key tag or multi-column primary-key (unsupported for non --cloud diffs)\n"
)

rich.print("Diffs Complete!")
return model_output


def _get_diff_vars(
dbt_parser: "DbtParser",
config_prod_database: Optional[str],
config_prod_schema: Optional[str],
model,
datasource_id: int,
) -> DiffVars:
dev_database = model.database
dev_schema = model.schema_
dev_database = model.get("database")
dev_schema = model.get("schema")
unique_id = model.get("unique_id")
primary_keys = dbt_parser.get_primary_keys(model)

prod_database = config_prod_database if config_prod_database else dev_database
prod_schema = config_prod_schema if config_prod_schema else dev_schema
prod_model_response = dynamic_request(
dbtc_client.metadata,
'get_model_by_environment',
environment_id=DBT_CLOUD_PROD_ENV_ID,
unique_id=unique_id,
last_run_count=1
)

prod_model_data = prod_model_response.get('data', {}).get("modelByEnvironment", [])[0]
prod_database = prod_model_data.get("database")
prod_schema = prod_model_data.get("schema")

if dbt_parser.requires_upper:
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.name]]
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.name]]
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.get("name")]]
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.get("name")]]
primary_keys = [x.upper() for x in primary_keys]
else:
dev_qualified_list = [dev_database, dev_schema, model.name]
prod_qualified_list = [prod_database, prod_schema, model.name]
dev_qualified_list = [dev_database, dev_schema, model.get("name")]
prod_qualified_list = [prod_database, prod_schema, model.get("name")]

return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection)

Expand Down Expand Up @@ -243,32 +258,28 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo
def get_datadiff_variables(self) -> dict:
return self.project_dict.get("vars").get("data_diff")

def get_models(self):
with open(self.project_dir + RUN_RESULTS_PATH) as run_results:
run_results_dict = json.load(run_results)
run_results_obj = parse_run_results(run_results=run_results_dict)

dbt_version = parse_version(run_results_obj.metadata.dbt_version)

if dbt_version < parse_version(LOWER_DBT_V) or dbt_version >= parse_version(UPPER_DBT_V):
raise Exception(
f"Found dbt: v{dbt_version} Expected the dbt project's version to be >= {LOWER_DBT_V} and < {UPPER_DBT_V}"
)
def get_models(self, selection):
if selection:
ls_cmd = ["dbt", "ls", "--resource-type", "model", "--select", selection]
else:
ls_cmd = ["dbt", "ls", "--resource-type", "model"]
result = subprocess.run(ls_cmd, capture_output=True, text=True)
model_list = ["model." + model for model in result.stdout.splitlines()]

with open(self.project_dir + MANIFEST_PATH) as manifest:
manifest_dict = json.load(manifest)
manifest_obj = parse_manifest(manifest=manifest_dict)
# for some reason the manifest parser appears to be choking on my seed
# manifest_obj = parse_manifest(manifest=manifest_dict)

success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
models = [manifest_obj.nodes.get(x) for x in success_models]
models = [manifest_dict.get("nodes").get(x) for x in model_list]
if not models:
raise ValueError("Expected > 0 successful models runs from the last dbt command.")
raise ValueError("No models selected!")

rich.print(f"Found {str(len(models))} successful model runs from the last dbt command.")
rich.print(f"Found {str(len(models))} models to compare.")
return models

def get_primary_keys(self, model):
return list((x.name for x in model.columns.values() if "primary-key" in x.tags))
return [x.get("name") for x in model.get("columns").values() if "primary-key" in x.get("tags")]

def set_project_dict(self):
with open(self.project_dir + PROJECT_FILE) as project:
Expand Down
13 changes: 13 additions & 0 deletions data_diff/dbt_cloud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dbtc import dbtCloudClient

# first party

def get_client(service_token):
client = dbtCloudClient(service_token=service_token)
return client

def dynamic_request(_prop, method, *args, **kwargs):
return getattr(_prop, method)(*args, **kwargs)

if __name__ == '__main__':
pass