Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace filtering API; includes support for filtering numbers #1069

Merged
merged 23 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions db/functions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from sqlalchemy import column, not_, and_, or_, func, literal

from db.functions import hints
from db.functions.exceptions import BadDBFunctionFormat


# NOTE: this class is abstract.
class DBFunction(ABC):
id = None
name = None
Expand All @@ -30,13 +32,22 @@ class DBFunction(ABC):
# strings.
depends_on = None

def __eq__(self, other):
return (
isinstance(other, DBFunction)
and self.id == other.id
and self.parameters == other.parameters
)

def __init__(self, parameters):
if self.id is None:
raise ValueError('DBFunction subclasses must define an ID.')
if self.name is None:
raise ValueError('DBFunction subclasses must define a name.')
if self.depends_on is not None and not isinstance(self.depends_on, tuple):
raise ValueError('DBFunction subclasses\' depends_on attribute must either be None or a tuple of SQL function names.')
if not isinstance(parameters, list):
raise BadDBFunctionFormat('DBFunction instance parameter `parameters` must be a list.')
self.parameters = parameters

@property
Expand All @@ -45,7 +56,7 @@ def referenced_columns(self):
Useful when checking if all referenced columns are present in the queried relation."""
columns = set([])
for parameter in self.parameters:
if isinstance(parameter, ColumnReference):
if isinstance(parameter, ColumnName):
columns.add(parameter.column)
elif isinstance(parameter, DBFunction):
columns.update(parameter.referenced_columns)
Expand All @@ -70,9 +81,10 @@ def to_sa_expression(primitive):
return literal(primitive)


class ColumnReference(DBFunction):
id = 'column_reference'
name = 'as column Reference'
# This represents referencing columns by their Postgres name.
class ColumnName(DBFunction):
id = 'column_name'
name = 'as column name'
hints = tuple([
hints.parameter_count(1),
hints.parameter(1, hints.column),
Expand Down
8 changes: 6 additions & 2 deletions db/functions/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
class BadDBFunctionFormat(Exception):
class DBFunctionException(Exception):
pass


class UnknownDBFunctionId(BadDBFunctionFormat):
class BadDBFunctionFormat(DBFunctionException):
pass


class UnknownDBFunctionID(BadDBFunctionFormat):
pass


Expand Down
2 changes: 1 addition & 1 deletion db/functions/known_db_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def _get_module_members_that_satisfy(module, predicate):
def _is_concrete_db_function_subclass(member):
return (
inspect.isclass(member)
and member != DBFunction
and issubclass(member, DBFunction)
and not inspect.isabstract(member)
)


Expand Down
2 changes: 1 addition & 1 deletion db/functions/operations/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from db.functions.operations.deserialize import get_db_function_from_ma_function_spec


def apply_ma_function_spec_as_filter(relation, ma_function_spec):
def apply_db_function_spec_as_filter(relation, ma_function_spec):
db_function = get_db_function_from_ma_function_spec(ma_function_spec)
return apply_db_function_as_filter(relation, db_function)

Expand Down
58 changes: 44 additions & 14 deletions db/functions/operations/deserialize.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,41 @@
from db.functions.base import DBFunction, Literal, ColumnReference
from db.functions.base import DBFunction, Literal, ColumnName
from db.functions.known_db_functions import known_db_functions
from db.functions.exceptions import UnknownDBFunctionId, BadDBFunctionFormat
from db.functions.exceptions import UnknownDBFunctionID, BadDBFunctionFormat


def get_db_function_from_ma_function_spec(spec: dict) -> DBFunction:
"""
Expects a db function specification in the following format:

```
{"and": [
{"empty": [
{"column_name": ["some_column"]},
]},
{"equal": [
{"to_lowercase": [
{"column_name": ["some_string_like_column"]},
]},
{"literal": ["some_string_literal"]},
]},
]}
```

Every serialized DBFunction is a dict containing one key-value pair. The key is the DBFunction
id, and the value is always a list of parameters.
"""
try:
db_function_subclass_id = _get_first_dict_key(spec)
db_function_subclass_id, raw_parameters = get_raw_spec_components(spec)
db_function_subclass = _get_db_function_subclass_by_id(db_function_subclass_id)
raw_parameters = spec[db_function_subclass_id]
if not isinstance(raw_parameters, list):
raise BadDBFunctionFormat(
"The value in the function's key-value pair must be a list."
)
parameters = [
_process_parameter(
parameter=raw_parameter,
parent_db_function_subclass=db_function_subclass
parent_db_function_subclass=db_function_subclass,
)
for raw_parameter in raw_parameters
]
return db_function_subclass(parameters=parameters)
db_function = db_function_subclass(parameters=parameters)
return db_function
except (TypeError, KeyError) as e:
raise BadDBFunctionFormat from e

Expand All @@ -30,10 +46,11 @@ def _process_parameter(parameter, parent_db_function_subclass):
return get_db_function_from_ma_function_spec(parameter)
elif (
parent_db_function_subclass is Literal
or parent_db_function_subclass is ColumnReference
or parent_db_function_subclass is ColumnName
):
# Everything except for a dict is considered a literal parameter.
# And, only the Literal and ColumnReference DBFunctions can have a literal parameter.
# Everything except for a dict is considered a literal parameter,
# and only the Literal and ColumnName DBFunctions can have
# a literal parameter.
return parameter
else:
raise BadDBFunctionFormat(
Expand All @@ -45,7 +62,20 @@ def _get_db_function_subclass_by_id(subclass_id):
for db_function_subclass in known_db_functions:
if db_function_subclass.id == subclass_id:
return db_function_subclass
raise UnknownDBFunctionId
raise UnknownDBFunctionID(
f"DBFunction subclass with id {subclass_id} not found (or not"
+ "available on this DB)."
)


def get_raw_spec_components(spec):
db_function_subclass_id = _get_first_dict_key(spec)
raw_parameters = spec[db_function_subclass_id]
if not isinstance(raw_parameters, list):
raise BadDBFunctionFormat(
"The value in the function's key-value pair must be a list."
)
return db_function_subclass_id, raw_parameters


def _get_first_dict_key(dict):
Expand Down
133 changes: 71 additions & 62 deletions db/records/operations/select.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,16 @@
from sqlalchemy import select, func
from sqlalchemy_filters import apply_filters, apply_sort
from sqlalchemy_filters.exceptions import BadFilterFormat, FilterFieldNotFound
from sqlalchemy_filters import apply_sort

from db.functions.operations.apply import apply_db_function_spec_as_filter
from db.columns.base import MathesarColumn
from db.records.operations import group
from db.tables.utils import get_primary_key_column
from db.types.operations.cast import get_column_cast_expression
from db.utils import execute_query


DUPLICATE_LABEL = "_is_dupe"
CONJUNCTIONS = ("and", "or", "not")


def _validate_nested_ops(filters):
for op in filters:
if op.get("op") == "get_duplicates":
raise BadFilterFormat("get_duplicates can not be nested")
for field in CONJUNCTIONS:
if field in op:
_validate_nested_ops(op[field])


def _get_duplicate_only_cte(table, duplicate_columns):
DUPLICATE_LABEL = "_is_dupe"
duplicate_flag_cte = (
select(
*table.c,
Expand All @@ -32,45 +20,42 @@ def _get_duplicate_only_cte(table, duplicate_columns):
return select(duplicate_flag_cte).where(duplicate_flag_cte.c[DUPLICATE_LABEL]).cte()


def _get_duplicate_data_columns(table, filters):
try:
duplicate_ops = [f for f in filters if f.get("op") == "get_duplicates"]
non_duplicate_ops = [f for f in filters if f.get("op") != "get_duplicates"]
except AttributeError:
# Ignore formatting errors - they will be handled by sqlalchemy_filters
return None, filters

_validate_nested_ops(non_duplicate_ops)
if len(duplicate_ops) > 1:
raise BadFilterFormat("get_duplicates can only be specified a single time")
elif len(duplicate_ops) == 1:
duplicate_cols = duplicate_ops[0]['value']
for col in duplicate_cols:
if col not in table.c:
raise FilterFieldNotFound(f"Table {table.name} has no column `{col}`.")
return duplicate_ops[0]['value'], non_duplicate_ops
else:
return None, filters
def _sort_and_filter(query, order_by, filter):
if order_by is not None:
query = apply_sort(query, order_by)
if filter is not None:
query = apply_db_function_spec_as_filter(query, filter)
return query


def get_query(table, limit, offset, order_by, filters, cols=None, group_by=None):
duplicate_columns, filters = _get_duplicate_data_columns(table, filters)
if duplicate_columns:
select_target = _get_duplicate_only_cte(table, duplicate_columns)
def get_query(
table,
limit,
offset,
order_by,
filter=None,
columns_to_select=None,
group_by=None,
duplicate_only=None
):
if duplicate_only:
select_target = _get_duplicate_only_cte(table, duplicate_only)
else:
select_target = table

if isinstance(group_by, group.GroupBy):
query = group.get_group_augmented_records_query(table, group_by)
selectable = group.get_group_augmented_records_query(select_target, group_by)
else:
query = select(*(cols or select_target.c)).select_from(select_target)
selectable = select(select_target)

query = query.limit(limit).offset(offset)
if order_by is not None:
query = apply_sort(query, order_by)
if filters is not None:
query = apply_filters(query, filters)
return query
selectable = _sort_and_filter(selectable, order_by, filter)

if columns_to_select:
selectable = selectable.cte()
selectable = select(*columns_to_select).select_from(selectable)

selectable = selectable.limit(limit).offset(offset)
return selectable


def get_record(table, engine, id_value):
Expand All @@ -82,23 +67,32 @@ def get_record(table, engine, id_value):


def get_records(
table, engine, limit=None, offset=None, order_by=[], filters=[], group_by=None,
table,
engine,
limit=None,
offset=None,
order_by=[],
filter=None,
group_by=None,
duplicate_only=None,
):
"""
Returns annotated records from a table.

Args:
table: SQLAlchemy table object
engine: SQLAlchemy engine object
limit: int, gives number of rows to return
offset: int, gives number of rows to skip
order_by: list of dictionaries, where each dictionary has a 'field' and
'direction' field.
See: https://github.com/centerofci/sqlalchemy-filters#sort-format
filters: list of dictionaries, where each dictionary has a 'field' and 'op'
field, in addition to an 'value' field if appropriate.
See: https://github.com/centerofci/sqlalchemy-filters#filters-format
group_by: group.GroupBy object
table: SQLAlchemy table object
engine: SQLAlchemy engine object
limit: int, gives number of rows to return
offset: int, gives number of rows to skip
order_by: list of dictionaries, where each dictionary has a 'field' and
'direction' field.
See: https://github.com/centerofci/sqlalchemy-filters#sort-format
filter: a dictionary with one key-value pair, where the key is the filter id and
the value is a list of parameters; supports composition/nesting.
See: https://github.com/centerofci/sqlalchemy-filters#filters-format
group_by: group.GroupBy object
duplicate_only: list of column names; only rows that have duplicates across those rows
will be returned
"""
if not order_by:
# Set default ordering if none was requested
Expand All @@ -111,14 +105,29 @@ def get_records(
order_by = [{'field': col, 'direction': 'asc'}
for col in table.columns]

query = get_query(table, limit, offset, order_by, filters, group_by=group_by)
query = get_query(
table=table,
limit=limit,
offset=offset,
order_by=order_by,
filter=filter,
group_by=group_by,
duplicate_only=duplicate_only
)
return execute_query(engine, query)


def get_count(table, engine, filters=[]):
def get_count(table, engine, filter=None):
col_name = "_count"
cols = [func.count().label(col_name)]
query = get_query(table, None, None, None, filters, cols)
columns_to_select = [func.count().label(col_name)]
query = get_query(
table=table,
limit=None,
offset=None,
order_by=None,
filter=filter,
columns_to_select=columns_to_select
)
return execute_query(engine, query)[0][col_name]


Expand Down
26 changes: 26 additions & 0 deletions db/tests/functions/operations/test_deserialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
from db.functions.exceptions import UnknownDBFunctionID, BadDBFunctionFormat
from db.functions.operations.deserialize import get_db_function_from_ma_function_spec


exceptions_test_list = [
(
{
"non_existent_fn": [
{"column_name": ["varchar"]},
{"literal": ["test"]},
]
},
UnknownDBFunctionID
),
(
{"empty": {"column_name": ["varchar"]}, },
BadDBFunctionFormat
),
]


@pytest.mark.parametrize("filter,exception", exceptions_test_list)
def test_get_records_filters_exceptions(filter, exception):
with pytest.raises(exception):
get_db_function_from_ma_function_spec(filter)
Loading