Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

support custom schemas #437

Merged
merged 3 commits into from
Mar 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def dbt_diff(
config_prod_database = datadiff_variables.get("prod_database")
config_prod_schema = datadiff_variables.get("prod_schema")
datasource_id = datadiff_variables.get("datasource_id")
custom_schemas = datadiff_variables.get("custom_schemas")
# custom schemas is default dbt behavior, so default to True if the var doesn't exist
custom_schemas = True if custom_schemas is None else custom_schemas

if not is_cloud:
dbt_parser.set_connection()
Expand All @@ -70,7 +73,9 @@ def dbt_diff(
)

for model in models:
diff_vars = _get_diff_vars(dbt_parser, config_prod_database, config_prod_schema, model, datasource_id)
diff_vars = _get_diff_vars(
dbt_parser, config_prod_database, config_prod_schema, model, datasource_id, custom_schemas
)

if is_cloud and len(diff_vars.primary_keys) > 0:
_cloud_diff(diff_vars)
Expand All @@ -95,6 +100,7 @@ def _get_diff_vars(
config_prod_schema: Optional[str],
model,
datasource_id: int,
custom_schemas: bool,
) -> DiffVars:
dev_database = model.database
dev_schema = model.schema_
Expand All @@ -103,6 +109,12 @@ def _get_diff_vars(
prod_database = config_prod_database if config_prod_database else dev_database
prod_schema = config_prod_schema if config_prod_schema else dev_schema

# if project has custom schemas (default)
# need to construct the prod schema as <prod_target_schema>_<custom_schema>
# https://docs.getdbt.com/docs/build/custom-schemas
if custom_schemas and model.config.schema_:
prod_schema = prod_schema + "_" + model.config.schema_

if dbt_parser.requires_upper:
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias]]
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.alias]]
Expand All @@ -111,16 +123,22 @@ def _get_diff_vars(
dev_qualified_list = [dev_database, dev_schema, model.alias]
prod_qualified_list = [prod_database, prod_schema, model.alias]

return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection, dbt_parser.threads)
return DiffVars(
dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection, dbt_parser.threads
)


def _local_diff(diff_vars: DiffVars) -> None:
column_diffs_str = ""
dev_qualified_string = ".".join(diff_vars.dev_path)
prod_qualified_string = ".".join(diff_vars.prod_path)

table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)
table1 = connect_to_table(
diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads
)
table2 = connect_to_table(
diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads
)

table1_columns = list(table1.get_schema())
try:
Expand Down