Skip to content

Commit

Permalink
Foreach: resolve vars in queries (#2183)
Browse files Browse the repository at this point in the history
  • Loading branch information
m1n0 authored Nov 4, 2024
1 parent bbe338b commit 0ecbec4
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 3 deletions.
4 changes: 2 additions & 2 deletions soda/core/soda/execution/check/metric_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ def __init__(
f"User defined metric {metric_name} is a data_source supported metric. Please, choose a different name for the metric.",
location=metric_check_cfg.location,
)

jinja_resolve = self.data_source_scan.scan.jinja_resolve
metric = UserDefinedNumericMetric(
data_source_scan=self.data_source_scan,
check_name=metric_check_cfg.source_line,
sql=metric_check_cfg.metric_query,
sql=jinja_resolve(metric_check_cfg.metric_query, metric_check_cfg.variables),
check=self,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ def __init__(
check_cfg: UserDefinedFailedRowsCheckCfg = self.check_cfg
self.check_value = None

jinja_resolve = self.data_source_scan.scan.jinja_resolve

metric = UserDefinedFailedRowsMetric(
data_source_scan=self.data_source_scan,
check_name=check_cfg.source_line,
query=check_cfg.query,
query=jinja_resolve(check_cfg.query, self.check_cfg.variables),
check=self,
partition=partition,
)
Expand Down
2 changes: 2 additions & 0 deletions soda/core/soda/sodacl/check_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
self.samples_limit: int | None = samples_limit
self.samples_columns: list | None = samples_columns
self.failed_rows_query: str | None = failed_rows_query
self.variables: dict = {}

def get_column_name(self) -> str | None:
pass
Expand All @@ -35,6 +36,7 @@ def instantiate_for_each_dataset(
self, name: str, table_alias: str, table_name: str, partition_name: str
) -> CheckCfg:
instantiated_check_cfg = deepcopy(self)
instantiated_check_cfg.variables[table_alias] = table_name
partition_replace = f" [{partition_name}]" if partition_name else ""
instantiated_check_cfg.name = name
instantiated_check_cfg.source_header = f"checks for {table_alias} being {table_name}{partition_replace}"
Expand Down
26 changes: 26 additions & 0 deletions soda/core/tests/data_source/test_for_each_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,29 @@ def test_for_each_dataset_with_quotes_warning(data_source_fixture: DataSourceFix
x for x in scan_results["logs"] if "It looks like quote characters are present" in x["message"]
]
assert len(character_log_warnings) == 2


@pytest.mark.skipif(
test_data_source in ["sqlserver"],
reason="Avoid running in datasources that need special quoting as that clashes with templates.",
)
def test_for_each_dataset_variables(data_source_fixture: DataSourceFixture):
customers_table_name = data_source_fixture.ensure_test_table(customers_test_table)

scan = data_source_fixture.create_test_scan()
scan.add_sodacl_yaml_str(
f"""
for each dataset D:
datasets:
- {customers_table_name}
checks:
- failed rows:
fail query: "SELECT * FROM ${{D}} WHERE {scan.casify_column_name('id')} = 'ID100'"
- user_metric < 100:
user_metric query: "SELECT count(*) FROM ${{D}}"
"""
)
scan.execute()

scan.assert_all_checks_pass()
assert len(scan._checks) == 2
7 changes: 7 additions & 0 deletions soda/core/tests/helpers/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,16 @@ def __get_error_message(self, expected_outcome) -> str | None:

return error_message

@property
def data_source(self) -> DataSource:
return self._data_source_manager.data_sources[self._data_source_name]

def casify_data_type(self, data_type: str) -> str:
data_source_type = self.data_source.get_sql_type_for_schema_check(data_type)
return self.data_source.default_casify_column_name(data_source_type)

def casify_column_name(self, test_column_name: str) -> str:
return self.data_source.default_casify_column_name(test_column_name)

def quote_table_name(self, table_name: str) -> str:
return self.data_source.quote_table(table_name)

0 comments on commit 0ecbec4

Please sign in to comment.