Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 106 additions & 6 deletions langchain_postgres/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,14 @@ def __init__(
self,
connection_string: str,
embedding_function: Embeddings,
*,
embedding_length: Optional[int] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
collection_metadata: Optional[dict] = None,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
pre_delete_collection: bool = False,
logger: Optional[logging.Logger] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None,
*,
connection: Optional[sqlalchemy.engine.Connection] = None,
engine_args: Optional[dict[str, Any]] = None,
use_jsonb: bool = True,
Expand Down Expand Up @@ -712,6 +712,99 @@ def _handle_field_filter(
else:
raise NotImplementedError()

def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def]
"""Deprecated functionality.

This is for backwards compatibility with the JSON based schema for metadata.
It uses incorrect operator syntax (operators are not prefixed with $).

This implementation is not efficient, and has bugs associated with
the way that it handles numeric filter clauses.
"""
IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne"
EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and"

value_case_insensitive = {k.lower(): v for k, v in value.items()}
if IN in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.in_(
value_case_insensitive[IN]
)
elif NIN in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.not_in(
value_case_insensitive[NIN]
)
elif BETWEEN in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.between(
str(value_case_insensitive[BETWEEN][0]),
str(value_case_insensitive[BETWEEN][1]),
)
elif GT in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext > str(
value_case_insensitive[GT]
)
elif LT in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext < str(
value_case_insensitive[LT]
)
elif NE in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext != str(
value_case_insensitive[NE]
)
elif EQ in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str(
value_case_insensitive[EQ]
)
elif LIKE in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.like(
value_case_insensitive[LIKE]
)
elif CONTAINS in map(str.lower, value):
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.contains(
value_case_insensitive[CONTAINS]
)
elif OR in map(str.lower, value):
or_clauses = [
self._create_filter_clause(key, sub_value)
for sub_value in value_case_insensitive[OR]
]
filter_by_metadata = sqlalchemy.or_(*or_clauses)
elif AND in map(str.lower, value):
and_clauses = [
self._create_filter_clause(key, sub_value)
for sub_value in value_case_insensitive[AND]
]
filter_by_metadata = sqlalchemy.and_(*and_clauses)

else:
filter_by_metadata = None

return filter_by_metadata

def _create_filter_clause_json_deprecated(
self, filter: Any
) -> List[SQLColumnExpression]:
"""Convert filters from IR to SQL clauses.

**DEPRECATED** This functionality will be deprecated in the future.

It implements translation of filters for a schema that uses JSON
for metadata rather than the JSONB field which is more efficient
for querying.
"""
filter_clauses = []
for key, value in filter.items():
if isinstance(value, dict):
filter_by_metadata = self._create_filter_clause_deprecated(key, value)

if filter_by_metadata is not None:
filter_clauses.append(filter_by_metadata)
else:
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str(
value
)
filter_clauses.append(filter_by_metadata)
return filter_clauses

def _create_filter_clause(self, filters: Any) -> Any:
"""Convert LangChain IR filter representation to matching SQLAlchemy clauses.

Expand Down Expand Up @@ -812,9 +905,14 @@ def __query_collection(

filter_by = [self.EmbeddingStore.collection_id == collection.uuid]
if filter:
filter_clauses = self._create_filter_clause(filter)
if filter_clauses is not None:
filter_by.append(filter_clauses)
if self.use_jsonb:
filter_clauses = self._create_filter_clause(filter)
if filter_clauses is not None:
filter_by.append(filter_clauses)
else:
# Old way of doing things
filter_clauses = self._create_filter_clause_json_deprecated(filter)
filter_by.extend(filter_clauses)

_type = self.EmbeddingStore

Expand Down Expand Up @@ -863,11 +961,11 @@ def from_texts(
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
*,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
*,
use_jsonb: bool = True,
**kwargs: Any,
) -> PGVector:
Expand Down Expand Up @@ -897,6 +995,7 @@ def from_embeddings(
cls,
text_embeddings: List[Tuple[str, List[float]]],
embedding: Embeddings,
*,
metadatas: Optional[List[dict]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
Expand Down Expand Up @@ -951,6 +1050,7 @@ def from_embeddings(
def from_existing_index(
cls: Type[PGVector],
embedding: Embeddings,
*,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
pre_delete_collection: bool = False,
Expand Down Expand Up @@ -996,11 +1096,11 @@ def from_documents(
cls: Type[PGVector],
documents: List[Document],
embedding: Embeddings,
*,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
*,
use_jsonb: bool = True,
**kwargs: Any,
) -> PGVector:
Expand Down