11from  typing  import  List , Optional 
2+ 
23import  numpy  as  np 
34
5+ from  gptcache .manager .vector_data .base  import  VectorBase , VectorData 
46from  gptcache .utils  import  import_qdrant 
57from  gptcache .utils .log  import  gptcache_log 
6- from  gptcache .manager .vector_data .base  import  VectorBase , VectorData 
78
89import_qdrant ()
910
10- from  qdrant_client  import  QdrantClient   # pylint: disable=C0413 
11- from  qdrant_client .models  import  PointStruct , HnswConfigDiff , VectorParams , OptimizersConfigDiff , \
12-     Distance   # pylint: disable=C0413 
11+ # pylint: disable=C0413 
12+ from  qdrant_client  import  QdrantClient 
13+ from  qdrant_client .models  import  (
14+     PointStruct ,
15+     HnswConfigDiff ,
16+     VectorParams ,
17+     OptimizersConfigDiff ,
18+     Distance ,
19+ )
1320
1421
1522class  QdrantVectorStore (VectorBase ):
23+     """Qdrant Vector Store""" 
1624
1725    def  __init__ (
18-              self ,
19-              url : Optional [str ] =  None ,
20-              port : Optional [int ] =  6333 ,
21-              grpc_port : int  =  6334 ,
22-              prefer_grpc : bool  =  False ,
23-              https : Optional [bool ] =  None ,
24-              api_key : Optional [str ] =  None ,
25-              prefix : Optional [str ] =  None ,
26-              timeout : Optional [float ] =  None ,
27-              host : Optional [str ] =  None ,
28-              collection_name : Optional [str ] =  "gptcache" ,
29-              location : Optional [str ] =  "./qdrant" ,
30-              dimension : int  =  0 ,
31-              top_k : int  =  1 ,
32-              flush_interval_sec : int  =  5 ,
33-              index_params : Optional [dict ] =  None ,
26+         self ,
27+         url : Optional [str ] =  None ,
28+         port : Optional [int ] =  6333 ,
29+         grpc_port : int  =  6334 ,
30+         prefer_grpc : bool  =  False ,
31+         https : Optional [bool ] =  None ,
32+         api_key : Optional [str ] =  None ,
33+         prefix : Optional [str ] =  None ,
34+         timeout : Optional [float ] =  None ,
35+         host : Optional [str ] =  None ,
36+         collection_name : Optional [str ] =  "gptcache" ,
37+         location : Optional [str ] =  "./qdrant" ,
38+         dimension : int  =  0 ,
39+         top_k : int  =  1 ,
40+         flush_interval_sec : int  =  5 ,
41+         index_params : Optional [dict ] =  None ,
3442    ):
3543        if  dimension  <=  0 :
3644            raise  ValueError (
@@ -44,13 +52,17 @@ def __init__(
4452        if  self ._in_memory  or  location  is  not None :
4553            self ._create_local (location )
4654        else :
47-             self ._create_remote (url , port , api_key , timeout , host , grpc_port , prefer_grpc , prefix , https )
55+             self ._create_remote (
56+                 url , port , api_key , timeout , host , grpc_port , prefer_grpc , prefix , https 
57+             )
4858        self ._create_collection (collection_name , flush_interval_sec , index_params )
4959
5060    def  _create_local (self , location ):
5161        self ._client  =  QdrantClient (location = location )
5262
53-     def  _create_remote (self , url , port , api_key , timeout , host , grpc_port , prefer_grpc , prefix , https ):
63+     def  _create_remote (
64+         self , url , port , api_key , timeout , host , grpc_port , prefer_grpc , prefix , https 
65+     ):
5466        self ._client  =  QdrantClient (
5567            url = url ,
5668            port = port ,
@@ -63,45 +75,70 @@ def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_gr
6375            https = https ,
6476        )
6577
66-     def  _create_collection (self , collection_name : str , flush_interval_sec : int , index_params : Optional [dict ] =  None ):
78+     def  _create_collection (
79+         self ,
80+         collection_name : str ,
81+         flush_interval_sec : int ,
82+         index_params : Optional [dict ] =  None ,
83+     ):
6784        hnsw_config  =  HnswConfigDiff (** (index_params  or  {}))
68-         vectors_config  =  VectorParams (size = self .dimension , distance = Distance .COSINE ,
69-                                       hnsw_config = hnsw_config )
70-         optimizers_config  =  OptimizersConfigDiff (deleted_threshold = 0.2 , vacuum_min_vector_number = 1000 ,
71-                                                  flush_interval_sec = flush_interval_sec )
85+         vectors_config  =  VectorParams (
86+             size = self .dimension , distance = Distance .COSINE , hnsw_config = hnsw_config 
87+         )
88+         optimizers_config  =  OptimizersConfigDiff (
89+             deleted_threshold = 0.2 ,
90+             vacuum_min_vector_number = 1000 ,
91+             flush_interval_sec = flush_interval_sec ,
92+         )
7293        # check if the collection exists 
7394        existing_collections  =  self ._client .get_collections ()
7495        for  existing_collection  in  existing_collections .collections :
7596            if  existing_collection .name  ==  collection_name :
76-                 gptcache_log .warning ("The %s collection already exists, and it will be used directly." , collection_name )
97+                 gptcache_log .warning (
98+                     "The %s collection already exists, and it will be used directly." ,
99+                     collection_name ,
100+                 )
77101                break 
78102        else :
79-             self ._client .create_collection (collection_name = collection_name , vectors_config = vectors_config ,
80-                                            optimizers_config = optimizers_config )
103+             self ._client .create_collection (
104+                 collection_name = collection_name ,
105+                 vectors_config = vectors_config ,
106+                 optimizers_config = optimizers_config ,
107+             )
81108
82109    def  mul_add (self , datas : List [VectorData ]):
83-         points  =  [PointStruct (id = d .id , vector = d .data .reshape (- 1 ).tolist ()) for  d  in  datas ]
84-         self ._client .upsert (collection_name = self ._collection_name , points = points , wait = False )
110+         points  =  [
111+             PointStruct (id = d .id , vector = d .data .reshape (- 1 ).tolist ()) for  d  in  datas 
112+         ]
113+         self ._client .upsert (
114+             collection_name = self ._collection_name , points = points , wait = False 
115+         )
85116
86117    def  search (self , data : np .ndarray , top_k : int  =  - 1 ):
87118        if  top_k  ==  - 1 :
88119            top_k  =  self .top_k 
89120        reshaped_data  =  data .reshape (- 1 ).tolist ()
90-         search_result  =  self ._client .search (collection_name = self ._collection_name , query_vector = reshaped_data ,
91-                                             limit = top_k )
121+         search_result  =  self ._client .search (
122+             collection_name = self ._collection_name ,
123+             query_vector = reshaped_data ,
124+             limit = top_k ,
125+         )
92126        return  list (map (lambda  x : (x .score , x .id ), search_result ))
93127
94128    def  delete (self , ids : List [str ]):
95129        self ._client .delete (collection_name = self ._collection_name , points_selector = ids )
96130
97131    def  rebuild (self , ids = None ):  # pylint: disable=unused-argument 
98-         optimizers_config  =  OptimizersConfigDiff (deleted_threshold = 0.2 , vacuum_min_vector_number = 1000 )
99-         self ._client .update_collection (collection_name = self ._collection_name , optimizer_config = optimizers_config )
132+         optimizers_config  =  OptimizersConfigDiff (
133+             deleted_threshold = 0.2 , vacuum_min_vector_number = 1000 
134+         )
135+         self ._client .update_collection (
136+             collection_name = self ._collection_name , optimizer_config = optimizers_config 
137+         )
100138
101139    def  flush (self ):
102140        # no need to flush manually as qdrant flushes automatically based on the optimizers_config for remote Qdrant 
103141        pass 
104142
105- 
106143    def  close (self ):
107144        self .flush ()
0 commit comments