88import_qdrant ()
99
1010from  qdrant_client  import  QdrantClient   # pylint: disable=C0413 
11- from  qdrant_client .models  import  PointStruct , HnswConfigDiff , VectorParams , OptimizersConfigDiff , \
12-     Distance   # pylint: disable=C0413 
11+ from  qdrant_client .models  import  (
12+     PointStruct ,
13+     HnswConfigDiff ,
14+     VectorParams ,
15+     OptimizersConfigDiff ,
16+     Distance ,
17+ )  # pylint: disable=C0413 
1318
1419
1520class  QdrantVectorStore (VectorBase ):
16- 
1721    def  __init__ (
18-              self ,
19-              url : Optional [str ] =  None ,
20-              port : Optional [int ] =  6333 ,
21-              grpc_port : int  =  6334 ,
22-              prefer_grpc : bool  =  False ,
23-              https : Optional [bool ] =  None ,
24-              api_key : Optional [str ] =  None ,
25-              prefix : Optional [str ] =  None ,
26-              timeout : Optional [float ] =  None ,
27-              host : Optional [str ] =  None ,
28-              collection_name : Optional [str ] =  "gptcache" ,
29-              location : Optional [str ] =  "./qdrant" ,
30-              dimension : int  =  0 ,
31-              top_k : int  =  1 ,
32-              flush_interval_sec : int  =  5 ,
33-              index_params : Optional [dict ] =  None ,
22+         self ,
23+         url : Optional [str ] =  None ,
24+         port : Optional [int ] =  6333 ,
25+         grpc_port : int  =  6334 ,
26+         prefer_grpc : bool  =  False ,
27+         https : Optional [bool ] =  None ,
28+         api_key : Optional [str ] =  None ,
29+         prefix : Optional [str ] =  None ,
30+         timeout : Optional [float ] =  None ,
31+         host : Optional [str ] =  None ,
32+         collection_name : Optional [str ] =  "gptcache" ,
33+         location : Optional [str ] =  "./qdrant" ,
34+         dimension : int  =  0 ,
35+         top_k : int  =  1 ,
36+         flush_interval_sec : int  =  5 ,
37+         index_params : Optional [dict ] =  None ,
3438    ):
3539        if  dimension  <=  0 :
3640            raise  ValueError (
@@ -44,13 +48,17 @@ def __init__(
4448        if  self ._in_memory  or  location  is  not None :
4549            self ._create_local (location )
4650        else :
47-             self ._create_remote (url , port , api_key , timeout , host , grpc_port , prefer_grpc , prefix , https )
51+             self ._create_remote (
52+                 url , port , api_key , timeout , host , grpc_port , prefer_grpc , prefix , https 
53+             )
4854        self ._create_collection (collection_name , flush_interval_sec , index_params )
4955
5056    def  _create_local (self , location ):
5157        self ._client  =  QdrantClient (location = location )
5258
53-     def  _create_remote (self , url , port , api_key , timeout , host , grpc_port , prefer_grpc , prefix , https ):
59+     def  _create_remote (
60+         self , url , port , api_key , timeout , host , grpc_port , prefer_grpc , prefix , https 
61+     ):
5462        self ._client  =  QdrantClient (
5563            url = url ,
5664            port = port ,
@@ -63,45 +71,70 @@ def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_gr
6371            https = https ,
6472        )
6573
66-     def  _create_collection (self , collection_name : str , flush_interval_sec : int , index_params : Optional [dict ] =  None ):
74+     def  _create_collection (
75+         self ,
76+         collection_name : str ,
77+         flush_interval_sec : int ,
78+         index_params : Optional [dict ] =  None ,
79+     ):
6780        hnsw_config  =  HnswConfigDiff (** (index_params  or  {}))
68-         vectors_config  =  VectorParams (size = self .dimension , distance = Distance .COSINE ,
69-                                       hnsw_config = hnsw_config )
70-         optimizers_config  =  OptimizersConfigDiff (deleted_threshold = 0.2 , vacuum_min_vector_number = 1000 ,
71-                                                  flush_interval_sec = flush_interval_sec )
81+         vectors_config  =  VectorParams (
82+             size = self .dimension , distance = Distance .COSINE , hnsw_config = hnsw_config 
83+         )
84+         optimizers_config  =  OptimizersConfigDiff (
85+             deleted_threshold = 0.2 ,
86+             vacuum_min_vector_number = 1000 ,
87+             flush_interval_sec = flush_interval_sec ,
88+         )
7289        # check if the collection exists 
7390        existing_collections  =  self ._client .get_collections ()
7491        for  existing_collection  in  existing_collections .collections :
7592            if  existing_collection .name  ==  collection_name :
76-                 gptcache_log .warning ("The %s collection already exists, and it will be used directly." , collection_name )
93+                 gptcache_log .warning (
94+                     "The %s collection already exists, and it will be used directly." ,
95+                     collection_name ,
96+                 )
7797                break 
7898        else :
79-             self ._client .create_collection (collection_name = collection_name , vectors_config = vectors_config ,
80-                                            optimizers_config = optimizers_config )
99+             self ._client .create_collection (
100+                 collection_name = collection_name ,
101+                 vectors_config = vectors_config ,
102+                 optimizers_config = optimizers_config ,
103+             )
81104
82105    def  mul_add (self , datas : List [VectorData ]):
83-         points  =  [PointStruct (id = d .id , vector = d .data .reshape (- 1 ).tolist ()) for  d  in  datas ]
84-         self ._client .upsert (collection_name = self ._collection_name , points = points , wait = False )
106+         points  =  [
107+             PointStruct (id = d .id , vector = d .data .reshape (- 1 ).tolist ()) for  d  in  datas 
108+         ]
109+         self ._client .upsert (
110+             collection_name = self ._collection_name , points = points , wait = False 
111+         )
85112
86113    def  search (self , data : np .ndarray , top_k : int  =  - 1 ):
87114        if  top_k  ==  - 1 :
88115            top_k  =  self .top_k 
89116        reshaped_data  =  data .reshape (- 1 ).tolist ()
90-         search_result  =  self ._client .search (collection_name = self ._collection_name , query_vector = reshaped_data ,
91-                                             limit = top_k )
117+         search_result  =  self ._client .search (
118+             collection_name = self ._collection_name ,
119+             query_vector = reshaped_data ,
120+             limit = top_k ,
121+         )
92122        return  list (map (lambda  x : (x .score , x .id ), search_result ))
93123
94124    def  delete (self , ids : List [str ]):
95125        self ._client .delete (collection_name = self ._collection_name , points_selector = ids )
96126
97127    def  rebuild (self , ids = None ):  # pylint: disable=unused-argument 
98-         optimizers_config  =  OptimizersConfigDiff (deleted_threshold = 0.2 , vacuum_min_vector_number = 1000 )
99-         self ._client .update_collection (collection_name = self ._collection_name , optimizer_config = optimizers_config )
128+         optimizers_config  =  OptimizersConfigDiff (
129+             deleted_threshold = 0.2 , vacuum_min_vector_number = 1000 
130+         )
131+         self ._client .update_collection (
132+             collection_name = self ._collection_name , optimizer_config = optimizers_config 
133+         )
100134
101135    def  flush (self ):
102136        # no need to flush manually as qdrant flushes automatically based on the optimizers_config for remote Qdrant 
103137        pass 
104138
105- 
106139    def  close (self ):
107140        self .flush ()
0 commit comments