Skip to content

perf: optimize query methods and reduce preview panel updates #794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 68 additions & 36 deletions tagstudio/src/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from datetime import UTC, datetime
from os import makedirs
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import uuid4
from warnings import catch_warnings

Expand Down Expand Up @@ -69,6 +70,10 @@
from .models import Entry, Folder, Namespace, Preferences, Tag, TagAlias, TagColorGroup, ValueType
from .visitors import SQLBoolExpressionBuilder

if TYPE_CHECKING:
from sqlalchemy import Select


logger = structlog.get_logger(__name__)

TAG_CHILDREN_QUERY = text("""
Expand Down Expand Up @@ -259,7 +264,7 @@ def migrate_json_to_sqlite(self, json_lib: JsonLibrary):
for k, v in field.items():
# Old tag fields get added as tags
if k in LEGACY_TAG_FIELD_IDS:
self.add_tags_to_entry(entry_id=entry.id + 1, tag_ids=v)
self.add_tags_to_entries(entry_ids=entry.id + 1, tag_ids=v)
else:
self.add_field_to_entry(
entry_id=(entry.id + 1), # JSON IDs start at 0 instead of 1
Expand Down Expand Up @@ -513,30 +518,49 @@ def get_entry_full(
self, entry_id: int, with_fields: bool = True, with_tags: bool = True
) -> Entry | None:
"""Load entry and join with all joins and all tags."""
# NOTE: TODO: Currently this method makes multiple separate queries to the db and combines
# those into a final Entry object (if using "with" args). This was done due to it being
# much more efficient than the existing join query, however there likely exists a single
# query that can accomplish the same task without exhibiting the same slowdown.
with Session(self.engine) as session:
statement = select(Entry).where(Entry.id == entry_id)
tags: set[Tag] | None = None
tag_stmt: Select[tuple[Tag]]
entry_stmt = select(Entry).where(Entry.id == entry_id).limit(1)
if with_fields:
statement = (
statement.outerjoin(Entry.text_fields)
entry_stmt = (
entry_stmt.outerjoin(Entry.text_fields)
.outerjoin(Entry.datetime_fields)
.options(selectinload(Entry.text_fields), selectinload(Entry.datetime_fields))
)
# if with_tags:
# entry_stmt = entry_stmt.outerjoin(Entry.tags).options(selectinload(Entry.tags))
if with_tags:
statement = (
statement.outerjoin(Entry.tags)
.outerjoin(TagAlias)
.options(
selectinload(Entry.tags).options(
joinedload(Tag.aliases),
joinedload(Tag.parent_tags),
)
tag_stmt = select(Tag).where(
and_(
TagEntry.tag_id == Tag.id,
TagEntry.entry_id == entry_id,
)
)
entry = session.scalar(statement)

start_time = time.time()
entry = session.scalar(entry_stmt)
if with_tags:
tags = set(session.scalars(tag_stmt)) # pyright: ignore [reportPossiblyUnboundVariable]
end_time = time.time()
logger.info(
f"[Library] Time it took to get entry: "
f"{format_timespan(end_time-start_time, max_units=5)}",
with_fields=with_fields,
with_tags=with_tags,
)
if not entry:
return None
session.expunge(entry)
make_transient(entry)

# Recombine the separately queried tags with the base entry object.
if with_tags and tags:
entry.tags = tags
return entry

def get_entries_full(self, entry_ids: list[int] | set[int]) -> Iterator[Entry]:
Expand Down Expand Up @@ -1089,41 +1113,49 @@ def add_tag(
session.rollback()
return None

def add_tags_to_entry(self, entry_id: int, tag_ids: int | list[int] | set[int]) -> bool:
"""Add one or more tags to an entry."""
tag_ids = [tag_ids] if isinstance(tag_ids, int) else tag_ids
def add_tags_to_entries(
self, entry_ids: int | list[int], tag_ids: int | list[int] | set[int]
) -> bool:
"""Add one or more tags to one or more entries."""
entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids
tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids
with Session(self.engine, expire_on_commit=False) as session:
for tag_id in tag_ids:
try:
session.add(TagEntry(tag_id=tag_id, entry_id=entry_id))
session.flush()
except IntegrityError:
session.rollback()
for tag_id in tag_ids_:
for entry_id in entry_ids_:
try:
session.add(TagEntry(tag_id=tag_id, entry_id=entry_id))
session.flush()
except IntegrityError:
session.rollback()
try:
session.commit()
except IntegrityError as e:
logger.warning("[add_tags_to_entry]", warning=e)
logger.warning("[Library][add_tags_to_entries]", warning=e)
session.rollback()
return False
return True

def remove_tags_from_entry(self, entry_id: int, tag_ids: int | list[int] | set[int]) -> bool:
"""Remove one or more tags from an entry."""
def remove_tags_from_entries(
self, entry_ids: int | list[int], tag_ids: int | list[int] | set[int]
) -> bool:
"""Remove one or more tags from one or more entries."""
entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids
tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids
with Session(self.engine, expire_on_commit=False) as session:
try:
for tag_id in tag_ids_:
tag_entry = session.scalars(
select(TagEntry).where(
and_(
TagEntry.tag_id == tag_id,
TagEntry.entry_id == entry_id,
for entry_id in entry_ids_:
tag_entry = session.scalars(
select(TagEntry).where(
and_(
TagEntry.tag_id == tag_id,
TagEntry.entry_id == entry_id,
)
)
)
).first()
if tag_entry:
session.delete(tag_entry)
session.commit()
).first()
if tag_entry:
session.delete(tag_entry)
session.flush()
session.commit()
return True
except IntegrityError as e:
Expand Down Expand Up @@ -1331,7 +1363,7 @@ def merge_entries(self, from_entry: Entry, into_entry: Entry) -> None:
value=field.value,
)
tag_ids = [tag.id for tag in from_entry.tags]
self.add_tags_to_entry(into_entry.id, tag_ids)
self.add_tags_to_entries(into_entry.id, tag_ids)
self.remove_entries([from_entry.id])

@property
Expand Down
2 changes: 1 addition & 1 deletion tagstudio/src/qt/modals/folders_to_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def add_tag_to_tree(items: list[Tag]):

tag = add_folders_to_tree(library, tree, folders).tag
if tag and not entry.has_tag(tag):
library.add_tags_to_entry(entry.id, tag.id)
library.add_tags_to_entries(entry.id, tag.id)

logger.info("Done")

Expand Down
32 changes: 29 additions & 3 deletions tagstudio/src/qt/ts_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,8 +952,7 @@ def clear_select_action_callback(self):
self.preview_panel.update_widgets()

def add_tags_to_selected_callback(self, tag_ids: list[int]):
for entry_id in self.selected:
self.lib.add_tags_to_entry(entry_id, tag_ids)
self.lib.add_tags_to_entries(self.selected, tag_ids)

def delete_files_callback(self, origin_path: str | Path, origin_id: int | None = None):
"""Callback to send on or more files to the system trash.
Expand Down Expand Up @@ -1359,7 +1358,7 @@ def paste_fields_action_callback(self):
exists = True
if not exists:
self.lib.add_field_to_entry(id, field_id=field.type_key, value=field.value)
self.lib.add_tags_to_entry(id, self.copy_buffer["tags"])
self.lib.add_tags_to_entries(id, self.copy_buffer["tags"])
if len(self.selected) > 1:
if TAG_ARCHIVED in self.copy_buffer["tags"]:
self.update_badges({BadgeType.ARCHIVED: True}, origin_id=0, add_tags=False)
Expand Down Expand Up @@ -1650,14 +1649,41 @@ def update_badges(self, badge_values: dict[BadgeType, bool], origin_id: int, add
the items. Defaults to True.
"""
item_ids = self.selected if (not origin_id or origin_id in self.selected) else [origin_id]
pending_entries: dict[BadgeType, list[int]] = {}

logger.info(
"[QtDriver][update_badges] Updating ItemThumb badges",
badge_values=badge_values,
origin_id=origin_id,
add_tags=add_tags,
)
for it in self.item_thumbs:
if it.item_id in item_ids:
for badge_type, value in badge_values.items():
if add_tags:
if not pending_entries.get(badge_type):
pending_entries[badge_type] = []
pending_entries[badge_type].append(it.item_id)
it.toggle_item_tag(it.item_id, value, BADGE_TAGS[badge_type])
it.assign_badge(badge_type, value)

if not add_tags:
return

logger.info(
"[QtDriver][update_badges] Adding tags to updated entries",
pending_entries=pending_entries,
)
for badge_type, value in badge_values.items():
if value:
self.lib.add_tags_to_entries(
pending_entries.get(badge_type, []), BADGE_TAGS[badge_type]
)
else:
self.lib.remove_tags_from_entries(
pending_entries.get(badge_type, []), BADGE_TAGS[badge_type]
)

def filter_items(self, filter: FilterState | None = None) -> None:
if not self.lib.library_dir:
logger.info("Library not loaded")
Expand Down
14 changes: 5 additions & 9 deletions tagstudio/src/qt/widgets/item_thumb.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,15 +499,11 @@ def toggle_item_tag(
toggle_value: bool,
tag_id: int,
):
logger.info("toggle_item_tag", entry_id=entry_id, toggle_value=toggle_value, tag_id=tag_id)

if toggle_value:
self.lib.add_tags_to_entry(entry_id, tag_id)
else:
self.lib.remove_tags_from_entry(entry_id, tag_id)

if self.driver.preview_panel.is_open:
self.driver.preview_panel.update_widgets(update_preview=False)
if entry_id in self.driver.selected and self.driver.preview_panel.is_open:
if len(self.driver.selected) == 1:
self.driver.preview_panel.fields.update_toggled_tag(tag_id, toggle_value)
else:
pass

def mouseMoveEvent(self, event): # noqa: N802
if event.buttons() is not Qt.MouseButton.LeftButton:
Expand Down
32 changes: 24 additions & 8 deletions tagstudio/src/qt/widgets/preview/field_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,24 +114,29 @@ def update_from_entry(self, entry_id: int, update_badges: bool = True):
logger.warning("[FieldContainers] Updating Selection", entry_id=entry_id)

self.cached_entries = [self.lib.get_entry_full(entry_id)]
entry_ = self.cached_entries[0]
container_len: int = len(entry_.fields)
container_index = 0
entry = self.cached_entries[0]
self.update_granular(entry.tags, entry.fields, update_badges)

def update_granular(
self, entry_tags: set[Tag], entry_fields: list[BaseField], update_badges: bool = True
):
"""Individually update elements of the item preview."""
container_len: int = len(entry_fields)
container_index = 0
# Write tag container(s)
if entry_.tags:
categories = self.get_tag_categories(entry_.tags)
if entry_tags:
categories = self.get_tag_categories(entry_tags)
for cat, tags in sorted(categories.items(), key=lambda kv: (kv[0] is None, kv)):
self.write_tag_container(
container_index, tags=tags, category_tag=cat, is_mixed=False
)
container_index += 1
container_len += 1
if update_badges:
self.emit_badge_signals({t.id for t in entry_.tags})
self.emit_badge_signals({t.id for t in entry_tags})

# Write field container(s)
for index, field in enumerate(entry_.fields, start=container_index):
for index, field in enumerate(entry_fields, start=container_index):
self.write_container(index, field, is_mixed=False)

# Hide leftover container(s)
Expand All @@ -140,6 +145,17 @@ def update_from_entry(self, entry_id: int, update_badges: bool = True):
if i > (container_len - 1):
c.setHidden(True)

def update_toggled_tag(self, tag_id: int, toggle_value: bool):
"""Visually add or remove a tag from the item preview without needing to query the db."""
entry = self.cached_entries[0]
tag = self.lib.get_tag(tag_id)
if not tag:
return
new_tags = (
entry.tags.union({tag}) if toggle_value else {t for t in entry.tags if t.id != tag_id}
)
self.update_granular(entry_tags=new_tags, entry_fields=entry.fields, update_badges=False)

def hide_containers(self):
"""Hide all field and tag containers."""
for c in self.containers:
Expand Down Expand Up @@ -262,7 +278,7 @@ def add_tags_to_selected(self, tags: int | list[int]):
tags=tags,
)
for entry_id in self.driver.selected:
self.lib.add_tags_to_entry(
self.lib.add_tags_to_entries(
entry_id,
tag_ids=tags,
)
Expand Down
2 changes: 1 addition & 1 deletion tagstudio/src/qt/widgets/tag_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,6 @@ def remove_tag(self, tag_id: int):
)

for entry_id in self.driver.selected:
self.driver.lib.remove_tags_from_entry(entry_id, tag_id)
self.driver.lib.remove_tags_from_entries(entry_id, tag_id)

self.updated.emit()
4 changes: 2 additions & 2 deletions tagstudio/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ def library(request):
path=pathlib.Path("foo.txt"),
fields=lib.default_fields,
)
assert lib.add_tags_to_entry(entry.id, tag.id)
assert lib.add_tags_to_entries(entry.id, tag.id)

entry2 = Entry(
id=2,
folder=lib.folder,
path=pathlib.Path("one/two/bar.md"),
fields=lib.default_fields,
)
assert lib.add_tags_to_entry(entry2.id, tag2.id)
assert lib.add_tags_to_entries(entry2.id, tag2.id)

assert lib.add_entries([entry, entry2])
assert len(lib.tags) == 6
Expand Down
4 changes: 2 additions & 2 deletions tagstudio/tests/qt/test_field_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_meta_tag_category(qt_driver, library, entry_full):
panel = PreviewPanel(library, qt_driver)

# Ensure the Favorite tag is on entry_full
library.add_tags_to_entry(1, entry_full.id)
library.add_tags_to_entries(1, entry_full.id)

# Select the single entry
qt_driver.toggle_item_selection(entry_full.id, append=False, bridge=False)
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_custom_tag_category(qt_driver, library, entry_full):
)

# Ensure the Favorite tag is on entry_full
library.add_tags_to_entry(1, entry_full.id)
library.add_tags_to_entries(1, entry_full.id)

# Select the single entry
qt_driver.toggle_item_selection(entry_full.id, append=False, bridge=False)
Expand Down
8 changes: 4 additions & 4 deletions tagstudio/tests/test_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,8 @@ def test_merge_entries(library: Library):
tag_0 = library.add_tag(Tag(id=1000, name="tag_0"))
tag_1 = library.add_tag(Tag(id=1001, name="tag_1"))
tag_2 = library.add_tag(Tag(id=1002, name="tag_2"))
library.add_tags_to_entry(ids[0], [tag_0.id, tag_2.id])
library.add_tags_to_entry(ids[1], [tag_1.id])
library.add_tags_to_entries(ids[0], [tag_0.id, tag_2.id])
library.add_tags_to_entries(ids[1], [tag_1.id])
library.merge_entries(entry_a, entry_b)
assert library.has_path_entry(Path("b"))
assert not library.has_path_entry(Path("a"))
Expand All @@ -344,11 +344,11 @@ def test_merge_entries(library: Library):
AssertionError()


def test_remove_tag_from_entry(library, entry_full):
def test_remove_tags_from_entries(library, entry_full):
removed_tag_id = -1
for tag in entry_full.tags:
removed_tag_id = tag.id
library.remove_tags_from_entry(entry_full.id, tag.id)
library.remove_tags_from_entries(entry_full.id, tag.id)

entry = next(library.get_entries(with_joins=True))
assert removed_tag_id not in [t.id for t in entry.tags]
Expand Down