Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

reduce repetition in print statements #464

Merged
merged 4 commits into from
Apr 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 24 additions & 68 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,7 @@ def dbt_diff(
_local_diff(diff_vars)
else:
rich.print(
"[red]"
+ ".".join(diff_vars.prod_path)
+ " <> "
+ ".".join(diff_vars.dev_path)
+ "[/] \n"
_diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
)

Expand Down Expand Up @@ -172,14 +168,13 @@ def _get_diff_vars(

def _local_diff(diff_vars: DiffVars) -> None:
column_diffs_str = ""
dev_qualified_string = ".".join(diff_vars.dev_path)
prod_qualified_string = ".".join(diff_vars.prod_path)
dev_qualified_str = ".".join(diff_vars.dev_path)
prod_qualified_str = ".".join(diff_vars.prod_path)
diff_output_str = _diff_output_base(dev_qualified_str, prod_qualified_str)

table1 = connect_to_table(
diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads
)
table1 = connect_to_table(diff_vars.connection, dev_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads)
table2 = connect_to_table(
diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads
diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads
)

table1_columns = list(table1.get_schema())
Expand All @@ -188,15 +183,8 @@ def _local_diff(diff_vars: DiffVars) -> None:
# Not ideal, but we don't have more specific exceptions yet
except Exception as ex:
logger.debug(ex)
rich.print(
"[red]"
+ prod_qualified_string
+ " <> "
+ dev_qualified_string
+ "[/] \n"
+ column_diffs_str
+ "[green]New model or no access to prod table.[/] \n"
)
diff_output_str += "[red]New model or no access to prod table.[/] \n"
rich.print(diff_output_str)
return

mutual_set = set(table1_columns) & set(table2_columns)
Expand All @@ -215,29 +203,15 @@ def _local_diff(diff_vars: DiffVars) -> None:
diff = diff_tables(table1, table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=extra_columns)

if list(diff):
rich.print(
"[red]"
+ prod_qualified_string
+ " <> "
+ dev_qualified_string
+ "[/] \n"
+ column_diffs_str
+ diff.get_stats_string(is_dbt=True)
+ "\n"
)
diff_output_str += f"{column_diffs_str}{diff.get_stats_string(is_dbt=True)} \n"
rich.print(diff_output_str)
else:
rich.print(
"[red]"
+ prod_qualified_string
+ " <> "
+ dev_qualified_string
+ "[/] \n"
+ column_diffs_str
+ "[green]No row differences[/] \n"
)
diff_output_str += f"{column_diffs_str}[bold][green]No row differences[/][/] \n"
rich.print(diff_output_str)


def _cloud_diff(diff_vars: DiffVars, datasource_id: int, datafold_host: str, url: str, api_key: str) -> None:
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
payload = {
"data_source1_id": datasource_id,
"data_source2_id": datasource_id,
Expand Down Expand Up @@ -286,24 +260,11 @@ def _cloud_diff(diff_vars: DiffVars, datasource_id: int, datafold_host: str, url
diff_percent_list,
"Value Match Percent:",
)
rich.print(
"[red]"
+ ".".join(diff_vars.prod_path)
+ " <> "
+ ".".join(diff_vars.dev_path)
+ f"[/]\n{diff_url}\n"
+ diff_output
+ "\n"
)
diff_output_str += f"{diff_url}\n {diff_output} \n"
rich.print(diff_output_str)
else:
rich.print(
"[red]"
+ ".".join(diff_vars.prod_path)
+ " <> "
+ ".".join(diff_vars.dev_path)
+ f"[/]\n{diff_url}\n"
+ "[green]No row differences[/] \n"
)
diff_output_str += f"{diff_url}\n [green]No row differences[/] \n"
rich.print(diff_output_str)

except BaseException as ex: # Catch KeyboardInterrupt too
error = ex
Expand All @@ -328,12 +289,7 @@ def _cloud_diff(diff_vars: DiffVars, datasource_id: int, datafold_host: str, url
send_event_json(event_json)

if error:
rich.print(
"[red]"
+ ".".join(diff_vars.prod_path)
+ " <> "
+ ".".join(diff_vars.dev_path) + "[/]\n"
)
rich.print(diff_output_str)
if diff_id:
diff_url = f"{datafold_host}/datadiffs/{diff_id}/overview"
rich.print(f"{diff_url} \n")
Expand Down Expand Up @@ -398,6 +354,10 @@ def _cloud_poll_and_get_summary_results(url, headers):
return summary_results


def _diff_output_base(dev_path: str, prod_path: str) -> str:
return f"[green]{prod_path} <> {dev_path}[/] \n"


class DbtParser:
def __init__(self, profiles_dir_override: str, project_dir_override: str) -> None:
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
Expand All @@ -414,12 +374,8 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str) -> Non
self.unique_columns = self.get_unique_columns()

def get_datadiff_variables(self) -> dict:
vars = get_from_dict_with_raise(
self.project_dict, "vars", f"No vars: found in dbt_project.yml."
)
return get_from_dict_with_raise(
vars, "data_diff", f"data_diff: section not found in dbt_project.yml vars:."
)
vars = get_from_dict_with_raise(self.project_dict, "vars", f"No vars: found in dbt_project.yml.")
return get_from_dict_with_raise(vars, "data_diff", f"data_diff: section not found in dbt_project.yml vars:.")

def get_models(self):
with open(self.project_dir / RUN_RESULTS_PATH) as run_results:
Expand Down