Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sqlalchemy_file/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from builtins import RuntimeError
from tempfile import SpooledTemporaryFile
from typing import Any, Dict, Union
from typing import Any, Dict, List, Union

INMEMORY_FILESIZE = 1024 * 1024
LOCAL_STORAGE_DRIVER_NAME = "Local Storage"
Expand Down Expand Up @@ -74,3 +74,8 @@ def convert_size(size: Union[str, int]) -> int:
si_map = {"k": 1000, "K": 1000, "M": 1000**2, "Ki": 1024, "Mi": 1024**2}
return int(value) * si_map[si]
return size


def flatmap(lists: List[List[Any]]) -> List[Any]:
"""Flattens a list of lists into a single list."""
return [value for _list in lists for value in _list]
21 changes: 13 additions & 8 deletions sqlalchemy_file/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy.orm import ColumnProperty, Mapper, Session, SessionTransaction
from sqlalchemy.orm.attributes import get_history
from sqlalchemy_file.file import File
from sqlalchemy_file.helpers import flatmap
from sqlalchemy_file.mutable_list import MutableList
from sqlalchemy_file.processors import Processor, ThumbnailGenerator
from sqlalchemy_file.storage import StorageManager
Expand Down Expand Up @@ -190,10 +191,10 @@ def extract_files_from_history(cls, data: Union[Tuple[()], List[Any]]) -> List[s
paths = []
for item in data:
if isinstance(item, list):
paths.extend([f["path"] for f in item])
paths.extend([f["files"] for f in item])
elif isinstance(item, File):
paths.append(item["path"])
return paths
paths.append(item["files"])
return flatmap(paths)

@classmethod
def _mapper_configured(cls, mapper: Mapper, class_: Any) -> None: # type: ignore[type-arg]
Expand Down Expand Up @@ -242,10 +243,12 @@ def _after_delete(cls, mapper: Mapper, _: Connection, obj: Any) -> None: # type
if value is not None:
cls.add_old_files_to_session(
inspect(obj).session,
[
f["path"]
for f in (value if isinstance(value, list) else [value])
],
flatmap(
[
f["files"]
for f in (value if isinstance(value, list) else [value])
]
),
)

@classmethod
Expand Down Expand Up @@ -280,7 +283,9 @@ def _before_update(cls, mapper: Mapper, _: Connection, obj: Any) -> None: # typ
)
if isinstance(value, MutableList):
_removed = getattr(value, "_removed", ())
cls.add_old_files_to_session(session, [f["path"] for f in _removed])
cls.add_old_files_to_session(
session, flatmap([f["files"] for f in _removed])
)

@classmethod
def _before_insert(cls, mapper: Mapper, _: Connection, obj: Any) -> None: # type: ignore[type-arg]
Expand Down
36 changes: 34 additions & 2 deletions tests/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import tempfile

import pytest
from libcloud.storage.types import ObjectDoesNotExistError
from PIL import Image
from sqlalchemy import Column, Integer, String, select
from sqlalchemy.orm import Session, declarative_base
from sqlalchemy_file.storage import StorageManager
Expand Down Expand Up @@ -51,8 +53,6 @@ def setup_method(self, method) -> None:

def test_create_image_with_thumbnail(self, fake_image) -> None:
with Session(engine) as session:
from PIL import Image

session.add(Book(title="Pointless Meetings", cover=fake_image))
session.flush()
book = session.execute(
Expand All @@ -66,6 +66,38 @@ def test_create_image_with_thumbnail(self, fake_image) -> None:
assert book.cover["thumbnail"]["width"] == thumbnail.width
assert book.cover["thumbnail"]["height"] == thumbnail.height

def test_update_image_with_thumbnail(self, fake_image) -> None:
with Session(engine) as session:
session.add(Book(title="Pointless Meetings", cover=fake_image))
session.commit()
book = session.execute(
select(Book).where(Book.title == "Pointless Meetings")
).scalar_one()
old_file_id = book.cover.path
old_thumbnail_file_id = book.cover.thumbnail["path"]
book.cover = fake_image
session.commit()
with pytest.raises(ObjectDoesNotExistError):
assert StorageManager.get_file(old_file_id)
with pytest.raises(ObjectDoesNotExistError):
assert StorageManager.get_file(old_thumbnail_file_id)

def test_delete_image_with_thumbnail(self, fake_image) -> None:
with Session(engine) as session:
session.add(Book(title="Pointless Meetings", cover=fake_image))
session.commit()
book = session.execute(
select(Book).where(Book.title == "Pointless Meetings")
).scalar_one()
old_file_id = book.cover.path
old_thumbnail_file_id = book.cover.thumbnail["path"]
session.delete(book)
session.commit()
with pytest.raises(ObjectDoesNotExistError):
assert StorageManager.get_file(old_file_id)
with pytest.raises(ObjectDoesNotExistError):
assert StorageManager.get_file(old_thumbnail_file_id)

def teardown_method(self, method):
for obj in StorageManager.get().list_objects():
obj.delete()
Expand Down