Skip to content

Commit 412bdd6

Browse files
fix: all_pks() for complex keys (#471)
* fix: all_pks for complex keys * fix tests * more fixes * support Python below 3.9+ * black * linter again
1 parent 250d29d commit 412bdd6

File tree

3 files changed

+76
-16
lines changed

3 files changed

+76
-16
lines changed

aredis_om/model/model.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,14 @@ def decode_redis_value(
151151
return obj.decode(encoding)
152152

153153

154+
# TODO: replace with `str.removeprefix()` when only Python 3.9+ is supported
155+
def remove_prefix(value: str, prefix: str) -> str:
156+
"""Remove a prefix from a string."""
157+
if value.startswith(prefix):
158+
value = value[len(prefix) :] # noqa: E203
159+
return value
160+
161+
154162
class PipelineError(Exception):
155163
"""A Redis pipeline error."""
156164

@@ -1350,16 +1358,12 @@ async def save(
13501358
@classmethod
13511359
async def all_pks(cls): # type: ignore
13521360
key_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
1353-
# TODO: We assume the key ends with the default separator, ":" -- when
1354-
# we make the separator configurable, we need to update this as well.
1355-
# ... And probably lots of other places ...
1356-
#
1357-
# TODO: Also, we need to decide how we want to handle the lack of
1361+
# TODO: We need to decide how we want to handle the lack of
13581362
# decode_responses=True...
13591363
return (
1360-
key.split(":")[-1]
1364+
remove_prefix(key, key_prefix)
13611365
if isinstance(key, str)
1362-
else key.decode(cls.Meta.encoding).split(":")[-1]
1366+
else remove_prefix(key.decode(cls.Meta.encoding), key_prefix)
13631367
async for key in cls.db().scan_iter(f"{key_prefix}*", _type="HASH")
13641368
)
13651369

@@ -1521,16 +1525,12 @@ async def save(
15211525
@classmethod
15221526
async def all_pks(cls): # type: ignore
15231527
key_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
1524-
# TODO: We assume the key ends with the default separator, ":" -- when
1525-
# we make the separator configurable, we need to update this as well.
1526-
# ... And probably lots of other places ...
1527-
#
1528-
# TODO: Also, we need to decide how we want to handle the lack of
1528+
# TODO: We need to decide how we want to handle the lack of
15291529
# decode_responses=True...
15301530
return (
1531-
key.split(":")[-1]
1531+
remove_prefix(key, key_prefix)
15321532
if isinstance(key, str)
1533-
else key.decode(cls.Meta.encoding).split(":")[-1]
1533+
else remove_prefix(key.decode(cls.Meta.encoding), key_prefix)
15341534
async for key in cls.db().scan_iter(f"{key_prefix}*", _type="ReJSON-RL")
15351535
)
15361536

tests/test_hash_model.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,37 @@ async def test_all_pks(m):
500500
async for pk in await m.Member.all_pks():
501501
pk_list.append(pk)
502502

503-
assert len(pk_list) == 2
503+
assert sorted(pk_list) == ["0", "1"]
504+
505+
506+
@py_test_mark_asyncio
507+
async def test_all_pks_with_complex_pks(key_prefix):
508+
class City(HashModel):
509+
name: str
510+
511+
class Meta:
512+
global_key_prefix = key_prefix
513+
model_key_prefix = "city"
514+
515+
city1 = City(
516+
pk="ca:on:toronto",
517+
name="Toronto",
518+
)
519+
520+
await city1.save()
521+
522+
city2 = City(
523+
pk="ca:qc:montreal",
524+
name="Montreal",
525+
)
526+
527+
await city2.save()
528+
529+
pk_list = []
530+
async for pk in await City.all_pks():
531+
pk_list.append(pk)
532+
533+
assert sorted(pk_list) == ["ca:on:toronto", "ca:qc:montreal"]
504534

505535

506536
@py_test_mark_asyncio

tests/test_json_model.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,37 @@ async def test_all_pks(address, m, redis):
218218
async for pk in await m.Member.all_pks():
219219
pk_list.append(pk)
220220

221-
assert len(pk_list) == 2
221+
assert sorted(pk_list) == sorted([member.pk, member1.pk])
222+
223+
224+
@py_test_mark_asyncio
225+
async def test_all_pks_with_complex_pks(key_prefix):
226+
class City(JsonModel):
227+
name: str
228+
229+
class Meta:
230+
global_key_prefix = key_prefix
231+
model_key_prefix = "city"
232+
233+
city1 = City(
234+
pk="ca:on:toronto",
235+
name="Toronto",
236+
)
237+
238+
await city1.save()
239+
240+
city2 = City(
241+
pk="ca:qc:montreal",
242+
name="Montreal",
243+
)
244+
245+
await city2.save()
246+
247+
pk_list = []
248+
async for pk in await City.all_pks():
249+
pk_list.append(pk)
250+
251+
assert sorted(pk_list) == ["ca:on:toronto", "ca:qc:montreal"]
222252

223253

224254
@py_test_mark_asyncio

0 commit comments

Comments
 (0)