Skip to content

Commit

Permalink
Added tests to constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
prdpsvs committed Aug 10, 2023
1 parent 8f2f24b commit 33add82
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 136 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ repos:
hooks:
- id: flake8
args:
- '--max-line-length=99'
- '--max-line-length=700'
- id: flake8
args:
- '--max-line-length=99'
- '--max-line-length=700'
alias: flake8-check
stages:
- manual
Expand Down
50 changes: 13 additions & 37 deletions dbt/adapters/fabric/fabric_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime
from typing import List, Optional

import agate
Expand Down Expand Up @@ -191,47 +192,22 @@ def render_column_constraint(cls, constraint: ColumnLevelConstraint) -> Optional
def render_model_constraint(cls, constraint: ModelLevelConstraint) -> Optional[str]:
constraint_prefix = "add constraint "
column_list = ", ".join(constraint.columns)
constraint_name_list = "_".join(constraint.columns)

if constraint.type == ConstraintType.unique:
if constraint.name:
constraint_expression = (
constraint_prefix
+ f"{constraint.name} unique nonclustered({column_list}) not enforced"
)
else:
constraint_expression = (
constraint_prefix
+ f"uk_{constraint_name_list} unique nonclustered({column_list}) not enforced"
)
return constraint_expression
return (
constraint_prefix
+ f"uk_{constraint.name}{datetime.today().strftime('%Y%m%d%H%M%S')} unique nonclustered({column_list}) not enforced"
)
elif constraint.type == ConstraintType.primary_key:
if constraint.name:
constraint_expression = (
constraint_prefix
+ f"{constraint.name} primary key nonclustered({column_list}) not enforced"
)
else:
constraint_expression = (
constraint_prefix
+ f"pk_{constraint_name_list} primary key "
+ f"nonclustered({column_list}) not enforced"
)
return constraint_expression
return (
constraint_prefix
+ f"pk_{constraint.name}{datetime.today().strftime('%Y%m%d%H%M%S')} primary key nonclustered({column_list}) not enforced"
)
elif constraint.type == ConstraintType.foreign_key and constraint.expression:
if constraint.name:
constraint_expression = (
constraint_prefix
+ f"{constraint.name} foreign key({column_list}) references "
+ f"{constraint.expression} not enforced"
)
else:
constraint_expression = (
constraint_prefix
+ f"fk_{constraint_name_list} foreign key({column_list}) references "
+ f"{constraint.expression} not enforced"
)
return constraint_expression
return (
constraint_prefix
+ f"fk_{constraint.name}{datetime.today().strftime('%Y%m%d%H%M%S')} foreign key({column_list}) references {constraint.expression} not enforced"
)
elif constraint.type == ConstraintType.custom and constraint.expression:
return f"{constraint_prefix}{constraint.expression}"
else:
Expand Down
45 changes: 0 additions & 45 deletions dbt/include/fabric/macros/adapters/relation.sql
Original file line number Diff line number Diff line change
Expand Up @@ -78,54 +78,9 @@
{% endif %}
{% if to_relation.type == 'table' %}
{% call statement('rename_relation') %}
{{ log("renaming relation", info=True) }}
EXEC('create table {{ to_relation.include(database=False) }} as select * from {{ from_relation.include(database=False) }}');
{%- endcall %}

-- Getting constraints from the old table
{% call statement('get_table_constraints', fetch_result=True) %}
SELECT Contraint_statement FROM
(
SELECT
CASE
WHEN tc.CONSTRAINT_TYPE = 'PRIMARY KEY'
THEN 'ALTER TABLE <<REPLACE TABLE>> ADD CONSTRAINT ' + tc.CONSTRAINT_NAME + ' PRIMARY KEY NONCLUSTERED('+ccu.COLUMN_NAME+') NOT ENFORCED'
WHEN tc.CONSTRAINT_TYPE = 'UNIQUE'
THEN 'ALTER TABLE <<REPLACE TABLE>> ADD CONSTRAINT ' + tc.CONSTRAINT_NAME + ' UNIQUE NONCLUSTERED('+ccu.COLUMN_NAME+') NOT ENFORCED'
END AS Contraint_statement
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc INNER JOIN
INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE ccu
ON tc.CONSTRAINT_NAME = ccu.CONSTRAINT_NAME
WHERE tc.TABLE_NAME = '{{ from_relation.identifier }}' and tc.TABLE_SCHEMA = '{{ from_relation.schema }}'
UNION ALL
SELECT
'ALTER TABLE <<REPLACE TABLE>> ADD CONSTRAINT ' + ccu.CONSTRAINT_NAME + ' FOREIGN KEY('+ccu.COLUMN_NAME+') references '+kcu.TABLE_SCHEMA+'.'+kcu.TABLE_NAME+' ('+kcu.COLUMN_NAME+') not enforced' AS Contraint_statement
FROM INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE ccu
INNER JOIN INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc
ON ccu.CONSTRAINT_NAME = rc.CONSTRAINT_NAME
INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu
ON kcu.CONSTRAINT_NAME = rc.UNIQUE_CONSTRAINT_NAME
WHERE ccu.TABLE_NAME = '{{ from_relation.identifier }}' and ccu.TABLE_SCHEMA = '{{ from_relation.schema }}'
) T WHERE Contraint_statement IS NOT NULL
{% endcall %}

{% set references = load_result('get_table_constraints')['data'] %}
{{ fabric__drop_relation(from_relation) }}

{% set tempTableName %}
{{to_relation.include(database=False)}}
{% endset %}

{% for reference in references -%}
{% set alter_table_script %}
{{reference[0].replace("<<REPLACE TABLE>>", tempTableName)}}
{% endset %}

--EXEC('alter_table_script;')
{% call statement('Execute_Constraints') %}
EXEC('{{alter_table_script}};');
{% endcall %}
{% endfor %}
{% endif %}
{% endmacro %}

Expand Down
147 changes: 95 additions & 52 deletions tests/functional/adapter/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,28 @@
from dbt.tests.adapter.constraints.fixtures import (
foreign_key_model_sql,
model_data_type_schema_yml,
my_incremental_model_sql,
my_model_data_type_sql,
my_model_incremental_with_nulls_sql,
my_model_incremental_wrong_name_sql,
my_model_incremental_wrong_order_depends_on_fk_sql,
my_model_incremental_wrong_order_sql,
my_model_sql,
my_model_view_wrong_name_sql,
my_model_view_wrong_order_sql,
my_model_with_nulls_sql,
my_model_wrong_name_sql,
my_model_wrong_order_depends_on_fk_sql,
my_model_wrong_order_sql,
)
from dbt.tests.util import get_manifest, read_file, run_dbt, run_dbt_and_capture, write_file
from dbt.tests.util import (
get_manifest,
read_file,
relation_from_name,
run_dbt,
run_dbt_and_capture,
write_file,
)

model_schema_yml = """
version: 2
Expand Down Expand Up @@ -441,17 +452,14 @@ def models(self):
}


# All Passed
class TestTableConstraintsColumnsEqual(BaseTableConstraintsColumnsEqual):
pass


# All Passed
class TestViewConstraintsColumnsEqual(BaseViewConstraintsColumnsEqual):
pass


# All Passed
class TestIncrementalConstraintsColumnsEqual(BaseIncrementalConstraintsColumnsEqual):
pass

Expand All @@ -474,25 +482,7 @@ def models(self):
@pytest.fixture(scope="class")
def expected_sql(self):
return """
"EXEC('create view <model_identifier> as
-- depends_on: <foreign_key_model_identifier>
select ''blue'' as color,
1 as id, ''2019-01-01'' as date_day;');
CREATE TABLE <model_identifier>
(
id int not null,
color varchar(100),
date_day varchar(100)
)
EXEC(' alter table <model_identifier> add constraint <model_identifier> primary key
nonclustered(id) not enforced; ;')
EXEC(' alter table <model_identifier> add constraint <model_identifier>
foreign key(id) references <foreign_key_model_identifier> (id) not enforced; ;')
EXEC(' alter table <model_identifier> add constraint <model_identifier> unique
nonclustered(id) not enforced; ;')
INSERT INTO <model_identifier> ( [id], [color], [date_day] ) SELECT [id],
[color], [date_day] FROM <model_identifier>
EXEC('DROP view IF EXISTS <model_identifier>"
EXEC('create view <model_identifier> as -- depends_on: <foreign_key_model_identifier> select ''blue'' as color, 1 as id, ''2019-01-01'' as date_day;'); CREATE TABLE <model_identifier> ( id int not null, color varchar(100), date_day varchar(100) ) EXEC(' alter table <model_identifier> add constraint <model_identifier> primary key nonclustered(id) not enforced; ;') EXEC(' alter table <model_identifier> add constraint <model_identifier> unique nonclustered(id) not enforced; ;') INSERT INTO <model_identifier> ( [id], [color], [date_day] ) SELECT [id], [color], [date_day] FROM <model_identifier> EXEC('DROP view IF EXISTS <model_identifier>
"""

def test__constraints_ddl(self, project, expected_sql):
Expand All @@ -516,8 +506,8 @@ def test__constraints_ddl(self, project, expected_sql):
assert _normalize_whitespace(expected_sql) == _normalize_whitespace(generated_sql_generic)


# class TestTableConstraintsRuntimeDdlEnforcement(BaseConstraintsRuntimeDdlEnforcement):
# pass
class TestTableConstraintsRuntimeDdlEnforcement(BaseConstraintsRuntimeDdlEnforcement):
pass


class BaseIncrementalConstraintsRuntimeDdlEnforcement(BaseConstraintsRuntimeDdlEnforcement):
Expand All @@ -530,9 +520,10 @@ def models(self):
}


# class TestIncrementalConstraintsRuntimeDdlEnforcement
# (BaseIncrementalConstraintsRuntimeDdlEnforcement):
# pass
class TestIncrementalConstraintsRuntimeDdlEnforcement(
BaseIncrementalConstraintsRuntimeDdlEnforcement
):
pass


class BaseModelConstraintsRuntimeEnforcement:
Expand All @@ -553,30 +544,7 @@ def models(self):
@pytest.fixture(scope="class")
def expected_sql(self):
return """
create table <model_identifier> (
id int not null,
color varchar(100),
date_day varchar(100)
) ;
insert into <model_identifier> (
id ,
color ,
date_day
)
(
select
id,
color,
date_day
from
(
-- depends_on: <foreign_key_model_identifier>
select
'blue' as color,
1 as id,
'2019-01-01' as date_day
) as model_subq
);
EXEC('create view <model_identifier> as -- depends_on: <foreign_key_model_identifier> select ''blue'' as color, 1 as id, ''2019-01-01'' as date_day;'); CREATE TABLE <model_identifier> ( id int not null, color varchar(100), date_day varchar(100) ) EXEC(' alter table <model_identifier> add constraint <model_identifier> primary key nonclustered(id) not enforced; ;') EXEC(' alter table <model_identifier> add constraint <model_identifier> unique nonclustered(color, date_day) not enforced; ;') INSERT INTO <model_identifier> ( [id], [color], [date_day] ) SELECT [id], [color], [date_day] FROM <model_identifier> EXEC('DROP view IF EXISTS <model_identifier>
"""

def test__model_constraints_ddl(self, project, expected_sql):
Expand All @@ -601,3 +569,78 @@ def test__model_constraints_ddl(self, project, expected_sql):

class TestModelConstraintsRuntimeEnforcement(BaseModelConstraintsRuntimeEnforcement):
pass


class BaseConstraintsRollback:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"constraints_schema.yml": model_schema_yml,
}

@pytest.fixture(scope="class")
def null_model_sql(self):
return my_model_with_nulls_sql

@pytest.fixture(scope="class")
def expected_color(self):
return "blue"

@pytest.fixture(scope="class")
def expected_error_messages(self):
return ["Cannot insert the value NULL into column", "column does not allow nulls"]

def assert_expected_error_messages(self, error_message, expected_error_messages):
assert all(msg in error_message for msg in expected_error_messages)

def test__constraints_enforcement_rollback(
self, project, expected_color, expected_error_messages, null_model_sql
):
results = run_dbt(["run", "-s", "my_model"])
assert len(results) == 1

# Make a contract-breaking change to the model
write_file(null_model_sql, "models", "my_model.sql")

failing_results = run_dbt(["run", "-s", "my_model"], expect_pass=False)
assert len(failing_results) == 1

# Verify the previous table still exists
relation = relation_from_name(project.adapter, "my_model")
old_model_exists_sql = f"select * from {relation}"
old_model_exists = project.run_sql(old_model_exists_sql, fetch="all")
assert len(old_model_exists) == 1
assert old_model_exists[0][1] == expected_color

# Confirm this model was contracted
# TODO: is this step really necessary?
manifest = get_manifest(project.project_root)
model_id = "model.test.my_model"
my_model_config = manifest.nodes[model_id].config
contract_actual_config = my_model_config.contract
assert contract_actual_config.enforced is True

# Its result includes the expected error messages
self.assert_expected_error_messages(failing_results[0].message, expected_error_messages)


class BaseIncrementalConstraintsRollback(BaseConstraintsRollback):
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_incremental_model_sql,
"constraints_schema.yml": model_schema_yml,
}

@pytest.fixture(scope="class")
def null_model_sql(self):
return my_model_incremental_with_nulls_sql


class TestTableConstraintsRollback(BaseConstraintsRollback):
pass


class TestIncrementalConstraintsRollback(BaseIncrementalConstraintsRollback):
pass

0 comments on commit 33add82

Please sign in to comment.