From 3da93cb7d135772f481288d3ead05c6c7942e7b7 Mon Sep 17 00:00:00 2001 From: silentninja Date: Tue, 31 May 2022 05:28:15 +0400 Subject: [PATCH] Refactor create many to many links function to use `create_mathesar_table` function --- db/links/operations/create.py | 27 ++++-------- mathesar/tests/api/test_links.py | 73 +++++++++++++++++++++++++++----- 2 files changed, 71 insertions(+), 29 deletions(-) diff --git a/db/links/operations/create.py b/db/links/operations/create.py index ecaa033829..e2d6d654bb 100644 --- a/db/links/operations/create.py +++ b/db/links/operations/create.py @@ -1,9 +1,10 @@ from alembic.operations import Operations from alembic.migration import MigrationContext -from sqlalchemy import MetaData +from sqlalchemy import ForeignKey, MetaData from db.columns.base import MathesarColumn from db.constraints.utils import naming_convention +from db.tables.operations.create import create_mathesar_table from db.tables.operations.select import reflect_table_from_oid, reflect_tables_from_oids from db.tables.utils import get_primary_key_column @@ -48,27 +49,17 @@ def create_many_to_many_link(engine, schema, map_table_name, referents): referent_tables_oid = [referent['referent_table'] for referent in referents] referent_tables = reflect_tables_from_oids(referent_tables_oid, engine, conn) metadata = MetaData(bind=engine, schema=schema, naming_convention=naming_convention) - opts = { - 'target_metadata': metadata - } - ctx = MigrationContext.configure(conn, opts=opts) - op = Operations(ctx) - op.create_table(map_table_name, schema=schema) + # Throws sqlalchemy.exc.NoReferencedTableError if metadata is not reflected. + metadata.reflect() + referrer_columns = [] for referent in referents: referent_table_oid = referent['referent_table'] referent_table = referent_tables[referent_table_oid] col_name = referent['column_name'] primary_key_column = get_primary_key_column(referent_table) + foreign_keys = {ForeignKey(primary_key_column)} column = MathesarColumn( - col_name, primary_key_column.type - ) - op.add_column(map_table_name, column, schema=schema) - op.create_foreign_key( - None, - map_table_name, - referent_table.name, - [column.name], - [primary_key_column.name], - source_schema=schema, - referent_schema=schema + col_name, primary_key_column.type, foreign_keys=foreign_keys, ) + referrer_columns.append(column) + create_mathesar_table(map_table_name, schema, referrer_columns, engine, metadata) diff --git a/mathesar/tests/api/test_links.py b/mathesar/tests/api/test_links.py index d6f7eea3ed..dbc143e984 100644 --- a/mathesar/tests/api/test_links.py +++ b/mathesar/tests/api/test_links.py @@ -3,13 +3,12 @@ from sqlalchemy import Table as SATable from django.core.cache import cache +from db.constraints.utils import ConstraintType from db.tables.operations.select import get_oid_from_table -from db.tests.types import fixtures -from mathesar import models +from db.tables.utils import get_primary_key_column -engine_with_types = fixtures.engine_with_types -engine_email_type = fixtures.engine_email_type -temporary_testing_schema = fixtures.temporary_testing_schema +from mathesar import models +from mathesar.models import Constraint, Table @pytest.fixture @@ -33,9 +32,9 @@ def column_test_table(patent_schema): return table -def test_one_to_one_link_create(column_test_table, client, create_table): +def test_one_to_one_link_create(column_test_table, client, create_patents_table): cache.clear() - table_2 = create_table('Table 2') + table_2 = create_patents_table('Table 2') data = { "link_type": "one-to-one", "reference_column_name": "col_1", @@ -47,11 +46,31 @@ def test_one_to_one_link_create(column_test_table, client, create_table): data=data, ) assert response.status_code == 201 + constraints = Constraint.objects.filter(table=table_2) + assert constraints.count() == 3 + + unique_constraint = next( + constraint + for constraint in constraints + if constraint.type == ConstraintType.UNIQUE.value + ) + fk_constraint = next( + constraint + for constraint in constraints + if constraint.type == ConstraintType.FOREIGN_KEY.value + ) + unique_constraint_columns = list(unique_constraint.columns.all()) + fk_constraint_columns = list(fk_constraint.columns.all()) + referent_columns = list(fk_constraint.referent_columns.all()) + assert unique_constraint_columns == table_2.get_columns_by_name(['col_1']) + assert fk_constraint_columns == table_2.get_columns_by_name(['col_1']) + referent_primary_key_column_name = get_primary_key_column(column_test_table._sa_table).name + assert referent_columns == column_test_table.get_columns_by_name([referent_primary_key_column_name]) -def test_one_to_many_link_create(column_test_table, client, create_table): +def test_one_to_many_link_create(column_test_table, client, create_patents_table): cache.clear() - table_2 = create_table('Table 2') + table_2 = create_patents_table('Table 2') data = { "link_type": "one-to-many", "reference_column_name": "col_1", @@ -63,6 +82,19 @@ def test_one_to_many_link_create(column_test_table, client, create_table): data=data, ) assert response.status_code == 201 + constraints = Constraint.objects.filter(table=table_2) + assert constraints.count() == 2 + + fk_constraint = next( + constraint + for constraint in constraints + if constraint.type == ConstraintType.FOREIGN_KEY.value + ) + fk_constraint_columns = list(fk_constraint.columns.all()) + referent_columns = list(fk_constraint.referent_columns.all()) + assert fk_constraint_columns == table_2.get_columns_by_name(['col_1']) + referent_primary_key_column_name = get_primary_key_column(column_test_table._sa_table).name + assert referent_columns == column_test_table.get_columns_by_name([referent_primary_key_column_name]) def test_one_to_many_self_referential_link_create(column_test_table, client): @@ -78,9 +110,24 @@ def test_one_to_many_self_referential_link_create(column_test_table, client): data=data, ) assert response.status_code == 201 + constraints = Constraint.objects.filter(table=column_test_table) + assert constraints.count() == 2 + + fk_constraint = next( + constraint + for constraint in constraints + if constraint.type == ConstraintType.FOREIGN_KEY.value + ) + fk_constraint_columns = list(fk_constraint.columns.all()) + referent_columns = list(fk_constraint.referent_columns.all()) + assert fk_constraint_columns == column_test_table.get_columns_by_name(['col_1']) + referent_primary_key_column_name = get_primary_key_column(column_test_table._sa_table).name + assert referent_columns == column_test_table.get_columns_by_name([referent_primary_key_column_name]) def test_many_to_many_self_referential_link_create(column_test_table, client): + schema = column_test_table.schema + engine = schema._sa_engine cache.clear() data = { "link_type": "many-to-many", @@ -95,11 +142,15 @@ def test_many_to_many_self_referential_link_create(column_test_table, client): data=data, ) assert response.status_code == 201 + map_table_oid = get_oid_from_table("map_table", schema.name, engine) + map_table = Table.objects.get(oid=map_table_oid) + constraints = Constraint.objects.filter(table=map_table) + assert constraints.count() == 3 -def test_many_to_many_link_create(column_test_table, client, create_table): +def test_many_to_many_link_create(column_test_table, client, create_patents_table): cache.clear() - table_2 = create_table('Table 2') + table_2 = create_patents_table('Table 2') data = { "link_type": "many-to-many", "mapping_table_name": "map_table",