Skip to content

Commit

Permalink
chore: type check and external test
Browse files Browse the repository at this point in the history
Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
  • Loading branch information
cutecutecat committed Dec 24, 2024
1 parent 6335852 commit dabca33
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 5 deletions.
2 changes: 1 addition & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def kmeans_cluster(
verbose=True,
niter=niter,
seed=SEED,
spherical=metric == "cos",
spherical=metric != "l2",
)
child_kmeans.train(child_train)
centroids.append(child_kmeans.centroids)
Expand Down
28 changes: 24 additions & 4 deletions src/vchordrq/algorithm/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,34 @@ impl Structure {
) -> Vec<Self> {
use std::collections::BTreeMap;
let VchordrqExternalBuildOptions { table } = external_build;
let query = format!("SELECT id, parent, vector FROM {table};");
let dump_query = format!("SELECT id, parent, vector FROM {table};");
let table_name = table.split('.').last().unwrap().to_string();
let type_check_query = format!(
"SELECT COUNT(*)::INTEGER
FROM pg_catalog.pg_extension e
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = e.extnamespace
LEFT JOIN information_schema.columns i ON i.udt_schema = n.nspname
WHERE e.extname = 'vector' AND i.udt_name = 'vector'
AND i.table_name = '{table_name}' AND i.column_name = 'vector';"
);
let mut parents = BTreeMap::new();
let mut vectors = BTreeMap::new();
pgrx::spi::Spi::connect(|client| {
use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput;
use base::vector::VectorBorrowed;
use pgrx::pg_sys::panic::ErrorReportable;
let table = client.select(&query, None, None).unwrap_or_report();
// Check the column of centroid table named `vector`, which type should be pgvector::vector
let type_check = client
.select(&type_check_query, None, None)
.unwrap_or_report();
let count: Result<Option<i32>, _> = type_check.first().get_by_name("count");
if count != Ok(Some(1)) {
pgrx::warning!("{:?}", count);
pgrx::error!(
"extern build: `vector` column should be pgvector::vector type at the centroid table"
);
}
let table = client.select(&dump_query, None, None).unwrap_or_report();
for row in table {
let id: Option<i32> = row.get_by_name("id").unwrap();
let parent: Option<i32> = row.get_by_name("parent").unwrap();
Expand All @@ -220,7 +240,7 @@ impl Structure {
let pop = parents.insert(id, parent);
if pop.is_some() {
pgrx::error!(
"external build: there are at least two lines have same id, id = {id}"
"extern build: there are at least two lines have same id, id = {id}"
);
}
if vector_options.dims != vector.as_borrowed().dims() {
Expand Down Expand Up @@ -263,7 +283,7 @@ impl Structure {
parent.push(id);
} else {
pgrx::error!(
"external build: parent does not exist, id = {id}, parent = {parent}"
"extern build: parent does not exist, id = {id}, parent = {parent}"
);
}
} else {
Expand Down
79 changes: 79 additions & 0 deletions tests/logic/external_build.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
statement ok
CREATE TABLE t (val0 vector(3), val1 halfvec(3));

statement ok
INSERT INTO t (val0, val1)
SELECT
ARRAY[random(), random(), random()]::real[]::vector,
ARRAY[random(), random(), random()]::real[]::halfvec
FROM generate_series(1, 100);

statement ok
CREATE TABLE vector_centroid (id integer, parent integer, vector vector(3));

statement ok
INSERT INTO vector_centroid (id, vector) VALUES
(0, '[1.0, 0.0, 0.0]'),
(1, '[0.0, 1.0, 0.0]'),
(2, '[0.0, 0.0, 1.0]');

statement ok
CREATE TABLE bad_type_centroid (id integer, parent integer, vector halfvec(3));

statement ok
INSERT INTO bad_type_centroid (id, vector) VALUES
(0, '[1.0, 0.0, 0.0]'),
(1, '[0.0, 1.0, 0.0]'),
(2, '[0.0, 0.0, 1.0]');

statement ok
CREATE TABLE bad_duplicate_id (id integer, parent integer, vector vector(3));

statement ok
INSERT INTO bad_duplicate_id (id, vector) VALUES
(1, '[1.0, 0.0, 0.0]'),
(1, '[0.0, 1.0, 0.0]'),
(2, '[0.0, 0.0, 1.0]');

# external build for vector column

statement ok
CREATE INDEX ON t USING vchordrq (val0 vector_l2_ops)
WITH (options = $$
residual_quantization = true
[build.external]
table = 'public.vector_centroid'
$$);

# external build for halfvec column

statement ok
CREATE INDEX ON t USING vchordrq (val1 halfvec_l2_ops)
WITH (options = $$
residual_quantization = true
[build.external]
table = 'public.vector_centroid'
$$);

# failed: bad vector bad_type

statement error extern build: `vector` column should be pgvector::vector type at the centroid table
CREATE INDEX ON t USING vchordrq (val0 vector_l2_ops)
WITH (options = $$
residual_quantization = true
[build.external]
table = 'public.bad_type_centroid'
$$);

# failed: duplicate id

statement error extern build: there are at least two lines have same id, id = 1
CREATE INDEX ON t USING vchordrq (val0 vector_l2_ops)
WITH (options = $$
residual_quantization = true
[build.external]
table = 'public.bad_duplicate_id'
$$);

statement ok
DROP TABLE t, vector_centroid, bad_type_centroid, bad_duplicate_id;

0 comments on commit dabca33

Please sign in to comment.