diff --git a/CHANGELOG.md b/CHANGELOG.md index cbad59cc3..316c1036d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ Write the date in place of the "Unreleased" in the case a new version is release - Propagate setting `include_data_sources` into child nodes. - Populate attributes in member data variables and coordinates of xarray Datasets. - Update dependencies. +- Fix behavior of queries `In` and `NotIn` when passed an empty list of values. ## v0.1.0a120 (25 April 2024) diff --git a/tiled/_tests/test_queries.py b/tiled/_tests/test_queries.py index 2974fb6a7..cfa344659 100644 --- a/tiled/_tests/test_queries.py +++ b/tiled/_tests/test_queries.py @@ -230,6 +230,10 @@ def test_in(client, query_values): ] +def test_in_empty(client): + assert list(client.search(In("letter", []))) == [] + + @pytest.mark.parametrize( "query_values", [ @@ -256,6 +260,10 @@ def test_notin(client, query_values): ) +def test_not_in_empty(client): + assert sorted(list(client.search(NotIn("letter", [])))) == sorted(list(client)) + + @pytest.mark.parametrize( "include_values,exclude_values", [ diff --git a/tiled/adapters/mapping.py b/tiled/adapters/mapping.py index 367d82c72..fbaddf0d4 100644 --- a/tiled/adapters/mapping.py +++ b/tiled/adapters/mapping.py @@ -766,6 +766,8 @@ def notin(query: Any, tree: MapAdapter) -> MapAdapter: """ matches = {} + if len(query.value) == 0: + return tree for key, value, term in iter_child_metadata(query.key, tree): if term not in query.value: matches[key] = value diff --git a/tiled/catalog/adapter.py b/tiled/catalog/adapter.py index 4ee465cb8..a329bb787 100644 --- a/tiled/catalog/adapter.py +++ b/tiled/catalog/adapter.py @@ -10,11 +10,24 @@ import uuid from functools import partial, reduce from pathlib import Path +from typing import Callable, Dict from urllib.parse import quote_plus, urlparse import anyio from fastapi import HTTPException -from sqlalchemy import delete, event, func, not_, or_, select, text, type_coerce, update +from sqlalchemy import ( + delete, + event, + false, + func, + not_, + or_, + select, + text, + true, + type_coerce, + update, +) from sqlalchemy.dialects.postgresql import JSONB, REGCONFIG from sqlalchemy.engine import make_url from sqlalchemy.exc import IntegrityError @@ -1206,22 +1219,40 @@ def specs(query, tree): return tree.new_variation(conditions=tree.conditions + conditions) -def in_or_not_in(query, tree, method): - dialect_name = tree.engine.url.get_dialect().name +def in_or_not_in_sqlite(query, tree, method): keys = query.key.split(".") attr = orm.Node.metadata_[keys] - if dialect_name == "sqlite": - condition = getattr(_get_value(attr, type(query.value[0])), method)(query.value) - elif dialect_name == "postgresql": - # Engage btree_gin index with @> operator + if len(query.value) == 0: if method == "in_": + # Results cannot possibly be "in" in an empty list, + # so put a False condition in the list ensuring that + # there are no rows return. + condition = false() + else: # method == "not_in" + # All results are always "not in" an empty list. + condition = true() + else: + condition = getattr(_get_value(attr, type(query.value[0])), method)(query.value) + return tree.new_variation(conditions=tree.conditions + [condition]) + + +def in_or_not_in_postgresql(query, tree, method): + keys = query.key.split(".") + # Engage btree_gin index with @> operator + if method == "in_": + if len(query.value) == 0: + condition = false() + else: condition = or_( *( orm.Node.metadata_.op("@>")(key_array_to_json(keys, item)) for item in query.value ) ) - elif method == "not_in": + elif method == "not_in": + if len(query.value) == 0: + condition = true() + else: condition = not_( or_( *( @@ -1230,13 +1261,23 @@ def in_or_not_in(query, tree, method): ) ) ) - else: - raise UnsupportedQueryType("NotIn") - else: - raise UnsupportedQueryType("In/NotIn") return tree.new_variation(conditions=tree.conditions + [condition]) +_IN_OR_NOT_IN_DIALECT_DISPATCH: Dict[str, Callable] = { + "sqlite": in_or_not_in_sqlite, + "postgresql": in_or_not_in_postgresql, +} + + +def in_or_not_in(query, tree, method): + METHODS = {"in_", "not_in"} + if method not in METHODS: + raise ValueError(f"method must be one of {METHODS}") + dialect_name = tree.engine.url.get_dialect().name + return _IN_OR_NOT_IN_DIALECT_DISPATCH[dialect_name](query, tree, method) + + def keys_filter(query, tree): condition = orm.Node.key.in_(query.keys) return tree.new_variation(conditions=tree.conditions + [condition])