Skip to content

Commit

Permalink
Merge pull request #366 from centerofci/column_type_targets
Browse files Browse the repository at this point in the history
Add column type targets to API
  • Loading branch information
mathemancer authored Jul 20, 2021
2 parents 0211da8 + 6807e26 commit 4163144
Show file tree
Hide file tree
Showing 11 changed files with 447 additions and 140 deletions.
55 changes: 50 additions & 5 deletions db/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from alembic.migration import MigrationContext
from alembic.operations import Operations
from sqlalchemy import (
Column, Integer, ForeignKey, Table, MetaData, and_, select
Column, Integer, ForeignKey, Table, MetaData, and_, select, inspect
)
from db import constants, tables
from db.types import alteration
Expand Down Expand Up @@ -47,7 +47,9 @@ def __init__(
Optional keyword arguments:
primary_key -- Boolean giving whether the column is a primary key.
nullable -- Boolean giving whether the column is nullable.
"""
self.engine = None
super().__init__(
*foreign_keys,
name=name,
Expand Down Expand Up @@ -82,6 +84,46 @@ def is_default(self):
and self.nullable == default_def.get(NULLABLE, True)
)

def add_engine(self, engine):
self.engine = engine

@property
def valid_target_types(self):
"""
Returns a set of valid types to which the type of the column can be
altered.
"""
if self.engine is not None and not self.is_default:
db_type = self.type.compile(dialect=self.engine.dialect)
valid_target_types = sorted(
list(
set(
alteration.get_full_cast_map(self.engine).get(db_type, [])
)
)
)
return valid_target_types if valid_target_types else None

@property
def column_index(self):
"""
Get the ordinal index of this column in its table, if it is
attached to a table that is associated with the column's engine.
"""
if (
self.engine is not None
and self.table is not None
and inspect(self.engine).has_table(self.table.name, schema=self.table.schema)
):
table_oid = tables.get_oid_from_table(
self.table.name, self.table.schema, self.engine
)
return get_column_index_from_name(
table_oid,
self.name,
self.engine
)


def get_default_mathesar_column_list():
return [
Expand Down Expand Up @@ -117,11 +159,13 @@ def get_column_index_from_name(table_oid, column_name, engine):
def create_column(engine, table_oid, column_data):
column_type = column_data[TYPE]
column_nullable = column_data.get(NULLABLE, True)
supported_types = alteration.get_supported_alter_column_types(engine)
sa_type = supported_types.get(column_type.lower())
supported_types = alteration.get_supported_alter_column_types(
engine, friendly_names=False,
)
sa_type = supported_types.get(column_type)
if sa_type is None:
logger.warning("Requested type not supported. falling back to String")
sa_type = supported_types[alteration.STRING]
logger.warning("Requested type not supported. falling back to VARCHAR")
sa_type = supported_types["VARCHAR"]
table = tables.reflect_table_from_oid(table_oid, engine)
column = MathesarColumn(
column_data[NAME], sa_type, nullable=column_nullable,
Expand Down Expand Up @@ -175,6 +219,7 @@ def retype_column(table_oid, column_index, new_type, engine):
table.columns[column_index].name,
new_type,
engine,
friendly_names=False,
)
return tables.reflect_table_from_oid(table_oid, engine).columns[column_index]

Expand Down
15 changes: 15 additions & 0 deletions db/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,21 @@ def reflect_table_from_oid(oid, engine):
return reflect_table(table_name, schema, engine)


def get_enriched_column_table(raw_sa_table, engine=None):
table_columns = [
columns.MathesarColumn.from_column(c) for c in raw_sa_table.columns
]
if engine is not None:
for col in table_columns:
col.add_engine(engine)
return Table(
raw_sa_table.name,
MetaData(),
*table_columns,
schema=raw_sa_table.schema
)


def get_table_oids_from_schema(schema_oid, engine):
metadata = MetaData()

Expand Down
64 changes: 62 additions & 2 deletions db/tests/test_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
)
from sqlalchemy.exc import IntegrityError
from db import columns, tables, constants
from db.tests.types import fixtures

engine_with_types = fixtures.engine_with_types


def init_column(*args, **kwargs):
Expand Down Expand Up @@ -144,6 +147,62 @@ def test_MC_is_default_when_false_for_pk():
assert not col.is_default


def test_MC_valid_target_types_no_engine():
mc = columns.MathesarColumn('testable_col', String)
assert mc.valid_target_types is None


def test_MC_valid_target_types_default_engine(engine):
mc = columns.MathesarColumn('testable_col', String)
mc.add_engine(engine)
assert "VARCHAR" in mc.valid_target_types


def test_MC_valid_target_types_custom_engine(engine_with_types):
mc = columns.MathesarColumn('testable_col', String)
mc.add_engine(engine_with_types)
assert "mathesar_types.email" in mc.valid_target_types


def test_MC_column_index_when_no_engine():
mc = columns.MathesarColumn('testable_col', String)
assert mc.column_index is None


def test_MC_column_index_when_no_table(engine):
mc = columns.MathesarColumn('testable_col', String)
mc.add_engine(engine)
assert mc.column_index is None


def test_MC_column_index_when_no_db_table(engine):
mc = columns.MathesarColumn('testable_col', String)
mc.add_engine(engine)
table = Table('atable', MetaData(), mc)
assert mc.table == table and mc.column_index is None


def test_MC_column_index_single(engine_with_schema):
engine, schema = engine_with_schema
mc = columns.MathesarColumn('testable_col', String)
mc.add_engine(engine)
metadata = MetaData(bind=engine, schema=schema)
Table('asupertable', metadata, mc).create()
assert mc.column_index == 0


def test_MC_column_index_multiple(engine_with_schema):
engine, schema = engine_with_schema
mc_1 = columns.MathesarColumn('testable_col', String)
mc_2 = columns.MathesarColumn('testable_col2', String)
mc_1.add_engine(engine)
mc_2.add_engine(engine)
metadata = MetaData(bind=engine, schema=schema)
Table('asupertable', metadata, mc_1, mc_2).create()
assert mc_1.column_index == 0
assert mc_2.column_index == 1


@pytest.mark.parametrize(
"column_dict,func_name",
[
Expand Down Expand Up @@ -272,14 +331,15 @@ def test_retype_column_correct_column(engine_with_schema):
table_name,
target_column_name,
"boolean",
engine
engine,
friendly_names=False,
)


def test_create_column(engine_with_schema):
engine, schema = engine_with_schema
table_name = "atableone"
target_type = "boolean"
target_type = "BOOLEAN"
initial_column_name = "original_column"
new_column_name = "added_column"
table = Table(
Expand Down
15 changes: 14 additions & 1 deletion db/tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from sqlalchemy import MetaData, select, Column, String, Table, ForeignKey, Integer
from sqlalchemy.exc import NoSuchTableError
from psycopg2.errors import DependentObjectsStillExist

from db import tables, constants, columns

ROSTER = "Roster"
Expand Down Expand Up @@ -432,6 +431,20 @@ def test_move_columns_moves_column_from_rem_to_ext(extracted_remainder_roster):
assert sorted(actual_remainder_cols) == sorted(expect_remainder_cols)


def test_get_enriched_column_table(engine):
abc = "abc"
table = Table("testtable", MetaData(), Column(abc, String), Column('def', String))
enriched_table = tables.get_enriched_column_table(table, engine=engine)
assert enriched_table.columns[abc].engine == engine


def test_get_enriched_column_table_no_engine():
abc = "abc"
table = Table("testtable", MetaData(), Column(abc, String), Column('def', String))
enriched_table = tables.get_enriched_column_table(table)
assert enriched_table.columns[abc].engine is None


def test_infer_table_column_types_doesnt_touch_defaults(engine_with_schema):
column_list = []
engine, schema = engine_with_schema
Expand Down
35 changes: 35 additions & 0 deletions db/tests/types/test_alteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ def test_get_alter_column_types_with_custom_engine(engine_with_types):
)


def test_get_alter_column_types_with_unfriendly_names(engine_with_types):
type_dict = alteration.get_supported_alter_column_types(
engine_with_types, friendly_names=False
)
assert all(
[
type_dict[type_]().compile(dialect=engine_with_types.dialect) == type_
for type_ in type_dict
]
)


type_test_list = [
(String, "boolean", "BOOLEAN"),
(String, "interval", "INTERVAL"),
Expand Down Expand Up @@ -165,3 +177,26 @@ def test_alter_column_type_raises_on_bad_column_data(
alteration.alter_column_type(
schema, TABLE_NAME, COLUMN_NAME, target_type, engine,
)


def test_get_full_cast_map(engine_with_types):
"""
This test specifies the full map of what types can be cast to what
target types in Mathesar. When the map is modified, this test
should be updated accordingly.
"""
expect_cast_map = {
'NUMERIC': ['BOOLEAN', 'NUMERIC', 'VARCHAR'],
'VARCHAR': ['NUMERIC', 'VARCHAR', 'INTERVAL', 'mathesar_types.email', 'BOOLEAN'],
'mathesar_types.email': ['mathesar_types.email', 'VARCHAR'],
'INTERVAL': ['INTERVAL', 'VARCHAR'],
'BOOLEAN': ['NUMERIC', 'BOOLEAN', 'VARCHAR']
}
actual_cast_map = alteration.get_full_cast_map(engine_with_types)
assert len(actual_cast_map) == len(expect_cast_map)
assert all(
[
sorted(actual_cast_map[type_]) == sorted(expect_target_list)
for type_, expect_target_list in expect_cast_map.items()
]
)
Loading

0 comments on commit 4163144

Please sign in to comment.