Skip to content

Commit 704237f

Browse files
committed
DRY models
1 parent deeaef7 commit 704237f

File tree

4 files changed

+55
-65
lines changed

4 files changed

+55
-65
lines changed

tagstudio/src/core/library/alchemy/fields.py

Lines changed: 42 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,69 @@
22

33
from dataclasses import dataclass
44
from enum import Enum
5-
from typing import Union, Any, TYPE_CHECKING
5+
from typing import Any, TYPE_CHECKING
66

77
from sqlalchemy import ForeignKey, ForeignKeyConstraint
8-
from sqlalchemy.orm import Mapped, mapped_column, relationship
8+
from sqlalchemy.orm import Mapped, mapped_column, relationship, declared_attr
99

1010
from .db import Base
1111
from .enums import FieldTypeEnum
1212

1313
if TYPE_CHECKING:
1414
from .models import Entry, Tag, LibraryField
1515

16-
Field = Union["TextField", "TagBoxField", "DatetimeField"]
1716

17+
class BaseField(Base):
18+
__abstract__ = True
1819

19-
class BooleanField(Base):
20-
__tablename__ = "boolean_fields"
20+
@declared_attr
21+
def id(cls) -> Mapped[int]:
22+
return mapped_column(primary_key=True, autoincrement=True)
2123

22-
id: Mapped[int] = mapped_column(primary_key=True)
23-
type_key: Mapped[str] = mapped_column(ForeignKey("library_fields.key"))
24-
type: Mapped[LibraryField] = relationship(foreign_keys=[type_key], lazy=False)
24+
@declared_attr
25+
def type_key(cls) -> Mapped[str]:
26+
return mapped_column(ForeignKey("library_fields.key"))
2527

26-
entry_id: Mapped[int] = mapped_column(ForeignKey("entries.id"))
27-
entry: Mapped[Entry] = relationship()
28+
@declared_attr
29+
def type(cls) -> Mapped[LibraryField]:
30+
return relationship(foreign_keys=[cls.type_key], lazy=False) # type: ignore
2831

29-
value: Mapped[bool]
30-
position: Mapped[int]
32+
@declared_attr
33+
def entry_id(cls) -> Mapped[int]:
34+
return mapped_column(ForeignKey("entries.id"))
3135

32-
def __key(self):
33-
return (self.type, self.value)
36+
@declared_attr
37+
def entry(cls) -> Mapped[Entry]:
38+
return relationship(foreign_keys=[cls.entry_id]) # type: ignore
39+
40+
@declared_attr
41+
def position(cls) -> Mapped[int]:
42+
return mapped_column()
3443

3544
def __hash__(self):
3645
return hash(self.__key())
3746

47+
def __key(self):
48+
raise NotImplementedError
49+
50+
value: Any
51+
52+
53+
class BooleanField(BaseField):
54+
__tablename__ = "boolean_fields"
55+
56+
value: Mapped[bool]
57+
58+
def __key(self):
59+
return (self.type, self.value)
60+
3861
def __eq__(self, value) -> bool:
3962
if isinstance(value, BooleanField):
4063
return self.__key() == value.__key()
4164
raise NotImplementedError
4265

4366

44-
class TextField(Base):
67+
class TextField(BaseField):
4568
__tablename__ = "text_fields"
4669
# constrain for combination of: entry_id, type_key and position
4770
__table_args__ = (
@@ -51,21 +74,10 @@ class TextField(Base):
5174
),
5275
)
5376

54-
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
55-
type_key: Mapped[str] = mapped_column(ForeignKey("library_fields.key"))
56-
type: Mapped[LibraryField] = relationship(foreign_keys=[type_key], lazy=False)
57-
58-
entry_id: Mapped[int] = mapped_column(ForeignKey("entries.id"))
59-
entry: Mapped[Entry] = relationship(foreign_keys=[entry_id])
60-
6177
value: Mapped[str | None]
62-
position: Mapped[int]
6378

64-
def __key(self):
65-
return (self.type, self.value)
66-
67-
def __hash__(self):
68-
return hash(self.__key())
79+
def __key(self) -> tuple:
80+
return self.type, self.value
6981

7082
def __eq__(self, value) -> bool:
7183
if isinstance(value, TextField):
@@ -75,18 +87,10 @@ def __eq__(self, value) -> bool:
7587
raise NotImplementedError
7688

7789

78-
class TagBoxField(Base):
90+
class TagBoxField(BaseField):
7991
__tablename__ = "tag_box_fields"
8092

81-
id: Mapped[int] = mapped_column(primary_key=True)
82-
type_key: Mapped[str] = mapped_column(ForeignKey("library_fields.key"))
83-
type: Mapped[LibraryField] = relationship(foreign_keys=[type_key], lazy=False)
84-
85-
entry_id: Mapped[int] = mapped_column(ForeignKey("entries.id"))
86-
entry: Mapped[Entry] = relationship(foreign_keys=[entry_id])
87-
8893
tags: Mapped[set[Tag]] = relationship(secondary="tag_fields")
89-
position: Mapped[int]
9094

9195
def __key(self):
9296
return (
@@ -99,34 +103,20 @@ def value(self) -> None:
99103
"""For interface compatibility with other field types."""
100104
return None
101105

102-
def __hash__(self):
103-
return hash(self.__key())
104-
105106
def __eq__(self, value) -> bool:
106107
if isinstance(value, TagBoxField):
107108
return self.__key() == value.__key()
108109
raise NotImplementedError
109110

110111

111-
class DatetimeField(Base):
112+
class DatetimeField(BaseField):
112113
__tablename__ = "datetime_fields"
113114

114-
id: Mapped[int] = mapped_column(primary_key=True)
115-
type_key: Mapped[str] = mapped_column(ForeignKey("library_fields.key"))
116-
type: Mapped[LibraryField] = relationship(foreign_keys=[type_key], lazy=False)
117-
118-
entry_id: Mapped[int] = mapped_column(ForeignKey("entries.id"))
119-
entry: Mapped[Entry] = relationship(foreign_keys=[entry_id])
120-
121115
value: Mapped[str | None]
122-
position: Mapped[int]
123116

124117
def __key(self):
125118
return (self.type, self.value)
126119

127-
def __hash__(self):
128-
return hash(self.__key())
129-
130120
def __eq__(self, value) -> bool:
131121
if isinstance(value, DatetimeField):
132122
return self.__key() == value.__key()

tagstudio/src/core/library/alchemy/library.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
TagBoxField,
3636
TextField,
3737
_FieldID,
38-
Field,
38+
BaseField,
3939
)
4040
from .joins import TagSubtag, TagField
4141
from .models import Entry, Preferences, Tag, TagAlias, LibraryField, Folder
@@ -488,7 +488,7 @@ def remove_tag_from_field(self, tag: Tag, field: TagBoxField) -> None:
488488

489489
def update_field_position(
490490
self,
491-
field_class: Type[Field],
491+
field_class: Type[BaseField],
492492
field_type: str,
493493
entry_ids: list[int] | int,
494494
):
@@ -512,15 +512,15 @@ def update_field_position(
512512

513513
# Reassign `order` starting from 0
514514
for index, row in enumerate(rows):
515-
row.position = index # type: ignore
515+
row.position = index
516516
session.add(row)
517517
session.flush()
518518
if rows:
519519
session.commit()
520520

521521
def remove_entry_field(
522522
self,
523-
field: Field,
523+
field: BaseField,
524524
entry_ids: list[int],
525525
) -> None:
526526
FieldClass = type(field)
@@ -554,7 +554,7 @@ def remove_entry_field(
554554
def update_entry_field(
555555
self,
556556
entry_ids: list[int] | int,
557-
field: Field,
557+
field: BaseField,
558558
content: str | datetime | set[Tag],
559559
):
560560
if isinstance(entry_ids, int):

tagstudio/src/core/library/alchemy/models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from .enums import TagColor
99
from .fields import (
1010
DatetimeField,
11-
Field,
1211
TagBoxField,
1312
TextField,
1413
FieldTypeEnum,
1514
_FieldID,
15+
BaseField,
1616
)
1717
from .joins import TagSubtag
1818
from ...constants import TAG_FAVORITE, TAG_ARCHIVED
@@ -134,8 +134,8 @@ class Entry(Base):
134134
)
135135

136136
@property
137-
def fields(self) -> list[Field]:
138-
fields: list[Field] = []
137+
def fields(self) -> list[BaseField]:
138+
fields: list[BaseField] = []
139139
fields.extend(self.tag_box_fields)
140140
fields.extend(self.text_fields)
141141
fields.extend(self.datetime_fields)
@@ -171,7 +171,7 @@ def __init__(
171171
self,
172172
path: Path,
173173
folder: Folder,
174-
fields: list[Field] | None = None,
174+
fields: list[BaseField] | None = None,
175175
) -> None:
176176
self.path = path
177177
self.folder = folder

tagstudio/src/qt/widgets/preview_panel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
TagBoxField,
3737
DatetimeField,
3838
FieldTypeEnum,
39-
Field,
4039
_FieldID,
4140
TextField,
41+
BaseField,
4242
)
4343
from src.qt.helpers.file_opener import FileOpenerLabel, FileOpenerHelper, open_file
4444
from src.qt.modals.add_field import AddFieldModal
@@ -719,7 +719,7 @@ def set_tags_updated_slot(self, slot: object):
719719
self.tags_updated.connect(slot)
720720
self.is_connected = True
721721

722-
def write_container(self, index: int, field: Field, is_mixed: bool = False):
722+
def write_container(self, index: int, field: BaseField, is_mixed: bool = False):
723723
"""Update/Create data for a FieldContainer.
724724
725725
:param is_mixed: Relevant when multiple items are selected. If True, field is not present in all selected items
@@ -930,7 +930,7 @@ def write_container(self, index: int, field: Field, is_mixed: bool = False):
930930
container.setHidden(False)
931931
self.place_add_field_button()
932932

933-
def remove_field(self, field: Field):
933+
def remove_field(self, field: BaseField):
934934
"""Remove a field from all selected Entries."""
935935
logger.info("removing field", field=field, selected=self.selected)
936936
entry_ids = []
@@ -945,7 +945,7 @@ def remove_field(self, field: Field):
945945
if field.type_key == _FieldID.TAGS_META.value:
946946
self.driver.update_badges(self.selected)
947947

948-
def update_field(self, field: Field, content: str) -> None:
948+
def update_field(self, field: BaseField, content: str) -> None:
949949
"""Remove a field from all selected Entries, given a field object."""
950950
assert isinstance(
951951
field, (TextField, DatetimeField, TagBoxField)

0 commit comments

Comments
 (0)