Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Start PostgreSQL/PGVector.
docker run --rm -it --name pgvector-container \
-e POSTGRES_USER=langchain \
-e POSTGRES_PASSWORD=langchain \
-e POSTGRES_DB=langchain \
-e POSTGRES_DB=langchain_test \
-p 6024:5432 pgvector/pgvector:pg16 \
postgres -c log_statement=all
```
Expand Down
14 changes: 14 additions & 0 deletions langchain_postgres/v2/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,20 @@ def max_marginal_relevance_search_with_score_by_vector(
)
)

async def aapply_hybrid_search_index(
self,
concurrently: bool = False,
) -> None:
"""Creates a TSV index in the vector store table if possible."""
return await self._engine._run_as_async(self.__vs.aapply_hybrid_search_index(concurrently=concurrently))

def apply_hybrid_search_index(
self,
concurrently: bool = False,
) -> None:
"""Creates a TSV index in the vector store table if possible."""
return self._engine._run_as_sync(self.__vs.aapply_hybrid_search_index(concurrently=concurrently))

async def aapply_vector_index(
self,
index: BaseIndex,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ test = [
"codespell>=2.4.1",
"langchain-tests==0.3.7",
"mypy>=1.15.0",
"psycopg[binary]>=3,<4",
"pytest>=8.3.4",
"pytest-asyncio>=0.25.3",
"pytest-cov>=6.0.0",
Expand Down
119 changes: 116 additions & 3 deletions tests/unit_tests/v2/test_pg_vectorstore_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy import text

from langchain_postgres import PGEngine, PGVectorStore
from langchain_postgres.v2.hybrid_search_config import HybridSearchConfig
from langchain_postgres.v2.indexes import (
DistanceStrategy,
HNSWIndex,
Expand All @@ -17,12 +18,14 @@
from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING

uuid_str = str(uuid.uuid4()).replace("-", "_")
uuid_str_sync = str(uuid.uuid4()).replace("-", "_")
uuid_str_async = str(uuid.uuid4()).replace("-", "_")
DEFAULT_TABLE = "default" + uuid_str
DEFAULT_TABLE_ASYNC = "default_sync" + uuid_str_sync
DEFAULT_HYBRID_TABLE = "hybrid" + uuid_str
DEFAULT_TABLE_ASYNC = "default_async" + uuid_str_async
DEFAULT_HYBRID_TABLE_ASYNC = "hybrid_async" + uuid_str_async
CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_")
DEFAULT_INDEX_NAME = "index" + uuid_str
DEFAULT_INDEX_NAME_ASYNC = "index" + uuid_str_sync
DEFAULT_INDEX_NAME_ASYNC = "index" + uuid_str_async
VECTOR_SIZE = 768

embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)
Expand Down Expand Up @@ -64,6 +67,7 @@ async def engine(self) -> AsyncIterator[PGEngine]:
engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
yield engine
await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}")
await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_HYBRID_TABLE}")
await engine.close()

@pytest_asyncio.fixture(scope="class")
Expand Down Expand Up @@ -118,6 +122,60 @@ async def test_is_valid_index(self, vs: PGVectorStore) -> None:
is_valid = vs.is_valid_index("invalid_index")
assert not is_valid

async def test_apply_hybrid_search_index_non_hybrid_search_vs(
self, vs: PGVectorStore
) -> None:
with pytest.raises(ValueError):
vs.apply_hybrid_search_index()

async def test_apply_hybrid_search_index_table_without_tsv_column(
self, engine: PGEngine, vs: PGVectorStore
) -> None:
tsv_index_name = "tsv_index_on_table_without_tsv_column_" + uuid_str
vs_hybrid = PGVectorStore.create_sync(
engine,
embedding_service=embeddings_service,
table_name=DEFAULT_TABLE,
hybrid_search_config=HybridSearchConfig(index_name=tsv_index_name),
)
is_valid_index = vs_hybrid.is_valid_index(tsv_index_name)
assert is_valid_index == False
vs_hybrid.apply_hybrid_search_index()
assert vs_hybrid.is_valid_index(tsv_index_name)
vs_hybrid.drop_vector_index(tsv_index_name)
is_valid_index = vs_hybrid.is_valid_index(tsv_index_name)
assert is_valid_index == False

async def test_apply_hybrid_search_index_table_with_tsv_column(
self, engine: PGEngine
) -> None:
tsv_index_name = "tsv_index_on_table_with_tsv_column_" + uuid_str
config = HybridSearchConfig(
tsv_column="tsv_column",
tsv_lang="pg_catalog.english",
index_name=tsv_index_name,
)
engine.init_vectorstore_table(
DEFAULT_HYBRID_TABLE,
VECTOR_SIZE,
hybrid_search_config=config,
)
vs_hybrid = PGVectorStore.create_sync(
engine,
embedding_service=embeddings_service,
table_name=DEFAULT_HYBRID_TABLE,
hybrid_search_config=config,
)
is_valid_index = vs_hybrid.is_valid_index(tsv_index_name)
assert is_valid_index == False
vs_hybrid.apply_hybrid_search_index()
assert vs_hybrid.is_valid_index(tsv_index_name)
vs_hybrid.reindex(tsv_index_name)
assert vs_hybrid.is_valid_index(tsv_index_name)
vs_hybrid.drop_vector_index(tsv_index_name)
is_valid_index = vs_hybrid.is_valid_index(tsv_index_name)
assert is_valid_index == False


@pytest.mark.enable_socket
@pytest.mark.asyncio(scope="class")
Expand All @@ -127,6 +185,7 @@ async def engine(self) -> AsyncIterator[PGEngine]:
engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
yield engine
await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_ASYNC}")
await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_HYBRID_TABLE_ASYNC}")
await engine.close()

@pytest_asyncio.fixture(scope="class")
Expand Down Expand Up @@ -179,3 +238,57 @@ async def test_aapply_vector_index_ivfflat(self, vs: PGVectorStore) -> None:
async def test_is_valid_index(self, vs: PGVectorStore) -> None:
is_valid = await vs.ais_valid_index("invalid_index")
assert not is_valid

async def test_aapply_hybrid_search_index_non_hybrid_search_vs(
self, vs: PGVectorStore
) -> None:
with pytest.raises(ValueError):
await vs.aapply_hybrid_search_index()

async def test_aapply_hybrid_search_index_table_without_tsv_column(
self, engine: PGEngine, vs: PGVectorStore
) -> None:
tsv_index_name = "tsv_index_on_table_without_tsv_column_" + uuid_str_async
vs_hybrid = await PGVectorStore.create(
engine,
embedding_service=embeddings_service,
table_name=DEFAULT_TABLE_ASYNC,
hybrid_search_config=HybridSearchConfig(index_name=tsv_index_name),
)
is_valid_index = await vs_hybrid.ais_valid_index(tsv_index_name)
assert is_valid_index == False
await vs_hybrid.aapply_hybrid_search_index()
assert await vs_hybrid.ais_valid_index(tsv_index_name)
await vs_hybrid.adrop_vector_index(tsv_index_name)
is_valid_index = await vs_hybrid.ais_valid_index(tsv_index_name)
assert is_valid_index == False

async def test_aapply_hybrid_search_index_table_with_tsv_column(
self, engine: PGEngine
) -> None:
tsv_index_name = "tsv_index_on_table_with_tsv_column_" + uuid_str_async
config = HybridSearchConfig(
tsv_column="tsv_column",
tsv_lang="pg_catalog.english",
index_name=tsv_index_name,
)
await engine.ainit_vectorstore_table(
DEFAULT_HYBRID_TABLE_ASYNC,
VECTOR_SIZE,
hybrid_search_config=config,
)
vs_hybrid = await PGVectorStore.create(
engine,
embedding_service=embeddings_service,
table_name=DEFAULT_HYBRID_TABLE_ASYNC,
hybrid_search_config=config,
)
is_valid_index = await vs_hybrid.ais_valid_index(tsv_index_name)
assert is_valid_index == False
await vs_hybrid.aapply_hybrid_search_index()
assert await vs_hybrid.ais_valid_index(tsv_index_name)
await vs_hybrid.areindex(tsv_index_name)
assert await vs_hybrid.ais_valid_index(tsv_index_name)
await vs_hybrid.adrop_vector_index(tsv_index_name)
is_valid_index = await vs_hybrid.ais_valid_index(tsv_index_name)
assert is_valid_index == False
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD", "langchain")
POSTGRES_DB = os.environ.get("POSTGRES_DB", "langchain_test")

POSTGRES_PORT = os.environ.get("POSTGRES_PORT", "5432")
POSTGRES_PORT = os.environ.get("POSTGRES_PORT", "6024")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert this change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done revert


DSN = (
f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}"
Expand Down
Loading
Loading