Skip to content

feat: optimise AND queries #679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 13, 2025
10 changes: 9 additions & 1 deletion tagstudio/src/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,11 @@ def search_library(
statement = statement.where(Entry.suffix.in_(extensions))

statement = statement.distinct(Entry.id)
start_time = time.time()
query_count = select(func.count()).select_from(statement.alias("entries"))
count_all: int = session.execute(query_count).scalar()
end_time = time.time()
logger.info(f"finished counting ({format_timespan(end_time-start_time)})")

sort_on: ColumnExpressionArgument = Entry.id
match search.sorting_mode:
Expand All @@ -609,9 +612,14 @@ def search_library(
query_full=str(statement.compile(compile_kwargs={"literal_binds": True})),
)

start_time = time.time()
items = session.scalars(statement).fetchall()
end_time = time.time()
logger.info(f"SQL Execution finished ({format_timespan(end_time - start_time)})")

res = SearchResult(
total_count=count_all,
items=list(session.scalars(statement)),
items=list(items),
)

session.expunge_all()
Expand Down
51 changes: 26 additions & 25 deletions tagstudio/src/core/library/alchemy/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from typing import TYPE_CHECKING

import structlog
from sqlalchemy import and_, distinct, func, or_, select, text
from sqlalchemy import ColumnElement, and_, distinct, func, or_, select, text
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument
from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories
from src.core.query_lang import BaseVisitor
from src.core.query_lang.ast import AST, ANDList, Constraint, ConstraintType, Not, ORList, Property
from src.core.query_lang.ast import ANDList, Constraint, ConstraintType, Not, ORList, Property

from .joins import TagEntry
from .models import Entry, Tag, TagAlias
Expand All @@ -33,7 +32,7 @@
FROM tag_parents tp
INNER JOIN ChildTags c ON tp.child_id = c.child_id
)
SELECT * FROM ChildTags;
SELECT child_id FROM ChildTags;
""") # noqa: E501


Expand All @@ -44,17 +43,17 @@ def get_filetype_equivalency_list(item: str) -> list[str] | set[str]:
return [item]


class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]):
def __init__(self, lib: Library) -> None:
super().__init__()
self.lib = lib

def visit_or_list(self, node: ORList) -> ColumnExpressionArgument:
def visit_or_list(self, node: ORList) -> ColumnElement[bool]:
return or_(*[self.visit(element) for element in node.elements])

def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument:
def visit_and_list(self, node: ANDList) -> ColumnElement[bool]:
tag_ids: list[int] = []
bool_expressions: list[ColumnExpressionArgument] = []
bool_expressions: list[ColumnElement[bool]] = []

# Search for TagID / unambiguous Tag Constraints and store the respective tag ids separately
for term in node.terms:
Expand All @@ -74,7 +73,7 @@ def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument:
tag_ids.append(ids[0])
continue

bool_expressions.append(self.__entry_satisfies_ast(term))
bool_expressions.append(self.visit(term))

# If there are at least two tag ids use a relational division query
# to efficiently check all of them
Expand All @@ -88,15 +87,15 @@ def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument:

return and_(*bool_expressions)

def visit_constraint(self, node: Constraint) -> ColumnExpressionArgument:
def visit_constraint(self, node: Constraint) -> ColumnElement[bool]:
"""Returns a Boolean Expression that is true, if the Entry satisfies the constraint."""
if len(node.properties) != 0:
raise NotImplementedError("Properties are not implemented yet") # TODO TSQLANG

if node.type == ConstraintType.Tag:
return Entry.tags.any(Tag.id.in_(self.__get_tag_ids(node.value)))
return self.__entry_matches_tag_ids(self.__get_tag_ids(node.value))
elif node.type == ConstraintType.TagID:
return Entry.tags.any(Tag.id == int(node.value))
return self.__entry_matches_tag_ids([int(node.value)])
elif node.type == ConstraintType.Path:
return Entry.path.op("GLOB")(node.value)
elif node.type == ConstraintType.MediaType:
Expand All @@ -120,8 +119,17 @@ def visit_constraint(self, node: Constraint) -> ColumnExpressionArgument:
def visit_property(self, node: Property) -> None:
raise NotImplementedError("This should never be reached!")

def visit_not(self, node: Not) -> ColumnExpressionArgument:
return ~self.__entry_satisfies_ast(node.child)
def visit_not(self, node: Not) -> ColumnElement[bool]:
return ~self.visit(node.child)

def __entry_matches_tag_ids(self, tag_ids: list[int]) -> ColumnElement[bool]:
"""Returns a boolean expression that is true if the entry has at least one of the supplied tags.""" # noqa: E501
return (
select(1)
.correlate(Entry)
.where(and_(TagEntry.entry_id == Entry.id, TagEntry.tag_id.in_(tag_ids)))
.exists()
)

def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]:
"""Given a tag name find the ids of all tags that this name could refer to."""
Expand All @@ -146,24 +154,17 @@ def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[in
outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id})))
return outp

def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]:
def __entry_has_all_tags(self, tag_ids: list[int]) -> ColumnElement[bool]:
"""Returns Binary Expression that is true if the Entry has all provided tag ids."""
# Relational Division Query
return Entry.id.in_(
select(Entry.id)
.outerjoin(TagEntry)
select(TagEntry.entry_id)
.where(TagEntry.tag_id.in_(tag_ids))
.group_by(Entry.id)
.group_by(TagEntry.entry_id)
.having(func.count(distinct(TagEntry.tag_id)) == len(tag_ids))
)

def __entry_satisfies_ast(self, partial_query: AST) -> BinaryExpression[bool]:
"""Returns Binary Expression that is true if the Entry satisfies the partial query."""
return self.__entry_satisfies_expression(self.visit(partial_query))

def __entry_satisfies_expression(
self, expr: ColumnExpressionArgument
) -> BinaryExpression[bool]:
def __entry_satisfies_expression(self, expr: ColumnElement[bool]) -> ColumnElement[bool]:
"""Returns Binary Expression that is true if the Entry satisfies the column expression.

Executed on: Entry ⟕ TagEntry (Entry LEFT OUTER JOIN TagEntry).
Expand Down
Loading