1
1
import time
2
+ import toml
2
3
from typing import List , Optional
3
4
import psycopg2
4
5
from engine .base_client import BaseUploader
@@ -13,28 +14,45 @@ class PGVectorUploader(BaseUploader):
13
14
vector_count : int = None
14
15
15
16
@classmethod
16
- def init_client (cls , host , distance , vector_count , connection_params , upload_params ,
17
- extra_columns_name : list , extra_columns_type : list ):
18
- database , host , port , user , password = process_connection_params (connection_params , host )
19
- cls .conn = psycopg2 .connect (database = database , user = user , password = password , host = host , port = port )
17
+ def init_client (
18
+ cls ,
19
+ host ,
20
+ distance ,
21
+ vector_count ,
22
+ connection_params ,
23
+ upload_params ,
24
+ extra_columns_name : list ,
25
+ extra_columns_type : list ,
26
+ ):
27
+ database , host , port , user , password = process_connection_params (
28
+ connection_params , host
29
+ )
30
+ cls .conn = psycopg2 .connect (
31
+ database = database , user = user , password = password , host = host , port = port
32
+ )
20
33
cls .host = host
21
34
cls .upload_params = upload_params
22
35
cls .engine_type = upload_params .get ("engine_type" , "c" )
23
- cls .distance = DISTANCE_MAPPING_CREATE [distance ] if cls .engine_type == "c" else DISTANCE_MAPPING_CREATE_RUST [
24
- distance ]
36
+ cls .distance = (
37
+ DISTANCE_MAPPING_CREATE [distance ]
38
+ if cls .engine_type == "c"
39
+ else DISTANCE_MAPPING_CREATE_RUST [distance ]
40
+ )
25
41
cls .vector_count = vector_count
26
42
27
43
@classmethod
28
- def upload_batch (cls , ids : List [int ], vectors : List [list ], metadata : List [Optional [dict ]]):
44
+ def upload_batch (
45
+ cls , ids : List [int ], vectors : List [list ], metadata : List [Optional [dict ]]
46
+ ):
29
47
if len (ids ) != len (vectors ):
30
48
raise RuntimeError ("PGVector batch upload unhealthy" )
31
49
# Getting the names of structured data columns based on the first meta information.
32
- col_name_tuple = ('id' , ' vector' )
33
- col_type_tuple = ('%s' , ' %s::real[]' )
50
+ col_name_tuple = ("id" , " vector" )
51
+ col_type_tuple = ("%s" , " %s::real[]" )
34
52
if metadata [0 ] is not None :
35
53
for col_name in list (metadata [0 ].keys ()):
36
54
col_name_tuple += (col_name ,)
37
- col_type_tuple += ('%s' ,)
55
+ col_type_tuple += ("%s" ,)
38
56
39
57
insert_data = []
40
58
for i in range (0 , len (ids )):
@@ -43,7 +61,9 @@ def upload_batch(cls, ids: List[int], vectors: List[list], metadata: List[Option
43
61
for col_name in list (metadata [i ].keys ()):
44
62
value = metadata [i ][col_name ]
45
63
# Determining if the data is a dictionary type of latitude and longitude.
46
- if isinstance (value , dict ) and ('lon' and 'lat' ) in list (value .keys ()):
64
+ if isinstance (value , dict ) and ("lon" and "lat" ) in list (
65
+ value .keys ()
66
+ ):
47
67
raise RuntimeError ("Postgres doesn't support geo datasets" )
48
68
else :
49
69
temp_tuple += (value ,)
@@ -63,21 +83,22 @@ def upload_batch(cls, ids: List[int], vectors: List[list], metadata: List[Option
63
83
64
84
@classmethod
65
85
def post_upload (cls , distance ):
66
- index_options_c = ""
67
- index_options_rust = ""
68
- for key in cls .upload_params .get ("index_params" , {}).keys ():
69
- index_options_c += ( "{}={}" if index_options_c == "" else ", {}={}" ). format (
70
- key , cls . upload_params . get ( 'index_params' , {})[ key ])
71
- index_options_rust += ( "{}={}" if index_options_rust == "" else " \n {}={}" ). format (
72
- key , cls .upload_params . get ( 'index_params' , {})[ key ])
73
- create_index_command = f"CREATE INDEX ON { PGVECTOR_INDEX } USING hnsw (vector { cls .distance } ) WITH ( { index_options_c } );"
74
- if cls . engine_type == "rust" :
86
+ if cls . engine_type == "c" :
87
+ index_options_c = ""
88
+ for key in cls .upload_params .get ("index_params" , {}).keys ():
89
+ index_options_c += (
90
+ "{}={}" if index_options_c == "" else " , {}={}"
91
+ ). format ( key , cls . upload_params . get ( "index_params" , {})[ key ])
92
+ create_index_command = f"CREATE INDEX ON { PGVECTOR_INDEX } USING hnsw (vector { cls .distance } ) WITH ( { index_options_c } );"
93
+ elif cls .engine_type == "rust" :
94
+ index_options_rust = toml . dumps ( cls . upload_params . get ( "index_params" , {}))
75
95
create_index_command = f"""
76
96
CREATE INDEX ON { PGVECTOR_INDEX } USING vectors (vector { cls .distance } ) WITH (options=$$
77
- [indexing.hnsw]
78
97
{ index_options_rust }
79
98
$$);
80
99
"""
100
+ else :
101
+ raise ValueError ("PGVector engine type must be c or rust" )
81
102
82
103
# create index (blocking)
83
104
with cls .conn .cursor () as cur :
@@ -86,5 +107,7 @@ def post_upload(cls, distance):
86
107
cls .conn .commit ()
87
108
# wait index finished
88
109
with cls .conn .cursor () as cur :
89
- cur .execute ("SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index;" )
110
+ cur .execute (
111
+ "SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index;"
112
+ )
90
113
cls .conn .commit ()
0 commit comments