Skip to content

Add support for KNN vector similarity search #513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aredis_om/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
FindQuery,
HashModel,
JsonModel,
VectorFieldOptions,
KNNExpression,
NotFoundError,
QueryNotSupportedError,
QuerySyntaxError,
Expand Down
2 changes: 2 additions & 0 deletions aredis_om/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Field,
HashModel,
JsonModel,
VectorFieldOptions,
KNNExpression,
NotFoundError,
RedisModel,
)
195 changes: 186 additions & 9 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,24 @@ def tree(self):
return render_tree(self)


@dataclasses.dataclass
class KNNExpression:
k: int
vector_field: ModelField
reference_vector: bytes

def __str__(self):
return f"KNN $K @{self.vector_field.name} $knn_ref_vector"

@property
def query_params(self) -> Dict[str, Union[str, bytes]]:
return {"K": str(self.k), "knn_ref_vector": self.reference_vector}

@property
def score_field(self) -> str:
return f"__{self.vector_field.name}_score"


ExpressionOrNegated = Union[Expression, NegatedExpression]


Expand Down Expand Up @@ -349,8 +367,9 @@ def __init__(
self,
expressions: Sequence[ExpressionOrNegated],
model: Type["RedisModel"],
knn: Optional[KNNExpression] = None,
offset: int = 0,
limit: int = DEFAULT_PAGE_SIZE,
limit: Optional[int] = None,
page_size: int = DEFAULT_PAGE_SIZE,
sort_fields: Optional[List[str]] = None,
nocontent: bool = False,
Expand All @@ -364,13 +383,16 @@ def __init__(

self.expressions = expressions
self.model = model
self.knn = knn
self.offset = offset
self.limit = limit
self.limit = limit or (self.knn.k if self.knn else DEFAULT_PAGE_SIZE)
self.page_size = page_size
self.nocontent = nocontent

if sort_fields:
self.sort_fields = self.validate_sort_fields(sort_fields)
elif self.knn:
self.sort_fields = [self.knn.score_field]
else:
self.sort_fields = []

Expand Down Expand Up @@ -425,11 +447,26 @@ def query(self):
if self._query:
return self._query
self._query = self.resolve_redisearch_query(self.expression)
if self.knn:
self._query = (
self._query
if self._query.startswith("(") or self._query == "*"
else f"({self._query})"
) + f"=>[{self.knn}]"
return self._query

@property
def query_params(self):
params: List[Union[str, bytes]] = []
if self.knn:
params += [attr for kv in self.knn.query_params.items() for attr in kv]
return params

def validate_sort_fields(self, sort_fields: List[str]):
for sort_field in sort_fields:
field_name = sort_field.lstrip("-")
if self.knn and field_name == self.knn.score_field:
continue
if field_name not in self.model.__fields__:
raise QueryNotSupportedError(
f"You tried sort by {field_name}, but that field "
Expand Down Expand Up @@ -728,10 +765,27 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
return result

async def execute(self, exhaust_results=True, return_raw_result=False):
args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
args: List[Union[str, bytes]] = [
"FT.SEARCH",
self.model.Meta.index_name,
self.query,
*self.pagination,
]
if self.sort_fields:
args += self.resolve_redisearch_sort_fields()

if self.query_params:
args += ["PARAMS", str(len(self.query_params))] + self.query_params

if self.knn:
# Ensure DIALECT is at least 2
if "DIALECT" not in args:
args += ["DIALECT", "2"]
else:
i_dialect = args.index("DIALECT") + 1
if int(args[i_dialect]) < 2:
args[i_dialect] = "2"

if self.nocontent:
args.append("NOCONTENT")

Expand Down Expand Up @@ -917,11 +971,13 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
sortable = kwargs.pop("sortable", Undefined)
index = kwargs.pop("index", Undefined)
full_text_search = kwargs.pop("full_text_search", Undefined)
vector_options = kwargs.pop("vector_options", None)
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.sortable = sortable
self.index = index
self.full_text_search = full_text_search
self.vector_options = vector_options


class RelationshipInfo(Representation):
Expand All @@ -935,6 +991,94 @@ def __init__(
self.link_model = link_model


@dataclasses.dataclass
class VectorFieldOptions:
class ALGORITHM(Enum):
FLAT = "FLAT"
HNSW = "HNSW"

class TYPE(Enum):
FLOAT32 = "FLOAT32"
FLOAT64 = "FLOAT64"

class DISTANCE_METRIC(Enum):
L2 = "L2"
IP = "IP"
COSINE = "COSINE"

algorithm: ALGORITHM
type: TYPE
dimension: int
distance_metric: DISTANCE_METRIC

# Common optional parameters
initial_cap: Optional[int] = None

# Optional parameters for FLAT
block_size: Optional[int] = None

# Optional parameters for HNSW
m: Optional[int] = None
ef_construction: Optional[int] = None
ef_runtime: Optional[int] = None
epsilon: Optional[float] = None

@staticmethod
def flat(
type: TYPE,
dimension: int,
distance_metric: DISTANCE_METRIC,
initial_cap: Optional[int] = None,
block_size: Optional[int] = None,
):
return VectorFieldOptions(
algorithm=VectorFieldOptions.ALGORITHM.FLAT,
type=type,
dimension=dimension,
distance_metric=distance_metric,
initial_cap=initial_cap,
block_size=block_size,
)

@staticmethod
def hnsw(
type: TYPE,
dimension: int,
distance_metric: DISTANCE_METRIC,
initial_cap: Optional[int] = None,
m: Optional[int] = None,
ef_construction: Optional[int] = None,
ef_runtime: Optional[int] = None,
epsilon: Optional[float] = None,
):
return VectorFieldOptions(
algorithm=VectorFieldOptions.ALGORITHM.HNSW,
type=type,
dimension=dimension,
distance_metric=distance_metric,
initial_cap=initial_cap,
m=m,
ef_construction=ef_construction,
ef_runtime=ef_runtime,
epsilon=epsilon,
)

@property
def schema(self):
attr = []
for k, v in vars(self).items():
if k == "algorithm" or v is None:
continue
attr.extend(
[
k.upper() if k != "dimension" else "DIM",
str(v) if not isinstance(v, Enum) else v.name,
]
)

return " ".join([f"VECTOR {self.algorithm.name} {len(attr)}"] + attr)


def Field(
default: Any = Undefined,
*,
Expand Down Expand Up @@ -964,6 +1108,7 @@ def Field(
sortable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
full_text_search: Union[bool, UndefinedType] = Undefined,
vector_options: Optional[VectorFieldOptions] = None,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
Expand Down Expand Up @@ -991,6 +1136,7 @@ def Field(
sortable=sortable,
index=index,
full_text_search=full_text_search,
vector_options=vector_options,
**current_schema_extra,
)
field_info._validate()
Expand Down Expand Up @@ -1083,6 +1229,10 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
new_class._meta.primary_key = PrimaryKey(
name=field_name, field=field
)
if field.field_info.vector_options:
score_attr = f"_{field_name}_score"
setattr(new_class, score_attr, None)
new_class.__annotations__[score_attr] = Union[float, None]

if not getattr(new_class._meta, "global_key_prefix", None):
new_class._meta.global_key_prefix = getattr(
Expand Down Expand Up @@ -1216,8 +1366,12 @@ def db(cls):
return cls._meta.database

@classmethod
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery:
return FindQuery(expressions=expressions, model=cls)
def find(
cls,
*expressions: Union[Any, Expression],
knn: Optional[KNNExpression] = None,
) -> FindQuery:
return FindQuery(expressions=expressions, knn=knn, model=cls)

@classmethod
def from_redis(cls, res: Any):
Expand All @@ -1237,7 +1391,7 @@ def to_string(s):
for i in range(1, len(res), step):
if res[i + offset] is None:
continue
fields = dict(
fields: Dict[str, str] = dict(
zip(
map(to_string, res[i + offset][::2]),
map(to_string, res[i + offset][1::2]),
Expand All @@ -1247,6 +1401,9 @@ def to_string(s):
if fields.get("$"):
json_fields = json.loads(fields.pop("$"))
doc = cls(**json_fields)
for k, v in fields.items():
if k.startswith("__") and k.endswith("_score"):
setattr(doc, k[1:], float(v))
else:
doc = cls(**fields)

Expand Down Expand Up @@ -1474,7 +1631,13 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
embedded_cls = embedded_cls[0]
schema = cls.schema_for_type(name, embedded_cls, field_info)
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
schema = f"{name} NUMERIC"
vector_options: Optional[VectorFieldOptions] = getattr(
field_info, "vector_options", None
)
if vector_options:
schema = f"{name} {vector_options.schema}"
else:
schema = f"{name} NUMERIC"
elif issubclass(typ, str):
if getattr(field_info, "full_text_search", False) is True:
schema = (
Expand Down Expand Up @@ -1623,10 +1786,22 @@ def schema_for_type(
# Not a class, probably a type annotation
field_is_model = False

vector_options: Optional[VectorFieldOptions] = getattr(
field_info, "vector_options", None
)
try:
is_vector = vector_options and any(
issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES
)
except IndexError:
raise RedisModelError(
f"Vector field '{name}' must be annotated as a container type"
)

# When we encounter a list or model field, we need to descend
# into the values of the list or the fields of the model to
# find any values marked as indexed.
if is_container_type:
if is_container_type and not is_vector:
field_type = get_origin(typ)
embedded_cls = get_args(typ)
if not embedded_cls:
Expand Down Expand Up @@ -1689,7 +1864,9 @@ def schema_for_type(
)

# TODO: GEO field
if parent_is_container_type or parent_is_model_in_container:
if is_vector and vector_options:
schema = f"{path} AS {index_field_name} {vector_options.schema}"
elif parent_is_container_type or parent_is_model_in_container:
if typ is not str:
raise RedisModelError(
"In this Preview release, list and tuple fields can only "
Expand Down