diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1f31a11..56d094f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -76,7 +76,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest ] - pyver: [ "3.9", "3.10", "3.11", "3.12", "pypy-3.9", "pypy-3.10" ] + pyver: [ "3.9", "3.10", "3.11", "3.12", "3.13", "pypy-3.9", "pypy-3.10" ] redisstack: [ "latest" ] fail-fast: false services: diff --git a/.gitignore b/.gitignore index 8f21f6a..5e27823 100644 --- a/.gitignore +++ b/.gitignore @@ -143,4 +143,7 @@ tests_sync/ # spelling cruft *.dic -.idea \ No newline at end of file +.idea + +# version files +.tool-versions diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index ae7f47c..5a5c75e 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -41,7 +41,7 @@ from .. import redis from ..checks import has_redis_json, has_redisearch from ..connections import get_redis_connection -from ..util import ASYNC_MODE +from ..util import ASYNC_MODE, has_numeric_inner_type, is_numeric_type from .encoders import jsonable_encoder from .render_tree import render_tree from .token_escaper import TokenEscaper @@ -406,7 +406,6 @@ class RediSearchFieldTypes(Enum): # TODO: How to handle Geo fields? -NUMERIC_TYPES = (float, int, decimal.Decimal) DEFAULT_PAGE_SIZE = 1000 @@ -578,7 +577,7 @@ def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldType ) elif field_type is bool: return RediSearchFieldTypes.TAG - elif any(issubclass(field_type, t) for t in NUMERIC_TYPES): + elif is_numeric_type(field_type): # Index numeric Python types as NUMERIC fields, so we can support # range queries. return RediSearchFieldTypes.NUMERIC @@ -1378,12 +1377,14 @@ def outer_type_or_annotation(field: FieldInfo): def should_index_field(field_info: Union[FieldInfo, PydanticFieldInfo]) -> bool: # for vector, full text search, and sortable fields, we always have to index # We could require the user to set index=True, but that would be a breaking change - index = getattr(field_info, "index", None) is True + _index = getattr(field_info, "index", None) + + index = _index is True vector_options = getattr(field_info, "vector_options", None) is not None full_text_search = getattr(field_info, "full_text_search", None) is True sortable = getattr(field_info, "sortable", None) is True - if index is False and any([vector_options, full_text_search, sortable]): + if _index is False and any([vector_options, full_text_search, sortable]): log.warning( "Field is marked as index=False, but it is a vector, full text search, or sortable field. " "This will be ignored and the field will be indexed.", @@ -1803,7 +1804,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): schema = cls.schema_for_type(name, embedded_cls, field_info) elif typ is bool: schema = f"{name} TAG" - elif any(issubclass(typ, t) for t in NUMERIC_TYPES): + elif is_numeric_type(typ): vector_options: Optional[VectorFieldOptions] = getattr( field_info, "vector_options", None ) @@ -1965,7 +1966,7 @@ def schema_for_type( json_path: str, name: str, name_prefix: str, - typ: Union[type[RedisModel], Any], + typ: Union[Type[RedisModel], Any], field_info: PydanticFieldInfo, parent_type: Optional[Any] = None, ) -> str: @@ -2002,9 +2003,7 @@ def schema_for_type( field_info, "vector_options", None ) try: - is_vector = vector_options and any( - issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES - ) + is_vector = vector_options and has_numeric_inner_type(typ) except IndexError: raise RedisModelError( f"Vector field '{name}' must be annotated as a container type" @@ -2102,7 +2101,11 @@ def schema_for_type( # a proper type, we can pull the type information from the origin of the first argument. if not isinstance(typ, type): type_args = typing_get_args(field_info.annotation) - typ = type_args[0].__origin__ + typ = ( + getattr(type_args[0], "__origin__", type_args[0]) + if type_args + else typ + ) # TODO: GEO field if is_vector and vector_options: @@ -2125,7 +2128,7 @@ def schema_for_type( schema += " CASESENSITIVE" elif typ is bool: schema = f"{path} AS {index_field_name} TAG" - elif any(issubclass(typ, t) for t in NUMERIC_TYPES): + elif is_numeric_type(typ): schema = f"{path} AS {index_field_name} NUMERIC" elif issubclass(typ, str): if full_text_search is True: diff --git a/aredis_om/util.py b/aredis_om/util.py index 268657e..fc6a534 100644 --- a/aredis_om/util.py +++ b/aredis_om/util.py @@ -1,4 +1,6 @@ +import decimal import inspect +from typing import Any, Type, get_args def is_async_mode() -> bool: @@ -10,3 +12,27 @@ async def f() -> None: ASYNC_MODE = is_async_mode() + +NUMERIC_TYPES = (float, int, decimal.Decimal) + + +def is_numeric_type(type_: Type[Any]) -> bool: + try: + return issubclass(type_, NUMERIC_TYPES) + except TypeError: + return False + + +def has_numeric_inner_type(type_: Type[Any]) -> bool: + """ + Check if the type has a numeric inner type. + """ + args = get_args(type_) + + if not args: + return False + + try: + return issubclass(args[0], NUMERIC_TYPES) + except TypeError: + return False diff --git a/pyproject.toml b/pyproject.toml index 95789ca..d89f326 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "redis-om" -version = "1.0.0-beta" +version = "1.0.2-beta" description = "Object mappings, and more, for Redis." authors = ["Redis OSS "] maintainers = ["Redis OSS "] @@ -22,6 +22,7 @@ classifiers = [ 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'Programming Language :: Python', ] include=[ @@ -36,7 +37,7 @@ include=[ [tool.poetry.dependencies] python = ">=3.8,<4.0" -redis = ">=3.5.3,<6.0.0" +redis = ">=3.5.3,<7.0.0" pydantic = ">=2.0.0,<3.0.0" click = "^8.0.1" types-redis = ">=3.5.9,<5.0.0" @@ -46,7 +47,7 @@ hiredis = ">=2.2.3,<4.0.0" more-itertools = ">=8.14,<11.0" setuptools = ">=70.0" -[tool.poetry.dev-dependencies] +[tool.poetry.group.dev.dependencies] mypy = "^1.9.0" pytest = "^8.0.2" ipdb = "^0.13.9" diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 99bc36b..c3b578a 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -180,15 +180,15 @@ async def test_full_text_search_queries(members, m): async def test_pagination_queries(members, m): member1, member2, member3 = members - actual = await m.Member.find(m.Member.last_name == "Brookins").page() + actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("id").page() assert actual == [member1, member2] - actual = await m.Member.find().page(1, 1) + actual = await m.Member.find().sort_by("id").page(1, 1) assert actual == [member2] - actual = await m.Member.find().page(0, 1) + actual = await m.Member.find().sort_by("id").page(0, 1) assert actual == [member1] diff --git a/tests/test_json_model.py b/tests/test_json_model.py index d6428e1..44ae9c6 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -755,8 +755,10 @@ async def test_sorting(members, m): async def test_case_sensitive(members, m): member1, member2, member3 = members - actual = await m.Member.find(m.Member.first_name == "Andrew").all() - assert actual == [member1, member3] + actual = await m.Member.find(m.Member.first_name == "Andrew").sort_by("pk").all() + assert sorted(actual, key=lambda m: m.pk) == sorted( + [member1, member3], key=lambda m: m.pk + ) actual = await m.Member.find(m.Member.first_name == "andrew").all() assert actual == [] diff --git a/tests/test_knn_expression.py b/tests/test_knn_expression.py index 929f7ba..258e102 100644 --- a/tests/test_knn_expression.py +++ b/tests/test_knn_expression.py @@ -1,7 +1,7 @@ # type: ignore import abc import struct -from typing import Optional +from typing import Optional, Type import pytest_asyncio @@ -29,7 +29,24 @@ class Meta: class Member(BaseJsonModel, index=True): name: str - embeddings: list[list[float]] = Field([], vector_options=vector_field_options) + embeddings: list[float] = Field([], vector_options=vector_field_options) + embeddings_score: Optional[float] = None + + await Migrator().run() + + return Member + + +@pytest_asyncio.fixture +async def n(key_prefix, redis): + class BaseJsonModel(JsonModel, abc.ABC): + class Meta: + global_key_prefix = key_prefix + database = redis + + class Member(BaseJsonModel, index=True): + name: str + nested: list[list[float]] = Field([], vector_options=vector_field_options) embeddings_score: Optional[float] = None await Migrator().run() @@ -42,10 +59,10 @@ def to_bytes(vectors: list[float]) -> bytes: @py_test_mark_asyncio -async def test_vector_field(m: type[JsonModel]): +async def test_vector_field(m: Type[JsonModel]): # Create a new instance of the Member model vectors = [0.3 for _ in range(DIMENSIONS)] - member = m(name="seth", embeddings=[vectors]) + member = m(name="seth", embeddings=vectors) # Save the member to Redis await member.save() @@ -63,3 +80,27 @@ async def test_vector_field(m: type[JsonModel]): assert len(members) == 1 assert members[0].embeddings_score is not None + + +@py_test_mark_asyncio +async def test_nested_vector_field(n: Type[JsonModel]): + # Create a new instance of the Member model + vectors = [0.3 for _ in range(DIMENSIONS)] + member = n(name="seth", nested=[vectors]) + + # Save the member to Redis + await member.save() + + knn = KNNExpression( + k=1, + vector_field=n.nested, + score_field=n.embeddings_score, + reference_vector=to_bytes(vectors), + ) + + query = n.find(knn=knn) + + members = await query.all() + + assert len(members) == 1 + assert members[0].embeddings_score is not None