11from  __future__ import  annotations 
22
3- import  contextlib 
43import  enum 
54import  logging 
65import  uuid 
76from  typing  import  (
87    Any ,
98    Callable ,
109    Dict ,
11-     Generator ,
1210    Iterable ,
1311    List ,
1412    Optional ,
2119import  sqlalchemy 
2220from  sqlalchemy  import  SQLColumnExpression , cast , delete , func 
2321from  sqlalchemy .dialects .postgresql  import  JSON , JSONB , JSONPATH , UUID , insert 
24- from  sqlalchemy .orm  import  Session , relationship 
22+ from  sqlalchemy .orm  import  Session , relationship ,  sessionmaker 
2523
2624try :
2725    from  sqlalchemy .orm  import  declarative_base 
@@ -288,15 +286,19 @@ def __init__(
288286        self .override_relevance_score_fn  =  relevance_score_fn 
289287
290288        if  isinstance (connection , str ):
291-             self ._bind  =  sqlalchemy .create_engine (url = connection , ** (engine_args  or  {}))
289+             self ._engine  =  sqlalchemy .create_engine (
290+                 url = connection , ** (engine_args  or  {})
291+             )
292292        elif  isinstance (connection , sqlalchemy .engine .Engine ):
293-             self ._bind  =  connection 
293+             self ._engine  =  connection 
294294        else :
295295            raise  ValueError (
296296                "connection should be a connection string or an instance of " 
297297                "sqlalchemy.engine.Engine" 
298298            )
299299
300+         self ._session_maker  =  sessionmaker (bind = self ._engine )
301+ 
300302        self .use_jsonb  =  use_jsonb 
301303        self .create_extension  =  create_extension 
302304
@@ -321,16 +323,16 @@ def __post_init__(
321323        self .create_collection ()
322324
323325    def  __del__ (self ) ->  None :
324-         if  isinstance (self ._bind , sqlalchemy .engine .Connection ):
325-             self ._bind .close ()
326+         if  isinstance (self ._engine , sqlalchemy .engine .Connection ):
327+             self ._engine .close ()
326328
327329    @property  
328330    def  embeddings (self ) ->  Embeddings :
329331        return  self .embedding_function 
330332
331333    def  create_vector_extension (self ) ->  None :
332334        try :
333-             with  Session ( self ._bind ) as  session :  # type: ignore[arg-type] 
335+             with  self ._session_maker ( ) as  session :  # type: ignore[arg-type] 
334336                # The advisor lock fixes issue arising from concurrent 
335337                # creation of the vector extension. 
336338                # https://github.com/langchain-ai/langchain/issues/12933 
@@ -348,36 +350,31 @@ def create_vector_extension(self) -> None:
348350            raise  Exception (f"Failed to create vector extension: { e }  " ) from  e 
349351
350352    def  create_tables_if_not_exists (self ) ->  None :
351-         with  Session ( self ._bind ) as  session ,  session . begin ():   # type: ignore[arg-type] 
353+         with  self ._session_maker ( ) as  session : 
352354            Base .metadata .create_all (session .get_bind ())
353355
354356    def  drop_tables (self ) ->  None :
355-         with  Session ( self ._bind ) as  session ,  session . begin ():   # type: ignore[arg-type] 
357+         with  self ._session_maker ( ) as  session : 
356358            Base .metadata .drop_all (session .get_bind ())
357359
358360    def  create_collection (self ) ->  None :
359361        if  self .pre_delete_collection :
360362            self .delete_collection ()
361-         with  Session ( self ._bind ) as  session :  # type: ignore[arg-type] 
363+         with  self ._session_maker ( ) as  session :  # type: ignore[arg-type] 
362364            self .CollectionStore .get_or_create (
363365                session , self .collection_name , cmetadata = self .collection_metadata 
364366            )
365367
366368    def  delete_collection (self ) ->  None :
367369        self .logger .debug ("Trying to delete collection" )
368-         with  Session ( self ._bind ) as  session :  # type: ignore[arg-type] 
370+         with  self ._session_maker ( ) as  session :  # type: ignore[arg-type] 
369371            collection  =  self .get_collection (session )
370372            if  not  collection :
371373                self .logger .warning ("Collection not found" )
372374                return 
373375            session .delete (collection )
374376            session .commit ()
375377
376-     @contextlib .contextmanager  
377-     def  _make_session (self ) ->  Generator [Session , None , None ]:
378-         """Create a context manager for the session, bind to _conn string.""" 
379-         yield  Session (self ._bind )  # type: ignore[arg-type] 
380- 
381378    def  delete (
382379        self ,
383380        ids : Optional [List [str ]] =  None ,
@@ -390,7 +387,7 @@ def delete(
390387            ids: List of ids to delete. 
391388            collection_only: Only delete ids in the collection. 
392389        """ 
393-         with  Session ( self ._bind ) as  session :   # type: ignore[arg-type] 
390+         with  self ._session_maker ( ) as  session :
394391            if  ids  is  not   None :
395392                self .logger .debug (
396393                    "Trying to delete vectors by ids (represented by the model " 
@@ -476,7 +473,7 @@ def add_embeddings(
476473        if  not  metadatas :
477474            metadatas  =  [{} for  _  in  texts ]
478475
479-         with  Session ( self ._bind ) as  session :  # type: ignore[arg-type] 
476+         with  self ._session_maker ( ) as  session :  # type: ignore[arg-type] 
480477            collection  =  self .get_collection (session )
481478            if  not  collection :
482479                raise  ValueError ("Collection not found" )
@@ -901,7 +898,7 @@ def __query_collection(
901898        filter : Optional [Dict [str , str ]] =  None ,
902899    ) ->  List [Any ]:
903900        """Query the collection.""" 
904-         with  Session ( self ._bind ) as  session :  # type: ignore[arg-type] 
901+         with  self ._session_maker ( ) as  session :  # type: ignore[arg-type] 
905902            collection  =  self .get_collection (session )
906903            if  not  collection :
907904                raise  ValueError ("Collection not found" )
@@ -1066,6 +1063,7 @@ def from_existing_index(
10661063            embeddings = embedding ,
10671064            distance_strategy = distance_strategy ,
10681065            pre_delete_collection = pre_delete_collection ,
1066+             ** kwargs ,
10691067        )
10701068
10711069        return  store 
0 commit comments