Skip to content

Commit 2615e7d

Browse files
committed
feat: instead of hardcoding child tag ids into main query, include subquery
1 parent b791159 commit 2615e7d

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

tagstudio/src/core/library/alchemy/visitors.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from typing import TYPE_CHECKING
22

33
import structlog
4-
from sqlalchemy import and_, distinct, func, or_, select, text
4+
from sqlalchemy import and_, column, distinct, func, or_, select, text, union_all
55
from sqlalchemy.orm import Session
6-
from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument
6+
from sqlalchemy.sql.expression import (
7+
BinaryExpression,
8+
ColumnExpressionArgument,
9+
CompoundSelect,
10+
)
711
from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories
812
from src.core.query_lang import BaseVisitor
913
from src.core.query_lang.ast import AST, ANDList, Constraint, ConstraintType, Not, ORList, Property
@@ -28,7 +32,7 @@
2832
FROM tag_subtags ts
2933
INNER JOIN Subtags s ON ts.child_id = s.child_id
3034
)
31-
SELECT * FROM Subtags;
35+
SELECT child_id FROM Subtags
3236
""") # noqa: E501
3337

3438

@@ -59,7 +63,10 @@ def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument:
5963
tag_ids.append(int(term.value))
6064
continue
6165
case ConstraintType.Tag:
62-
if len(ids := self.__get_tag_ids(term.value)) == 1:
66+
if (
67+
isinstance((ids := self.__get_tag_ids(term.value)), list)
68+
and len(ids) == 1
69+
):
6370
tag_ids.append(ids[0])
6471
continue
6572

@@ -113,7 +120,9 @@ def visit_property(self, node: Property) -> None:
113120
def visit_not(self, node: Not) -> ColumnExpressionArgument:
114121
return ~self.__entry_satisfies_ast(node.child)
115122

116-
def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]:
123+
def __get_tag_ids(
124+
self, tag_name: str, include_children: bool = True
125+
) -> list[int] | CompoundSelect:
117126
"""Given a tag name find the ids of all tags that this name could refer to."""
118127
with Session(self.lib.engine) as session:
119128
tag_ids = list(
@@ -131,10 +140,13 @@ def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[in
131140
)
132141
if not include_children:
133142
return tag_ids
134-
outp = []
135-
for tag_id in tag_ids:
136-
outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id})))
137-
return outp
143+
queries = [
144+
CHILDREN_QUERY.bindparams(tag_id=id).columns(column("child_id")) for id in tag_ids
145+
]
146+
outp = union_all(*queries)
147+
# if only one tag is found return that a list with that tag instead,
148+
# in order to make use of the optimisations in __entry_has_all_tags
149+
return t if len(t := list(session.scalars(outp))) == 1 else outp
138150

139151
def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]:
140152
"""Returns Binary Expression that is true if the Entry has all provided tag ids."""

0 commit comments

Comments
 (0)