Skip to content

Commit

Permalink
feat: prewarm (tensorchord#15)
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <usamoi@outlook.com>
  • Loading branch information
usamoi authored Oct 21, 2024
1 parent 6d5cfc4 commit 87f0d80
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ path = "./src/bin/pgrx_embed.rs"

[features]
default = []
pg12 = ["pgrx/pg12"]
pg13 = ["pgrx/pg13"]
pg14 = ["pgrx/pg14"]
pg15 = ["pgrx/pg15"]
pg16 = ["pgrx/pg16"]
Expand Down
6 changes: 6 additions & 0 deletions src/algorithm/rabitq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ fn orthogonal_matrix(n: usize) -> Vec<Vec<f32>> {

static MATRIXS: [OnceLock<Vec<Vec<f32>>>; 1 + 2000] = [const { OnceLock::new() }; 1 + 2000];

pub fn prewarm(n: usize) {
if n <= 2000 {
MATRIXS[n].get_or_init(|| orthogonal_matrix(n));
}
}

pub fn project(vector: &[f32]) -> Vec<f32> {
use base::scalar::ScalarLike;
let n = vector.len();
Expand Down
7 changes: 7 additions & 0 deletions src/gucs/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
pub mod executing;
pub mod prewarm;

pub unsafe fn init() {
unsafe {
executing::init();
prewarm::init();
prewarm::prewarm();
#[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14"))]
pgrx::pg_sys::EmitWarningsOnPlaceholders(c"rabbithole".as_ptr());
#[cfg(any(feature = "pg15", feature = "pg16", feature = "pg17"))]
pgrx::pg_sys::MarkGUCPrefixReserved(c"rabbithole".as_ptr());
}
}
31 changes: 31 additions & 0 deletions src/gucs/prewarm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use pgrx::guc::{GucContext, GucFlags, GucRegistry, GucSetting};
use std::ffi::CStr;

static PREWARM_DIM: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr>>::new(None);

pub unsafe fn init() {
GucRegistry::define_string_guc(
"rabbithole.prewarm_dim",
"prewarm_dim when the extension is loading.",
"prewarm_dim when the extension is loading.",
&PREWARM_DIM,
GucContext::Userset,
GucFlags::default(),
);
}

pub fn prewarm() {
if let Some(prewarm_dim) = PREWARM_DIM.get() {
if let Ok(prewarm_dim) = prewarm_dim.to_str() {
for dim in prewarm_dim.split(',') {
if let Ok(dim) = dim.trim().parse::<usize>() {
crate::algorithm::rabitq::prewarm(dim as _);
} else {
pgrx::warning!("{dim:?} is not a valid integer");
}
}
} else {
pgrx::warning!("rabbithole.prewarm_dim is not a valid UTF-8 string");
}
}
}

0 comments on commit 87f0d80

Please sign in to comment.