1
1
from typing import TYPE_CHECKING
2
2
3
3
import structlog
4
- from sqlalchemy import and_ , distinct , func , or_ , select , text
4
+ from sqlalchemy import and_ , column , distinct , func , or_ , select , text , union_all
5
5
from sqlalchemy .orm import Session
6
- from sqlalchemy .sql .expression import BinaryExpression , ColumnExpressionArgument
6
+ from sqlalchemy .sql .expression import (
7
+ BinaryExpression ,
8
+ ColumnExpressionArgument ,
9
+ CompoundSelect ,
10
+ )
7
11
from src .core .media_types import FILETYPE_EQUIVALENTS , MediaCategories
8
12
from src .core .query_lang import BaseVisitor
9
13
from src .core .query_lang .ast import AST , ANDList , Constraint , ConstraintType , Not , ORList , Property
28
32
FROM tag_subtags ts
29
33
INNER JOIN Subtags s ON ts.child_id = s.child_id
30
34
)
31
- SELECT * FROM Subtags;
35
+ SELECT child_id FROM Subtags
32
36
""" ) # noqa: E501
33
37
34
38
@@ -59,7 +63,10 @@ def visit_and_list(self, node: ANDList) -> ColumnExpressionArgument:
59
63
tag_ids .append (int (term .value ))
60
64
continue
61
65
case ConstraintType .Tag :
62
- if len (ids := self .__get_tag_ids (term .value )) == 1 :
66
+ if (
67
+ isinstance ((ids := self .__get_tag_ids (term .value )), list )
68
+ and len (ids ) == 1
69
+ ):
63
70
tag_ids .append (ids [0 ])
64
71
continue
65
72
@@ -113,7 +120,9 @@ def visit_property(self, node: Property) -> None:
113
120
def visit_not (self , node : Not ) -> ColumnExpressionArgument :
114
121
return ~ self .__entry_satisfies_ast (node .child )
115
122
116
- def __get_tag_ids (self , tag_name : str , include_children : bool = True ) -> list [int ]:
123
+ def __get_tag_ids (
124
+ self , tag_name : str , include_children : bool = True
125
+ ) -> list [int ] | CompoundSelect :
117
126
"""Given a tag name find the ids of all tags that this name could refer to."""
118
127
with Session (self .lib .engine ) as session :
119
128
tag_ids = list (
@@ -131,10 +140,13 @@ def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[in
131
140
)
132
141
if not include_children :
133
142
return tag_ids
134
- outp = []
135
- for tag_id in tag_ids :
136
- outp .extend (list (session .scalars (CHILDREN_QUERY , {"tag_id" : tag_id })))
137
- return outp
143
+ queries = [
144
+ CHILDREN_QUERY .bindparams (tag_id = id ).columns (column ("child_id" )) for id in tag_ids
145
+ ]
146
+ outp = union_all (* queries )
147
+ # if only one tag is found return that a list with that tag instead,
148
+ # in order to make use of the optimisations in __entry_has_all_tags
149
+ return t if len (t := list (session .scalars (outp ))) == 1 else outp
138
150
139
151
def __entry_has_all_tags (self , tag_ids : list [int ]) -> BinaryExpression [bool ]:
140
152
"""Returns Binary Expression that is true if the Entry has all provided tag ids."""
0 commit comments