Skip to content

Commit fb3b1c9

Browse files
committed
support complex index params for pgvecto.rs
1 parent ed90d56 commit fb3b1c9

File tree

3 files changed

+57
-26
lines changed

3 files changed

+57
-26
lines changed

engine/clients/pgvector/upload.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import time
2+
import toml
23
from typing import List, Optional
34
import psycopg2
45
from engine.base_client import BaseUploader
@@ -13,28 +14,45 @@ class PGVectorUploader(BaseUploader):
1314
vector_count: int = None
1415

1516
@classmethod
16-
def init_client(cls, host, distance, vector_count, connection_params, upload_params,
17-
extra_columns_name: list, extra_columns_type: list):
18-
database, host, port, user, password = process_connection_params(connection_params, host)
19-
cls.conn = psycopg2.connect(database=database, user=user, password=password, host=host, port=port)
17+
def init_client(
18+
cls,
19+
host,
20+
distance,
21+
vector_count,
22+
connection_params,
23+
upload_params,
24+
extra_columns_name: list,
25+
extra_columns_type: list,
26+
):
27+
database, host, port, user, password = process_connection_params(
28+
connection_params, host
29+
)
30+
cls.conn = psycopg2.connect(
31+
database=database, user=user, password=password, host=host, port=port
32+
)
2033
cls.host = host
2134
cls.upload_params = upload_params
2235
cls.engine_type = upload_params.get("engine_type", "c")
23-
cls.distance = DISTANCE_MAPPING_CREATE[distance] if cls.engine_type == "c" else DISTANCE_MAPPING_CREATE_RUST[
24-
distance]
36+
cls.distance = (
37+
DISTANCE_MAPPING_CREATE[distance]
38+
if cls.engine_type == "c"
39+
else DISTANCE_MAPPING_CREATE_RUST[distance]
40+
)
2541
cls.vector_count = vector_count
2642

2743
@classmethod
28-
def upload_batch(cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]]):
44+
def upload_batch(
45+
cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]]
46+
):
2947
if len(ids) != len(vectors):
3048
raise RuntimeError("PGVector batch upload unhealthy")
3149
# Getting the names of structured data columns based on the first meta information.
32-
col_name_tuple = ('id', 'vector')
33-
col_type_tuple = ('%s', '%s::real[]')
50+
col_name_tuple = ("id", "vector")
51+
col_type_tuple = ("%s", "%s::real[]")
3452
if metadata[0] is not None:
3553
for col_name in list(metadata[0].keys()):
3654
col_name_tuple += (col_name,)
37-
col_type_tuple += ('%s',)
55+
col_type_tuple += ("%s",)
3856

3957
insert_data = []
4058
for i in range(0, len(ids)):
@@ -43,7 +61,9 @@ def upload_batch(cls, ids: List[int], vectors: List[list], metadata: List[Option
4361
for col_name in list(metadata[i].keys()):
4462
value = metadata[i][col_name]
4563
# Determining if the data is a dictionary type of latitude and longitude.
46-
if isinstance(value, dict) and ('lon' and 'lat') in list(value.keys()):
64+
if isinstance(value, dict) and ("lon" and "lat") in list(
65+
value.keys()
66+
):
4767
raise RuntimeError("Postgres doesn't support geo datasets")
4868
else:
4969
temp_tuple += (value,)
@@ -63,21 +83,22 @@ def upload_batch(cls, ids: List[int], vectors: List[list], metadata: List[Option
6383

6484
@classmethod
6585
def post_upload(cls, distance):
66-
index_options_c = ""
67-
index_options_rust = ""
68-
for key in cls.upload_params.get("index_params", {}).keys():
69-
index_options_c += ("{}={}" if index_options_c == "" else ", {}={}").format(
70-
key, cls.upload_params.get('index_params', {})[key])
71-
index_options_rust += ("{}={}" if index_options_rust == "" else "\n{}={}").format(
72-
key, cls.upload_params.get('index_params', {})[key])
73-
create_index_command = f"CREATE INDEX ON {PGVECTOR_INDEX} USING hnsw (vector {cls.distance}) WITH ({index_options_c});"
74-
if cls.engine_type == "rust":
86+
if cls.engine_type == "c":
87+
index_options_c = ""
88+
for key in cls.upload_params.get("index_params", {}).keys():
89+
index_options_c += (
90+
"{}={}" if index_options_c == "" else ", {}={}"
91+
).format(key, cls.upload_params.get("index_params", {})[key])
92+
create_index_command = f"CREATE INDEX ON {PGVECTOR_INDEX} USING hnsw (vector {cls.distance}) WITH ({index_options_c});"
93+
elif cls.engine_type == "rust":
94+
index_options_rust = toml.dumps(cls.upload_params.get("index_params", {}))
7595
create_index_command = f"""
7696
CREATE INDEX ON {PGVECTOR_INDEX} USING vectors (vector {cls.distance}) WITH (options=$$
77-
[indexing.hnsw]
7897
{index_options_rust}
7998
$$);
8099
"""
100+
else:
101+
raise ValueError("PGVector engine type must be c or rust")
81102

82103
# create index (blocking)
83104
with cls.conn.cursor() as cur:
@@ -86,5 +107,7 @@ def post_upload(cls, distance):
86107
cls.conn.commit()
87108
# wait index finished
88109
with cls.conn.cursor() as cur:
89-
cur.execute("SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index;")
110+
cur.execute(
111+
"SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index;"
112+
)
90113
cls.conn.commit()

experiments/needs_editing/pgvector_rust_HNSW_single_node_laion-768-5m-ip.json

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,12 @@
5353
"parallel": 16,
5454
"batch_size": 64,
5555
"index_params": {
56-
"m": 12,
57-
"ef_construction": 100
56+
"indexing": {
57+
"hnsw": {
58+
"m": 12,
59+
"ef_construction": 100
60+
}
61+
}
5862
},
5963
"index_type": "hnsw",
6064
"engine_type": "rust"

experiments/needs_editing/pgvector_rust_HNSW_single_node_laion-768-5m-probability-ip.json

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,12 @@
105105
"parallel": 16,
106106
"batch_size": 64,
107107
"index_params": {
108-
"m": 12,
109-
"ef_construction": 100
108+
"indexing": {
109+
"hnsw": {
110+
"m": 12,
111+
"ef_construction": 100
112+
}
113+
}
110114
},
111115
"index_type": "hnsw",
112116
"engine_type": "rust"

0 commit comments

Comments
 (0)