Skip to content

Commit a79da81

Browse files
authored
feat: sqlalchemy integration with tests (#2)
* feat: basic sqlalchemy type_decorator and proxy class * test: add test cases for sqlalchemy integration
1 parent 20073e4 commit a79da81

File tree

10 files changed

+270
-43
lines changed

10 files changed

+270
-43
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ s3 = ["aioboto3>=15.4.0"]
1414

1515
[dependency-groups]
1616
dev = [
17+
"aiosqlite>=0.21.0",
1718
"pytest>=8.4.2",
1819
"pytest-aioboto3>=0.6.0",
1920
"pytest-asyncio>=1.2.0",
21+
"sqlalchemy>=2.0.44",
2022
]
2123

2224
[tool.uv.build-backend]

src/cloud_storage/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1+
from .base import AsyncStorageFile
12
from .s3 import AsyncS3Storage
23

34
__version__ = "0.1.0"
4-
__all__ = ["AsyncS3Storage"]
5+
__all__ = ["AsyncStorageFile", "AsyncS3Storage"]

src/cloud_storage/base.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,39 @@
33

44

55
class AsyncBaseStorage:
6-
def get_secure_key(self, key: str) -> str:
6+
def get_name(self, name: str) -> str:
77
raise NotImplementedError()
88

9-
async def get_size(self, key: str) -> int:
9+
async def get_size(self, name: str) -> int:
1010
raise NotImplementedError()
1111

12-
async def get_url(self, key: str) -> str:
12+
async def get_url(self, name: str) -> str:
1313
raise NotImplementedError()
1414

15-
async def upload(self, file: BinaryIO, key: str) -> str:
15+
async def upload(self, file: BinaryIO, name: str) -> str:
1616
raise NotImplementedError()
1717

18-
async def delete(self, key: str) -> None:
18+
async def delete(self, name: str) -> None:
1919
raise NotImplementedError()
20+
21+
22+
class AsyncStorageFile:
23+
def __init__(self, name: str, storage: AsyncBaseStorage):
24+
self._name: str = name
25+
self._storage: AsyncBaseStorage = storage
26+
27+
@property
28+
def name(self) -> str:
29+
return self._name
30+
31+
async def get_size(self) -> int:
32+
return await self._storage.get_size(self._name)
33+
34+
async def get_url(self) -> str:
35+
return await self._storage.get_url(self._name)
36+
37+
async def upload(self, file: BinaryIO) -> str:
38+
return await self._storage.upload(file=file, name=self._name)
39+
40+
async def delete(self) -> None:
41+
await self._storage.delete(self._name)

src/cloud_storage/integrations/__init__.py

Whitespace-only changes.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import Any, override
2+
from sqlalchemy.engine.interfaces import Dialect
3+
from sqlalchemy.types import TypeDecorator, TypeEngine, Unicode
4+
5+
from cloud_storage.base import AsyncBaseStorage, AsyncStorageFile
6+
7+
8+
class AsyncFileType(TypeDecorator[Any]):
9+
impl: TypeEngine[Any] | type[TypeEngine[Any]] = Unicode
10+
cache_ok: bool | None = True
11+
12+
def __init__(self, storage: AsyncBaseStorage, *args: Any, **kwargs: Any):
13+
super().__init__(*args, **kwargs)
14+
self.storage: AsyncBaseStorage = storage
15+
16+
@override
17+
def process_bind_param(self, value: Any, dialect: Dialect) -> str:
18+
if value is None:
19+
return value
20+
if isinstance(value, str):
21+
return value
22+
23+
name = getattr(value, "name", None)
24+
if name:
25+
return name
26+
return str(value)
27+
28+
@override
29+
def process_result_value(
30+
self, value: Any | None, dialect: Dialect
31+
) -> AsyncStorageFile | None:
32+
if value is None:
33+
return None
34+
return AsyncStorageFile(name=value, storage=self.storage)

src/cloud_storage/s3.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def _get_s3_client(self) -> Any:
5757
)
5858

5959
@override
60-
def get_secure_key(self, key: str) -> str:
61-
parts = Path(key).parts
60+
def get_name(self, name: str) -> str:
61+
parts = Path(name).parts
6262
safe_parts: list[str] = []
6363

6464
for part in parts:
@@ -69,13 +69,13 @@ def get_secure_key(self, key: str) -> str:
6969
return str(safe_path)
7070

7171
@override
72-
async def get_size(self, key: str) -> int:
73-
key = self.get_secure_key(key)
72+
async def get_size(self, name: str) -> int:
73+
name = self.get_name(name)
7474

7575
async with self._get_s3_client() as s3_client:
7676
try:
77-
response = await s3_client.head_object(Bucket=self.bucket_name, Key=key)
78-
return int(response.get("ContentLength", 0))
77+
res = await s3_client.head_object(Bucket=self.bucket_name, Key=name)
78+
return int(res.get("ContentLength", 0))
7979
except ClientError as e:
8080
code = e.response.get("Error", {}).get("Code")
8181
status = e.response.get("ResponseMetadata", {}).get("HTTPStatusCode")
@@ -85,39 +85,39 @@ async def get_size(self, key: str) -> int:
8585
raise
8686

8787
@override
88-
async def get_url(self, key: str, expires_in: int = 3600) -> str:
88+
async def get_url(self, name: str) -> str:
8989
if self.custom_domain:
90-
return f"{self._http_scheme}://{self.custom_domain}/{key}"
90+
return f"{self._http_scheme}://{self.custom_domain}/{name}"
9191
elif self.querystring_auth:
9292
async with self._get_s3_client() as s3_client:
93-
params = {"Bucket": self.bucket_name, "Key": key}
93+
params = {"Bucket": self.bucket_name, "Key": name}
9494
return await s3_client.generate_presigned_url(
95-
"get_object", Params=params, ExpiresIn=expires_in
95+
"get_object", Params=params
9696
)
9797
else:
98-
url = f"{self._http_scheme}://{self.endpoint_url}/{self.bucket_name}/{key}"
98+
url = f"{self._http_scheme}://{self.endpoint_url}/{self.bucket_name}/{name}"
9999
return url
100100

101101
@override
102-
async def upload(self, file: BinaryIO, key: str) -> str:
103-
key = self.get_secure_key(key)
104-
content_type, _ = mimetypes.guess_type(key)
102+
async def upload(self, file: BinaryIO, name: str) -> str:
103+
name = self.get_name(name)
104+
content_type, _ = mimetypes.guess_type(name)
105105
extra_args = {"ContentType": content_type or "application/octet-stream"}
106106
if self.default_acl:
107107
extra_args["ACL"] = self.default_acl
108108

109109
async with self._get_s3_client() as s3_client:
110110
file.seek(0)
111111
await s3_client.put_object(
112-
Bucket=self.bucket_name, Key=key, Body=file, **extra_args
112+
Bucket=self.bucket_name, Key=name, Body=file, **extra_args
113113
)
114-
return key
114+
return name
115115

116116
@override
117-
async def delete(self, key: str) -> None:
117+
async def delete(self, name: str) -> None:
118118
async with self._get_s3_client() as s3_client:
119119
try:
120-
await s3_client.delete_object(Bucket=self.bucket_name, Key=key)
120+
await s3_client.delete_object(Bucket=self.bucket_name, Key=name)
121121
except ClientError as e:
122122
if e.response.get("Error", {}).get("Code") != "NoSuchKey":
123123
raise
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Any
2+
import pytest
3+
4+
from cloud_storage import AsyncS3Storage
5+
6+
7+
@pytest.fixture
8+
async def s3_test_storage(s3_test_env: Any) -> AsyncS3Storage:
9+
bucket_name, endpoint_without_scheme = s3_test_env
10+
11+
return AsyncS3Storage(
12+
bucket_name=bucket_name,
13+
endpoint_url=endpoint_without_scheme,
14+
aws_access_key_id="fake-access-key",
15+
aws_secret_access_key="fake-secret-key",
16+
use_ssl=False,
17+
)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from io import BytesIO
2+
from typing import Any
3+
import pytest
4+
from sqlalchemy import Column, Integer
5+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
6+
from sqlalchemy.ext.asyncio.session import async_sessionmaker
7+
from sqlalchemy.orm import declarative_base
8+
9+
from cloud_storage import AsyncStorageFile
10+
from cloud_storage.integrations.sqlalchemy import AsyncFileType
11+
12+
Base = declarative_base()
13+
14+
15+
class Document(Base):
16+
__tablename__: str = "documents"
17+
id: Column[int] = Column(Integer, primary_key=True)
18+
file: Column[str] = Column(AsyncFileType(storage=None)) # pyright: ignore[reportArgumentType]
19+
20+
21+
@pytest.mark.asyncio
22+
async def test_sqlalchemy_filetype_with_s3(s3_test_storage: Any):
23+
storage = s3_test_storage
24+
# assign s3_storage to file column
25+
Document.__table__.columns.file.type.storage = storage
26+
27+
# create async engine and session
28+
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
29+
async_session = async_sessionmaker(
30+
engine, expire_on_commit=False, class_=AsyncSession
31+
)
32+
33+
# create db tables
34+
async with engine.begin() as conn:
35+
await conn.run_sync(Base.metadata.create_all)
36+
37+
# create demo file object
38+
file_name = "uploads/test-file.txt"
39+
file_content = b"SQLAlchemy + S3 integration test"
40+
file_obj = BytesIO(file_content)
41+
42+
# upload to s3 storage to fetch from db and test methods
43+
await storage.upload(file_obj, file_name)
44+
45+
# insert record into db
46+
async with async_session() as session:
47+
doc = Document(file=file_name)
48+
session.add(doc)
49+
await session.commit()
50+
doc_id = doc.id
51+
52+
# fetch record back and run tests
53+
async with async_session() as session:
54+
doc = await session.get(Document, doc_id)
55+
if doc is None:
56+
return
57+
58+
# check instance type
59+
assert isinstance(doc.file, AsyncStorageFile)
60+
assert doc.file.name == f"{file_name}"
61+
62+
# methods should work
63+
url = await doc.file.get_url()
64+
assert file_name in url
65+
66+
size = await doc.file.get_size()
67+
assert size == len(file_content)
68+
69+
# deleting should not raise
70+
await doc.file.delete()
71+
size_after_delete = await storage.get_size(file_name)
72+
assert size_after_delete == 0
73+
74+
# close all connections
75+
await engine.dispose()

tests/test_s3_storage.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
@pytest.mark.asyncio
99
async def test_s3_storage_methods(s3_test_env: Any):
1010
bucket_name, endpoint_without_scheme = s3_test_env
11-
1211
storage = AsyncS3Storage(
1312
bucket_name=bucket_name,
1413
endpoint_url=endpoint_without_scheme,
@@ -17,28 +16,27 @@ async def test_s3_storage_methods(s3_test_env: Any):
1716
use_ssl=False,
1817
)
1918

19+
file_name = "test/file.txt"
2020
file_content = b"hello moto"
2121
file_obj = BytesIO(file_content)
2222

23-
key = "test/file.txt"
24-
2523
# upload test
26-
returned_key = await storage.upload(file_obj, key)
27-
assert returned_key == storage.get_secure_key(key)
24+
returned_name = await storage.upload(file_obj, file_name)
25+
assert returned_name == storage.get_name(file_name)
2826

2927
# get url test without custom domain or querystring_auth
30-
url = await storage.get_url(key)
31-
assert key in url
28+
url = await storage.get_url(file_name)
29+
assert file_name in url
3230

3331
# get size test
34-
size = await storage.get_size(key)
32+
size = await storage.get_size(file_name)
3533
assert size == len(file_content)
3634

3735
# delete test (should suceed silently)
38-
await storage.delete(key)
36+
await storage.delete(file_name)
3937

4038
# get size test after delete (should return 0)
41-
size_after_delete = await storage.get_size(key)
39+
size_after_delete = await storage.get_size(file_name)
4240
assert size_after_delete == 0
4341

4442

@@ -55,8 +53,8 @@ async def test_s3_storage_querystring_auth(s3_test_env: Any):
5553
querystring_auth=True,
5654
)
5755

58-
key = "test/file.txt"
59-
url = await storage.get_url(key)
56+
name = "test/file.txt"
57+
url = await storage.get_url(name)
6058

6159
assert url.count("AWSAccessKeyId=") == 1
6260
assert url.count("Signature=") == 1
@@ -76,11 +74,11 @@ async def test_s3_storage_custom_domain(s3_test_env: Any):
7674
custom_domain="cdn.example.com",
7775
)
7876

79-
key = "test/file.txt"
80-
url = await storage.get_url(key)
77+
name = "test/file.txt"
78+
url = await storage.get_url(name)
8179

8280
assert url.startswith("http://cdn.example.com/")
83-
assert key in await storage.get_url(key)
81+
assert name in await storage.get_url(name)
8482

8583

8684
@pytest.mark.asyncio
@@ -93,8 +91,8 @@ async def test_get_secure_key_normalization():
9391
use_ssl=False,
9492
)
9593

96-
raw_key = "../../weird ../file name.txt"
97-
normalized_key = storage.get_secure_key(raw_key)
94+
raw_name = "../../weird ../file name.txt"
95+
normalized_name = storage.get_name(raw_name)
9896

99-
assert ".." not in normalized_key
100-
assert ".txt" in normalized_key
97+
assert ".." not in normalized_name
98+
assert ".txt" in normalized_name

0 commit comments

Comments
 (0)