Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Make DBT dependency optional #421

Merged
merged 8 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@
import os
import time
import rich
import yaml
from dataclasses import dataclass
from packaging.version import parse as parse_version
from typing import List, Optional, Dict

import requests
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
from dbt.config.renderer import ProfileRenderer

def import_dbt():
try:
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
from dbt.config.renderer import ProfileRenderer
import yaml
except ImportError:
raise RuntimeError("Could not import 'dbt' package. You can install it using: pip install 'data-diff[dbt]'.")

return parse_run_results, parse_manifest, ProfileRenderer, yaml

from .tracking import (
set_entrypoint_name,
Expand Down Expand Up @@ -263,13 +270,15 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo
self.project_dict = None
self.requires_upper = False

self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()

def get_datadiff_variables(self) -> dict:
return self.project_dict.get("vars").get("data_diff")

def get_models(self):
with open(self.project_dir + RUN_RESULTS_PATH) as run_results:
run_results_dict = json.load(run_results)
run_results_obj = parse_run_results(run_results=run_results_dict)
run_results_obj = self.parse_run_results(run_results=run_results_dict)

dbt_version = parse_version(run_results_obj.metadata.dbt_version)

Expand All @@ -280,7 +289,7 @@ def get_models(self):

with open(self.project_dir + MANIFEST_PATH) as manifest:
manifest_dict = json.load(manifest)
manifest_obj = parse_manifest(manifest=manifest_dict)
manifest_obj = self.parse_manifest(manifest=manifest_dict)

success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
models = [manifest_obj.nodes.get(x) for x in success_models]
Expand All @@ -295,11 +304,11 @@ def get_primary_keys(self, model):

def set_project_dict(self):
with open(self.project_dir + PROJECT_FILE) as project:
self.project_dict = yaml.safe_load(project)
self.project_dict = self.yaml.safe_load(project)

def set_connection(self):
with open(self.profiles_dir + PROFILES_FILE) as profiles:
profiles = yaml.safe_load(profiles)
profiles = self.yaml.safe_load(profiles)

dbt_profile = self.project_dict.get("profile")
profile_outputs = profiles.get(dbt_profile)
Expand All @@ -308,7 +317,7 @@ def set_connection(self):
conn_type = credentials.get("type").lower()

# values can contain env_vars
rendered_credentials = ProfileRenderer().render_data(credentials)
rendered_credentials = self.ProfileRenderer().render_data(credentials)

if conn_type == "snowflake":
if rendered_credentials.get("password") is None or rendered_credentials.get("private_key_path") is not None:
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ trino = {version="^0.314.0", optional=true}
presto-python-client = {version="*", optional=true}
clickhouse-driver = {version="*", optional=true}
duckdb = {version="^0.6.0", optional=true}
dbt-artifacts-parser = "^0.2.4"
dbt-core = "^1.0.0"
dbt-artifacts-parser = {version="^0.2.4", optional=true}
dbt-core = {version="^1.0.0", optional=true}

[tool.poetry.dev-dependencies]
parameterized = "*"
Expand All @@ -54,6 +54,8 @@ presto-python-client = "*"
clickhouse-driver = "*"
vertica-python = "*"
duckdb = "^0.6.0"
dbt-artifacts-parser = "^0.2.4"
dbt-core = "^1.0.0"
# google-cloud-bigquery = "*"
# databricks-sql-connector = "*"

Expand All @@ -70,6 +72,7 @@ trino = ["trino"]
clickhouse = ["clickhouse-driver"]
vertica = ["vertica-python"]
duckdb = ["duckdb"]
dbt = ["dbt-core", "dbt-artifacts-parser"]

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
98 changes: 45 additions & 53 deletions tests/test_dbt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os

import yaml
from data_diff.diff_tables import Algorithm
from .test_cli import run_datadiff_cli

Expand Down Expand Up @@ -49,112 +50,102 @@ def test_get_datadiff_variables_empty(self):
DbtParser.get_datadiff_variables(mock_self)

@patch("builtins.open", new_callable=mock_open, read_data="{}")
@patch("data_diff.dbt.parse_run_results")
@patch("data_diff.dbt.parse_manifest")
def test_get_models(self, mock_manifest_parser, mock_run_parser, mock_open):
def test_get_models(self, mock_open):
expected_value = "expected_value"
mock_self = Mock()
mock_self.project_dir = ""
mock_run_results = Mock()
mock_success_result = Mock()
mock_failed_result = Mock()
mock_manifest = Mock()
mock_run_parser.return_value = mock_run_results
mock_self.parse_run_results.return_value = mock_run_results
mock_run_results.metadata.dbt_version = "1.0.0"
mock_success_result.unique_id = "success_unique_id"
mock_failed_result.unique_id = "failed_unique_id"
mock_success_result.status.name = "success"
mock_failed_result.status.name = "failed"
mock_run_results.results = [mock_success_result, mock_failed_result]
mock_manifest_parser.return_value = mock_manifest
mock_self.parse_manifest.return_value = mock_manifest
mock_manifest.nodes = {"success_unique_id": expected_value}

models = DbtParser.get_models(mock_self)

self.assertEqual(expected_value, models[0])
mock_open.assert_any_call(RUN_RESULTS_PATH)
mock_open.assert_any_call(MANIFEST_PATH)
mock_run_parser.assert_called_once_with(run_results={})
mock_manifest_parser.assert_called_once_with(manifest={})
mock_self.parse_run_results.assert_called_once_with(run_results={})
mock_self.parse_manifest.assert_called_once_with(manifest={})

@patch("builtins.open", new_callable=mock_open, read_data="{}")
@patch("data_diff.dbt.parse_run_results")
@patch("data_diff.dbt.parse_manifest")
def test_get_models_bad_lower_dbt_version(self, mock_manifest_parser, mock_run_parser, mock_open):
def test_get_models_bad_lower_dbt_version(self, mock_open):
mock_self = Mock()
mock_self.project_dir = ""
mock_run_results = Mock()
mock_run_parser.return_value = mock_run_results
mock_self.parse_run_results.return_value = mock_run_results
mock_run_results.metadata.dbt_version = "0.19.0"

with self.assertRaises(Exception) as ex:
DbtParser.get_models(mock_self)

mock_open.assert_called_once_with(RUN_RESULTS_PATH)
mock_run_parser.assert_called_once_with(run_results={})
mock_manifest_parser.assert_not_called()
mock_self.parse_run_results.assert_called_once_with(run_results={})
mock_self.parse_manifest.assert_not_called()
self.assertIn("version to be", ex.exception.args[0])

@patch("builtins.open", new_callable=mock_open, read_data="{}")
@patch("data_diff.dbt.parse_run_results")
@patch("data_diff.dbt.parse_manifest")
def test_get_models_bad_upper_dbt_version(self, mock_manifest_parser, mock_run_parser, mock_open):
def test_get_models_bad_upper_dbt_version(self, mock_open):
mock_self = Mock()
mock_self.project_dir = ""
mock_run_results = Mock()
mock_run_parser.return_value = mock_run_results
mock_self.parse_run_results.return_value = mock_run_results
mock_run_results.metadata.dbt_version = "1.5.1"

with self.assertRaises(Exception) as ex:
DbtParser.get_models(mock_self)

mock_open.assert_called_once_with(RUN_RESULTS_PATH)
mock_run_parser.assert_called_once_with(run_results={})
mock_manifest_parser.assert_not_called()
mock_self.parse_run_results.assert_called_once_with(run_results={})
mock_self.parse_manifest.assert_not_called()
self.assertIn("version to be", ex.exception.args[0])

@patch("builtins.open", new_callable=mock_open, read_data="{}")
@patch("data_diff.dbt.parse_run_results")
@patch("data_diff.dbt.parse_manifest")
def test_get_models_no_success(self, mock_manifest_parser, mock_run_parser, mock_open):
def test_get_models_no_success(self, mock_open):
mock_self = Mock()
mock_self.project_dir = ""
mock_run_results = Mock()
mock_success_result = Mock()
mock_failed_result = Mock()
mock_manifest = Mock()
mock_run_parser.return_value = mock_run_results
mock_self.parse_run_results.return_value = mock_run_results
mock_run_results.metadata.dbt_version = "1.0.0"
mock_failed_result.unique_id = "failed_unique_id"
mock_success_result.status.name = "success"
mock_failed_result.status.name = "failed"
mock_run_results.results = [mock_failed_result]
mock_manifest_parser.return_value = mock_manifest
mock_self.parse_manifest.return_value = mock_manifest
mock_manifest.nodes = {"success_unique_id": "a_unique_id"}

with self.assertRaises(Exception):
DbtParser.get_models(mock_self)

mock_open.assert_any_call(RUN_RESULTS_PATH)
mock_open.assert_any_call(MANIFEST_PATH)
mock_run_parser.assert_called_once_with(run_results={})
mock_manifest_parser.assert_called_once_with(manifest={})
mock_self.parse_run_results.assert_called_once_with(run_results={})
mock_self.parse_manifest.assert_called_once_with(manifest={})

@patch("data_diff.dbt.yaml.safe_load")
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
def test_set_project_dict(self, mock_open, mock_yaml_parse):
def test_set_project_dict(self, mock_open):
expected_dict = {"key1": "value1"}
mock_self = Mock()
mock_self.project_dir = ""
mock_yaml_parse.return_value = expected_dict
mock_self.yaml.safe_load.return_value = expected_dict
DbtParser.set_project_dict(mock_self)

self.assertEqual(mock_self.project_dict, expected_dict)
mock_open.assert_called_once_with(PROJECT_FILE)

@patch("data_diff.dbt.yaml.safe_load")
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
def test_set_connection_snowflake(self, mock_open_file, mock_yaml_parse):
def test_set_connection_snowflake(self, mock_open_file):
expected_driver = "snowflake"
expected_password = "password_value"
profiles_dict = {
Expand All @@ -172,19 +163,19 @@ def test_set_connection_snowflake(self, mock_open_file, mock_yaml_parse):
mock_self = Mock()
mock_self.profiles_dir = ""
mock_self.project_dict = {"profile": "profile_name"}
mock_yaml_parse.return_value = profiles_dict
mock_self.yaml.safe_load.return_value = profiles_dict
mock_self.ProfileRenderer().render_data.return_value = profiles_dict["profile_name"]["outputs"]["connection1"]
DbtParser.set_connection(mock_self)

self.assertIsInstance(mock_self.connection, dict)
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
self.assertEqual(mock_self.connection.get("password"), expected_password)
self.assertEqual(mock_self.requires_upper, True)
mock_open_file.assert_called_once_with(PROFILES_FILE)
mock_yaml_parse.assert_called_once_with(mock_open_file())
mock_self.yaml.safe_load.assert_called_once_with(mock_open_file())

@patch("data_diff.dbt.yaml.safe_load")
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
def test_set_connection_snowflake_no_password(self, mock_open_file, mock_yaml_parse):
def test_set_connection_snowflake_no_password(self, mock_open_file):
expected_driver = "snowflake"
profiles_dict = {
"profile_name": {
Expand All @@ -196,18 +187,18 @@ def test_set_connection_snowflake_no_password(self, mock_open_file, mock_yaml_pa
mock_self = Mock()
mock_self.profiles_dir = ""
mock_self.project_dict = {"profile": "profile_name"}
mock_yaml_parse.return_value = profiles_dict
mock_self.yaml.safe_load.return_value = profiles_dict
mock_self.ProfileRenderer().render_data.return_value = profiles_dict["profile_name"]["outputs"]["connection1"]

with self.assertRaises(Exception):
DbtParser.set_connection(mock_self)

mock_open_file.assert_called_once_with(PROFILES_FILE)
mock_yaml_parse.assert_called_once_with(mock_open_file())
mock_self.yaml.safe_load.assert_called_once_with(mock_open_file())
self.assertNotIsInstance(mock_self.connection, dict)

@patch("data_diff.dbt.yaml.safe_load")
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
def test_set_connection_bigquery(self, mock_open_file, mock_yaml_parse):
def test_set_connection_bigquery(self, mock_open_file):
expected_driver = "bigquery"
expected_method = "oauth"
expected_project = "a_project"
Expand All @@ -229,19 +220,19 @@ def test_set_connection_bigquery(self, mock_open_file, mock_yaml_parse):
mock_self = Mock()
mock_self.profiles_dir = ""
mock_self.project_dict = {"profile": "profile_name"}
mock_yaml_parse.return_value = profiles_dict
mock_self.yaml.safe_load.return_value = profiles_dict
mock_self.ProfileRenderer().render_data.return_value = profiles_dict["profile_name"]["outputs"]["connection1"]
DbtParser.set_connection(mock_self)

self.assertIsInstance(mock_self.connection, dict)
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
self.assertEqual(mock_self.connection.get("project"), expected_project)
self.assertEqual(mock_self.connection.get("dataset"), expected_dataset)
mock_open_file.assert_called_once_with(PROFILES_FILE)
mock_yaml_parse.assert_called_once_with(mock_open_file())
mock_self.yaml.safe_load.assert_called_once_with(mock_open_file())

@patch("data_diff.dbt.yaml.safe_load")
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
def test_set_connection_bigquery_not_oauth(self, mock_open_file, mock_yaml_parse):
def test_set_connection_bigquery_not_oauth(self, mock_open_file):
expected_driver = "bigquery"
expected_method = "not_oauth"
expected_project = "a_project"
Expand All @@ -263,17 +254,17 @@ def test_set_connection_bigquery_not_oauth(self, mock_open_file, mock_yaml_parse
mock_self = Mock()
mock_self.profiles_dir = ""
mock_self.project_dict = {"profile": "profile_name"}
mock_yaml_parse.return_value = profiles_dict
mock_self.yaml.safe_load.return_value = profiles_dict
mock_self.ProfileRenderer().render_data.return_value = profiles_dict["profile_name"]["outputs"]["connection1"]
with self.assertRaises(Exception):
DbtParser.set_connection(mock_self)

mock_open_file.assert_called_once_with(PROFILES_FILE)
mock_yaml_parse.assert_called_once_with(mock_open_file())
mock_self.yaml.safe_load.assert_called_once_with(mock_open_file())
self.assertNotIsInstance(mock_self.connection, dict)

@patch("data_diff.dbt.yaml.safe_load")
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
def test_set_connection_key_error(self, mock_open_file, mock_yaml_parse):
def test_set_connection_key_error(self, mock_open_file):
profiles_dict = {
"profile_name": {
"outputs": {
Expand All @@ -290,17 +281,17 @@ def test_set_connection_key_error(self, mock_open_file, mock_yaml_parse):
mock_self.profiles_dir = ""
mock_self.project_dir = ""
mock_self.project_dict = {"profile": "bad_key"}
mock_yaml_parse.return_value = profiles_dict
mock_self.yaml.safe_load.return_value = profiles_dict
mock_self.ProfileRenderer().render_data.return_value = profiles_dict["profile_name"]["outputs"]["connection1"]
with self.assertRaises(Exception):
DbtParser.set_connection(mock_self)

mock_open_file.assert_called_once_with(PROFILES_FILE)
mock_yaml_parse.assert_called_once_with(mock_open_file())
mock_self.yaml.safe_load.assert_called_once_with(mock_open_file())
self.assertNotIsInstance(mock_self.connection, dict)

@patch("data_diff.dbt.yaml.safe_load")
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
def test_set_connection_not_implemented(self, mock_open_file, mock_yaml_parse):
def test_set_connection_not_implemented(self, mock_open_file):
expected_driver = "not_implemented"
profiles_dict = {
"profile_name": {
Expand All @@ -317,12 +308,13 @@ def test_set_connection_not_implemented(self, mock_open_file, mock_yaml_parse):
mock_self.profiles_dir = ""
mock_self.project_dir = ""
mock_self.project_dict = {"profile": "profile_name"}
mock_yaml_parse.return_value = profiles_dict
mock_self.yaml.safe_load.return_value = profiles_dict
mock_self.ProfileRenderer().render_data.return_value = profiles_dict["profile_name"]["outputs"]["connection1"]
with self.assertRaises(NotImplementedError):
DbtParser.set_connection(mock_self)

mock_open_file.assert_called_once_with(PROFILES_FILE)
mock_yaml_parse.assert_called_once_with(mock_open_file())
mock_self.yaml.safe_load.assert_called_once_with(mock_open_file())
self.assertNotIsInstance(mock_self.connection, dict)


Expand Down