Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

support datadiff meta filter #522

Merged
merged 2 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions data_diff/cloud/datafold_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class TCloudApiDataDiff(pydantic.BaseModel):
table1: List[str]
table2: List[str]
pk_columns: List[str]
filter1: Optional[str] = None
filter2: Optional[str] = None


class TSummaryResultPrimaryKeyStats(pydantic.BaseModel):
Expand Down
23 changes: 21 additions & 2 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class DiffVars:
primary_keys: List[str]
connection: Dict[str, str]
threads: Optional[int]
where_filter: Optional[str] = None


def dbt_diff(
Expand Down Expand Up @@ -191,7 +192,16 @@ def _get_diff_vars(
dev_qualified_list = [dev_database, dev_schema, model.alias]
prod_qualified_list = [prod_database, prod_schema, model.alias]

return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, dbt_parser.connection, dbt_parser.threads)
where_filter = None
if model.meta:
try:
where_filter = model.meta["datafold"]["datadiff"]["filter"]
except KeyError:
pass

return DiffVars(
dev_qualified_list, prod_qualified_list, primary_keys, dbt_parser.connection, dbt_parser.threads, where_filter
)


def _local_diff(diff_vars: DiffVars) -> None:
Expand Down Expand Up @@ -228,7 +238,14 @@ def _local_diff(diff_vars: DiffVars) -> None:
mutual_set = mutual_set - set(diff_vars.primary_keys)
extra_columns = tuple(mutual_set)

diff = diff_tables(table1, table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=extra_columns)
diff = diff_tables(
table1,
table2,
threaded=True,
algorithm=Algorithm.JOINDIFF,
extra_columns=extra_columns,
where=diff_vars.where_filter,
)

if list(diff):
diff_output_str += f"{column_diffs_str}{diff.get_stats_string(is_dbt=True)} \n"
Expand Down Expand Up @@ -277,6 +294,8 @@ def _cloud_diff(diff_vars: DiffVars, datasource_id: int, api: DatafoldAPI) -> No
table1=diff_vars.prod_path,
table2=diff_vars.dev_path,
pk_columns=diff_vars.primary_keys,
filter1=diff_vars.where_filter,
filter2=diff_vars.where_filter,
)

if is_tracking_enabled():
Expand Down
145 changes: 129 additions & 16 deletions tests/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,20 +425,27 @@ def test_local_diff(self, mock_diff_tables):
mock_diff = MagicMock()
mock_diff_tables.return_value = mock_diff
mock_diff.__iter__.return_value = [1, 2, 3]
threads = None
where = "a_string"
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
expected_keys = ["key"]
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, mock_connection, None)
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, mock_connection, threads, where)
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
_local_diff(diff_vars)

mock_diff_tables.assert_called_once_with(
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY
mock_table1,
mock_table2,
threaded=True,
algorithm=Algorithm.JOINDIFF,
extra_columns=ANY,
where=where,
)
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
self.assertEqual(mock_connect.call_count, 2)
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None)
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None)
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), threads)
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), threads)
mock_diff.get_stats_string.assert_called_once()

@patch("data_diff.dbt.diff_tables")
Expand All @@ -455,12 +462,14 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
expected_keys = ["primary_key_column"]
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, mock_connection, None)
threads = None
where = "a_string"
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, mock_connection, threads, where)
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
_local_diff(diff_vars)

mock_diff_tables.assert_called_once_with(
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY, where=where
)
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
self.assertEqual(mock_connect.call_count, 2)
Expand All @@ -479,7 +488,10 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
expected_datasource_id = 1
expected_primary_keys = ["primary_key_column"]
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_primary_keys, None, None)
connection = None
threads = None
where = "a_string"
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_primary_keys, connection, threads, where)
_cloud_diff(diff_vars, expected_datasource_id, api=mock_api)

mock_api.create_data_diff.assert_called_once()
Expand All @@ -491,6 +503,8 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
self.assertEqual(payload.table1, prod_qualified_list)
self.assertEqual(payload.table2, dev_qualified_list)
self.assertEqual(payload.pk_columns, expected_primary_keys)
self.assertEqual(payload.filter1, where)
self.assertEqual(payload.filter2, where)

@patch("data_diff.dbt._initialize_api")
@patch("data_diff.dbt._get_diff_vars")
Expand All @@ -512,11 +526,14 @@ def test_diff_is_cloud(
api_key = "a_api_key"
api = DatafoldAPI(api_key=api_key, host=host)
mock_initialize_api.return_value = api
connection = None
threads = None
where = "a_string"

mock_dbt_parser.return_value = mock_dbt_parser_inst
mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], None, None)
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], connection, threads, where)
mock_get_diff_vars.return_value = expected_diff_vars
dbt_diff(is_cloud=True)
mock_dbt_parser_inst.get_models.assert_called_once()
Expand Down Expand Up @@ -547,11 +564,14 @@ def test_diff_is_cloud_no_ds_id(
api_key = "a_api_key"
api = DatafoldAPI(api_key=api_key, host=host)
mock_initialize_api.return_value = api
connection = None
threads = None
where = "a_string"

mock_dbt_parser.return_value = mock_dbt_parser_inst
mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], None, None)
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], connection, threads, where)
mock_get_diff_vars.return_value = expected_diff_vars

with self.assertRaises(ValueError):
Expand Down Expand Up @@ -579,7 +599,10 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
}
mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], None, None)
connection = None
threads = None
where = "a_string"
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], connection, threads, where)
mock_get_diff_vars.return_value = expected_diff_vars
dbt_diff(is_cloud=False)

Expand All @@ -606,7 +629,10 @@ def test_diff_no_prod_configs(

mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], None, None)
connection = None
threads = None
where = "a_string"
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], connection, threads, where)
mock_get_diff_vars.return_value = expected_diff_vars
with self.assertRaises(ValueError):
dbt_diff(is_cloud=False)
Expand All @@ -633,7 +659,10 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
}
mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], None, None)
connection = None
threads = None
where = "a_string"
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], connection, threads, where)
mock_get_diff_vars.return_value = expected_diff_vars
dbt_diff(is_cloud=False)

Expand Down Expand Up @@ -661,7 +690,10 @@ def test_diff_only_prod_schema(

mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], None, None)
connection = None
threads = None
where = "a_string"
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], connection, threads, where)
mock_get_diff_vars.return_value = expected_diff_vars
with self.assertRaises(ValueError):
dbt_diff(is_cloud=False)
Expand Down Expand Up @@ -697,7 +729,10 @@ def test_diff_is_cloud_no_pks(

mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
expected_diff_vars = DiffVars(["dev"], ["prod"], [], None, None)
connection = None
threads = None
where = "a_string"
expected_diff_vars = DiffVars(["dev"], ["prod"], [], connection, threads, where)
mock_get_diff_vars.return_value = expected_diff_vars
dbt_diff(is_cloud=True)

Expand Down Expand Up @@ -727,8 +762,10 @@ def test_diff_not_is_cloud_no_pks(

mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict

expected_diff_vars = DiffVars(["dev"], ["prod"], [], None, None)
connection = None
threads = None
where = "a_string"
expected_diff_vars = DiffVars(["dev"], ["prod"], [], connection, threads, where)
mock_get_diff_vars.return_value = expected_diff_vars
dbt_diff(is_cloud=False)
mock_dbt_parser_inst.get_models.assert_called_once()
Expand All @@ -749,6 +786,7 @@ def test_get_diff_vars_replace_custom_schema(self):
mock_dbt_parser = Mock()
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
mock_dbt_parser.requires_upper = False
mock_model.meta = None

diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod_<custom_schema>", mock_model)

Expand All @@ -773,6 +811,7 @@ def test_get_diff_vars_static_custom_schema(self):
mock_dbt_parser = Mock()
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
mock_dbt_parser.requires_upper = False
mock_model.meta = None

diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod", mock_model)

Expand All @@ -796,6 +835,7 @@ def test_get_diff_vars_no_custom_schema_on_model(self):
mock_dbt_parser = Mock()
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
mock_dbt_parser.requires_upper = False
mock_model.meta = None

diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod", mock_model)

Expand All @@ -817,6 +857,7 @@ def test_get_diff_vars_match_dev_schema(self):
mock_dbt_parser = Mock()
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
mock_dbt_parser.requires_upper = False
mock_model.meta = None

diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model)

Expand Down Expand Up @@ -844,3 +885,75 @@ def test_get_diff_custom_schema_no_config_exception(self):
_get_diff_vars(mock_dbt_parser, prod_database, prod_schema, None, mock_model)

mock_dbt_parser.get_pk_from_model.assert_called_once()

def test_get_diff_vars_meta_where(self):
mock_model = Mock()
prod_database = "a_prod_db"
primary_keys = ["a_primary_key"]
mock_model.database = "a_dev_db"
mock_model.schema_ = "a_schema"
mock_model.config.schema_ = None
mock_model.alias = "a_model_name"
mock_dbt_parser = Mock()
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
mock_dbt_parser.requires_upper = False
where = "a filter"
mock_model.meta = {"datafold": {"datadiff": {"filter": where}}}

diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model)

assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
assert diff_vars.prod_path == [prod_database, mock_model.schema_, mock_model.alias]
assert diff_vars.primary_keys == primary_keys
assert diff_vars.connection == mock_dbt_parser.connection
assert diff_vars.threads == mock_dbt_parser.threads
self.assertEqual(diff_vars.where_filter, where)
mock_dbt_parser.get_pk_from_model.assert_called_once()

def test_get_diff_vars_meta_unrelated(self):
mock_model = Mock()
prod_database = "a_prod_db"
primary_keys = ["a_primary_key"]
mock_model.database = "a_dev_db"
mock_model.schema_ = "a_schema"
mock_model.config.schema_ = None
mock_model.alias = "a_model_name"
mock_dbt_parser = Mock()
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
mock_dbt_parser.requires_upper = False
where = None
mock_model.meta = {"key": "value"}

diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model)

assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
assert diff_vars.prod_path == [prod_database, mock_model.schema_, mock_model.alias]
assert diff_vars.primary_keys == primary_keys
assert diff_vars.connection == mock_dbt_parser.connection
assert diff_vars.threads == mock_dbt_parser.threads
self.assertEqual(diff_vars.where_filter, where)
mock_dbt_parser.get_pk_from_model.assert_called_once()

def test_get_diff_vars_meta_none(self):
mock_model = Mock()
prod_database = "a_prod_db"
primary_keys = ["a_primary_key"]
mock_model.database = "a_dev_db"
mock_model.schema_ = "a_schema"
mock_model.config.schema_ = None
mock_model.alias = "a_model_name"
mock_dbt_parser = Mock()
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
mock_dbt_parser.requires_upper = False
where = None
mock_model.meta = None

diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model)

assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
assert diff_vars.prod_path == [prod_database, mock_model.schema_, mock_model.alias]
assert diff_vars.primary_keys == primary_keys
assert diff_vars.connection == mock_dbt_parser.connection
assert diff_vars.threads == mock_dbt_parser.threads
self.assertEqual(diff_vars.where_filter, where)
mock_dbt_parser.get_pk_from_model.assert_called_once()