Skip to content

Commit 685e03a

Browse files
authored
Fix handling of connections (#20)
Fix handling of connection in vectorstore
1 parent b48b3fd commit 685e03a

File tree

2 files changed

+80
-83
lines changed

2 files changed

+80
-83
lines changed

langchain_postgres/vectorstores.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Optional,
1515
Tuple,
1616
Type,
17+
Union,
1718
)
1819

1920
import numpy as np
@@ -177,35 +178,19 @@ def _results_to_docs(docs_and_scores: Any) -> List[Document]:
177178
return [doc for doc, _ in docs_and_scores]
178179

179180

180-
class PGVector(VectorStore):
181-
"""Vectorstore implementation using Postgres as the backend.
182-
183-
This code has been ported over from langchain_community with minimal changes
184-
to allow users to easily transition from langchain_community to langchain_postgres.
181+
Connection = Union[sqlalchemy.engine.Engine, str]
185182

186-
This vectorstore is in **beta** and **not** recommended for production
187-
usage at enterprise level.
188183

189-
It should be fine to use it for smaller datasets and/or for prototyping workflows.
190-
191-
The main issue with the existing implementation is:
184+
class PGVector(VectorStore):
185+
"""Vectorstore implementation using Postgres as the backend.
192186
193-
1) Handling of connections
194-
2) Lack of support for schema migrations
187+
Currently, there is no mechanism for supporting data migration.
195188
196-
If you're OK with these limitations (and know how to create migrations
197-
for the table) or are fine re-creating the collection and embeddings if
198-
the schema changes, then feel free to use this implementation.
189+
So breaking changes in the vectorstore schema will require the user to recreate
190+
the tables and re-add the documents.
199191
200-
Some changes had to be made to address issues with the community implementation:
201-
* langchain_postgres now works with psycopg3. Please update your
202-
connection strings from `postgresql+psycopg2://...` to
203-
`postgresql+psycopg://langchain:langchain@...`
204-
(yes, the driver name is `psycopg` not `psycopg3`)
205-
* The schema of the embedding store and collection have been changed to make
206-
add_documents work correctly with user specified ids, specifically
207-
when overwriting existing documents.
208-
You will need to recreate the tables if you are using an existing database.
192+
If this is a concern, please use a different vectorstore. If
193+
not, this implementation should be fine for your use case.
209194
210195
To use this vectorstore you need to have the `vector` extension installed.
211196
The `vector` extension is a Postgres extension that provides vector
@@ -228,33 +213,48 @@ class PGVector(VectorStore):
228213
vectorstore = PGVector.from_documents(
229214
embedding=embeddings,
230215
documents=docs,
216+
connection=connection_string,
231217
collection_name=collection_name,
232-
connection_string=connection_string,
233218
use_jsonb=True,
234219
)
220+
221+
222+
This code has been ported over from langchain_community with minimal changes
223+
to allow users to easily transition from langchain_community to langchain_postgres.
224+
225+
Some changes had to be made to address issues with the community implementation:
226+
* langchain_postgres now works with psycopg3. Please update your
227+
connection strings from `postgresql+psycopg2://...` to
228+
`postgresql+psycopg://langchain:langchain@...`
229+
(yes, the driver name is `psycopg` not `psycopg3`)
230+
* The schema of the embedding store and collection have been changed to make
231+
add_documents work correctly with user specified ids, specifically
232+
when overwriting existing documents.
233+
You will need to recreate the tables if you are using an existing database.
234+
* A Connection object has to be provided explicitly. Connections will not be
235+
picked up automatically based on env variables.
235236
"""
236237

237238
def __init__(
238239
self,
239-
connection_string: str,
240240
embedding_function: Embeddings,
241241
*,
242+
connection: Optional[Connection] = None,
242243
embedding_length: Optional[int] = None,
243244
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
244245
collection_metadata: Optional[dict] = None,
245246
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
246247
pre_delete_collection: bool = False,
247248
logger: Optional[logging.Logger] = None,
248249
relevance_score_fn: Optional[Callable[[float], float]] = None,
249-
connection: Optional[sqlalchemy.engine.Connection] = None,
250250
engine_args: Optional[dict[str, Any]] = None,
251251
use_jsonb: bool = True,
252252
create_extension: bool = True,
253253
) -> None:
254254
"""Initialize the PGVector store.
255255
256256
Args:
257-
connection_string: Postgres connection string.
257+
connection: Postgres connection string.
258258
embedding_function: Any embedding function implementing
259259
`langchain.embeddings.base.Embeddings` interface.
260260
embedding_length: The length of the embedding vector. (default: None)
@@ -278,7 +278,6 @@ def __init__(
278278
doesn't exist. disabling creation is useful when using ReadOnly
279279
Databases.
280280
"""
281-
self.connection_string = connection_string
282281
self.embedding_function = embedding_function
283282
self._embedding_length = embedding_length
284283
self.collection_name = collection_name
@@ -287,8 +286,17 @@ def __init__(
287286
self.pre_delete_collection = pre_delete_collection
288287
self.logger = logger or logging.getLogger(__name__)
289288
self.override_relevance_score_fn = relevance_score_fn
290-
self.engine_args = engine_args or {}
291-
self._bind = connection if connection else self._create_engine()
289+
290+
if isinstance(connection, str):
291+
self._bind = sqlalchemy.create_engine(url=connection, **(engine_args or {}))
292+
elif isinstance(connection, sqlalchemy.engine.Engine):
293+
self._bind = connection
294+
else:
295+
raise ValueError(
296+
"connection should be a connection string or an instance of "
297+
"sqlalchemy.engine.Engine"
298+
)
299+
292300
self.use_jsonb = use_jsonb
293301
self.create_extension = create_extension
294302

@@ -320,9 +328,6 @@ def __del__(self) -> None:
320328
def embeddings(self) -> Embeddings:
321329
return self.embedding_function
322330

323-
def _create_engine(self) -> sqlalchemy.engine.Engine:
324-
return sqlalchemy.create_engine(url=self.connection_string, **self.engine_args)
325-
326331
def create_vector_extension(self) -> None:
327332
try:
328333
with Session(self._bind) as session: # type: ignore[arg-type]
@@ -421,7 +426,7 @@ def __from(
421426
ids: Optional[List[str]] = None,
422427
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
423428
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
424-
connection_string: Optional[str] = None,
429+
connection: Optional[str] = None,
425430
pre_delete_collection: bool = False,
426431
*,
427432
use_jsonb: bool = True,
@@ -432,11 +437,9 @@ def __from(
432437

433438
if not metadatas:
434439
metadatas = [{} for _ in texts]
435-
if connection_string is None:
436-
connection_string = cls.get_connection_string(kwargs)
437440

438441
store = cls(
439-
connection_string=connection_string,
442+
connection=connection,
440443
collection_name=collection_name,
441444
embedding_function=embedding,
442445
distance_strategy=distance_strategy,
@@ -1054,18 +1057,16 @@ def from_existing_index(
10541057
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
10551058
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
10561059
pre_delete_collection: bool = False,
1060+
connection: Optional[Connection] = None,
10571061
**kwargs: Any,
10581062
) -> PGVector:
10591063
"""
10601064
Get instance of an existing PGVector store.This method will
10611065
return the instance of the store without inserting any new
10621066
embeddings
10631067
"""
1064-
1065-
connection_string = cls.get_connection_string(kwargs)
1066-
10671068
store = cls(
1068-
connection_string=connection_string,
1069+
connection=connection,
10691070
collection_name=collection_name,
10701071
embedding_function=embedding,
10711072
distance_strategy=distance_strategy,
@@ -1097,6 +1098,7 @@ def from_documents(
10971098
documents: List[Document],
10981099
embedding: Embeddings,
10991100
*,
1101+
connection: Optional[Connection] = None,
11001102
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
11011103
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
11021104
ids: Optional[List[str]] = None,
@@ -1113,16 +1115,14 @@ def from_documents(
11131115

11141116
texts = [d.page_content for d in documents]
11151117
metadatas = [d.metadata for d in documents]
1116-
connection_string = cls.get_connection_string(kwargs)
1117-
1118-
kwargs["connection_string"] = connection_string
11191118

11201119
return cls.from_texts(
11211120
texts=texts,
11221121
pre_delete_collection=pre_delete_collection,
11231122
embedding=embedding,
11241123
distance_strategy=distance_strategy,
11251124
metadatas=metadatas,
1125+
connection=connection,
11261126
ids=ids,
11271127
collection_name=collection_name,
11281128
use_jsonb=use_jsonb,
@@ -1140,6 +1140,8 @@ def connection_string_from_db_params(
11401140
password: str,
11411141
) -> str:
11421142
"""Return connection string from database parameters."""
1143+
if driver != "psycopg":
1144+
raise NotImplementedError("Only psycopg3 driver is supported")
11431145
return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}"
11441146

11451147
def _select_relevance_score_fn(self) -> Callable[[float], float]:

0 commit comments

Comments
 (0)