Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions sqlmesh/core/plan/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,23 +574,31 @@ def _get_audit_only_snapshots(
) -> t.Dict[SnapshotId, Snapshot]:
metadata_snapshots = []
for snapshot in new_snapshots.values():
if not snapshot.is_metadata or not snapshot.is_model or not snapshot.evaluatable:
if (
not snapshot.is_metadata
or not snapshot.is_model
or not snapshot.evaluatable
or not snapshot.previous_version
):
continue

metadata_snapshots.append(snapshot)

# Bulk load all the previous snapshots
previous_snapshots = self.state_reader.get_snapshots(
[
s.previous_version.snapshot_id(s.name)
for s in metadata_snapshots
if s.previous_version
]
).values()
previous_snapshot_ids = [
s.previous_version.snapshot_id(s.name) for s in metadata_snapshots if s.previous_version
]
previous_snapshots = {
s.name: s for s in self.state_reader.get_snapshots(previous_snapshot_ids).values()
}

# Check if any of the snapshots have modifications to the audits field by comparing the hashes
audit_snapshots = {}
for snapshot, previous_snapshot in zip(metadata_snapshots, previous_snapshots):
for snapshot in metadata_snapshots:
if snapshot.name not in previous_snapshots:
continue

previous_snapshot = previous_snapshots[snapshot.name]
new_audits_hash = snapshot.model.audit_metadata_hash()
previous_audit_hash = previous_snapshot.model.audit_metadata_hash()

Expand Down
3 changes: 3 additions & 0 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,9 @@ def _dag(
dag = DAG[SchedulingUnit]()

for snapshot_id in snapshot_dag:
if snapshot_id.name not in self.snapshots_by_name:
continue

snapshot = self.snapshots_by_name[snapshot_id.name]
intervals = intervals_per_snapshot.get(snapshot.name, [])

Expand Down
25 changes: 25 additions & 0 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6294,6 +6294,31 @@ def test_restatement_shouldnt_backfill_beyond_prod_intervals(init_and_plan_conte
].intervals[-1][1] == to_timestamp("2023-01-08 00:00:00 UTC")


@time_machine.travel("2023-01-08 15:00:00 UTC")
@use_terminal_console
def test_audit_only_metadata_change(init_and_plan_context: t.Callable):
context, plan = init_and_plan_context("examples/sushi")
context.apply(plan)

# Add a new audit
model = context.get_model("sushi.waiter_revenue_by_day")
audits = model.audits.copy()
audits.append(("number_of_rows", {"threshold": exp.Literal.number(1)}))
model = model.copy(update={"audits": audits})
context.upsert_model(model)

plan = context.plan_builder("prod", skip_tests=True).build()
assert len(plan.new_snapshots) == 2
assert all(s.change_category.is_metadata for s in plan.new_snapshots)
assert not plan.missing_intervals

with capture_output() as output:
context.apply(plan)

assert "Auditing models" in output.stdout
assert model.name in output.stdout


def initial_add(context: Context, environment: str):
assert not context.state_reader.get_environment(environment)

Expand Down