Skip to content

Commit

Permalink
Refactor create many to many links function to use `create_mathesar_t…
Browse files Browse the repository at this point in the history
…able` function
  • Loading branch information
silentninja committed May 31, 2022
1 parent ab1084d commit 3da93cb
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 29 deletions.
27 changes: 9 additions & 18 deletions db/links/operations/create.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from alembic.operations import Operations
from alembic.migration import MigrationContext
from sqlalchemy import MetaData
from sqlalchemy import ForeignKey, MetaData

from db.columns.base import MathesarColumn
from db.constraints.utils import naming_convention
from db.tables.operations.create import create_mathesar_table
from db.tables.operations.select import reflect_table_from_oid, reflect_tables_from_oids
from db.tables.utils import get_primary_key_column

Expand Down Expand Up @@ -48,27 +49,17 @@ def create_many_to_many_link(engine, schema, map_table_name, referents):
referent_tables_oid = [referent['referent_table'] for referent in referents]
referent_tables = reflect_tables_from_oids(referent_tables_oid, engine, conn)
metadata = MetaData(bind=engine, schema=schema, naming_convention=naming_convention)
opts = {
'target_metadata': metadata
}
ctx = MigrationContext.configure(conn, opts=opts)
op = Operations(ctx)
op.create_table(map_table_name, schema=schema)
# Throws sqlalchemy.exc.NoReferencedTableError if metadata is not reflected.
metadata.reflect()
referrer_columns = []
for referent in referents:
referent_table_oid = referent['referent_table']
referent_table = referent_tables[referent_table_oid]
col_name = referent['column_name']
primary_key_column = get_primary_key_column(referent_table)
foreign_keys = {ForeignKey(primary_key_column)}
column = MathesarColumn(
col_name, primary_key_column.type
)
op.add_column(map_table_name, column, schema=schema)
op.create_foreign_key(
None,
map_table_name,
referent_table.name,
[column.name],
[primary_key_column.name],
source_schema=schema,
referent_schema=schema
col_name, primary_key_column.type, foreign_keys=foreign_keys,
)
referrer_columns.append(column)
create_mathesar_table(map_table_name, schema, referrer_columns, engine, metadata)
73 changes: 62 additions & 11 deletions mathesar/tests/api/test_links.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
from sqlalchemy import Table as SATable
from django.core.cache import cache

from db.constraints.utils import ConstraintType
from db.tables.operations.select import get_oid_from_table
from db.tests.types import fixtures
from mathesar import models
from db.tables.utils import get_primary_key_column

engine_with_types = fixtures.engine_with_types
engine_email_type = fixtures.engine_email_type
temporary_testing_schema = fixtures.temporary_testing_schema
from mathesar import models
from mathesar.models import Constraint, Table


@pytest.fixture
Expand All @@ -33,9 +32,9 @@ def column_test_table(patent_schema):
return table


def test_one_to_one_link_create(column_test_table, client, create_table):
def test_one_to_one_link_create(column_test_table, client, create_patents_table):
cache.clear()
table_2 = create_table('Table 2')
table_2 = create_patents_table('Table 2')
data = {
"link_type": "one-to-one",
"reference_column_name": "col_1",
Expand All @@ -47,11 +46,31 @@ def test_one_to_one_link_create(column_test_table, client, create_table):
data=data,
)
assert response.status_code == 201
constraints = Constraint.objects.filter(table=table_2)
assert constraints.count() == 3

unique_constraint = next(
constraint
for constraint in constraints
if constraint.type == ConstraintType.UNIQUE.value
)
fk_constraint = next(
constraint
for constraint in constraints
if constraint.type == ConstraintType.FOREIGN_KEY.value
)
unique_constraint_columns = list(unique_constraint.columns.all())
fk_constraint_columns = list(fk_constraint.columns.all())
referent_columns = list(fk_constraint.referent_columns.all())
assert unique_constraint_columns == table_2.get_columns_by_name(['col_1'])
assert fk_constraint_columns == table_2.get_columns_by_name(['col_1'])
referent_primary_key_column_name = get_primary_key_column(column_test_table._sa_table).name
assert referent_columns == column_test_table.get_columns_by_name([referent_primary_key_column_name])


def test_one_to_many_link_create(column_test_table, client, create_table):
def test_one_to_many_link_create(column_test_table, client, create_patents_table):
cache.clear()
table_2 = create_table('Table 2')
table_2 = create_patents_table('Table 2')
data = {
"link_type": "one-to-many",
"reference_column_name": "col_1",
Expand All @@ -63,6 +82,19 @@ def test_one_to_many_link_create(column_test_table, client, create_table):
data=data,
)
assert response.status_code == 201
constraints = Constraint.objects.filter(table=table_2)
assert constraints.count() == 2

fk_constraint = next(
constraint
for constraint in constraints
if constraint.type == ConstraintType.FOREIGN_KEY.value
)
fk_constraint_columns = list(fk_constraint.columns.all())
referent_columns = list(fk_constraint.referent_columns.all())
assert fk_constraint_columns == table_2.get_columns_by_name(['col_1'])
referent_primary_key_column_name = get_primary_key_column(column_test_table._sa_table).name
assert referent_columns == column_test_table.get_columns_by_name([referent_primary_key_column_name])


def test_one_to_many_self_referential_link_create(column_test_table, client):
Expand All @@ -78,9 +110,24 @@ def test_one_to_many_self_referential_link_create(column_test_table, client):
data=data,
)
assert response.status_code == 201
constraints = Constraint.objects.filter(table=column_test_table)
assert constraints.count() == 2

fk_constraint = next(
constraint
for constraint in constraints
if constraint.type == ConstraintType.FOREIGN_KEY.value
)
fk_constraint_columns = list(fk_constraint.columns.all())
referent_columns = list(fk_constraint.referent_columns.all())
assert fk_constraint_columns == column_test_table.get_columns_by_name(['col_1'])
referent_primary_key_column_name = get_primary_key_column(column_test_table._sa_table).name
assert referent_columns == column_test_table.get_columns_by_name([referent_primary_key_column_name])


def test_many_to_many_self_referential_link_create(column_test_table, client):
schema = column_test_table.schema
engine = schema._sa_engine
cache.clear()
data = {
"link_type": "many-to-many",
Expand All @@ -95,11 +142,15 @@ def test_many_to_many_self_referential_link_create(column_test_table, client):
data=data,
)
assert response.status_code == 201
map_table_oid = get_oid_from_table("map_table", schema.name, engine)
map_table = Table.objects.get(oid=map_table_oid)
constraints = Constraint.objects.filter(table=map_table)
assert constraints.count() == 3


def test_many_to_many_link_create(column_test_table, client, create_table):
def test_many_to_many_link_create(column_test_table, client, create_patents_table):
cache.clear()
table_2 = create_table('Table 2')
table_2 = create_patents_table('Table 2')
data = {
"link_type": "many-to-many",
"mapping_table_name": "map_table",
Expand Down

0 comments on commit 3da93cb

Please sign in to comment.