Skip to content

Commit b57f6a1

Browse files
committed
chore: add custom binding
1 parent c0567ec commit b57f6a1

32 files changed

+15623
-17
lines changed

Cargo.lock

Lines changed: 202 additions & 17 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@ smallvec = { version = "1.13.2", features = ["union", "const_new
4444
strum = { version = "0.27.1", features = ["derive"] }
4545
# this will use tokio if available, otherwise async-std
4646
convert_case = "0.6.0"
47+
prost = "0.13.5"
4748
prost-reflect = "0.15.3"
4849
protox = "0.8.0"
4950
sqlx = { version = "0.8.2", features = ["runtime-tokio", "runtime-async-std", "postgres", "json"] }
5051
syn = "1.0.109"
5152
termcolor = "1.4.1"
5253
test-log = "0.2.17"
54+
thiserror = "1.0.31"
5355
tokio = { version = "1.40.0", features = ["full"] }
5456
tracing = { version = "0.1.40", default-features = false, features = ["std"] }
5557
tracing-bunyan-formatter = { version = "0.3.10 " }
@@ -74,7 +76,9 @@ pgt_lexer = { path = "./crates/pgt_lexer", version = "0.0.0" }
7476
pgt_lexer_codegen = { path = "./crates/pgt_lexer_codegen", version = "0.0.0" }
7577
pgt_lsp = { path = "./crates/pgt_lsp", version = "0.0.0" }
7678
pgt_markup = { path = "./crates/pgt_markup", version = "0.0.0" }
79+
pgt_query = { path = "./crates/pgt_query", version = "0.0.0" }
7780
pgt_query_ext = { path = "./crates/pgt_query_ext", version = "0.0.0" }
81+
pgt_query_macros = { path = "./crates/pgt_query_macros", version = "0.0.0" }
7882
pgt_query_proto_parser = { path = "./crates/pgt_query_proto_parser", version = "0.0.0" }
7983
pgt_schema_cache = { path = "./crates/pgt_schema_cache", version = "0.0.0" }
8084
pgt_statement_splitter = { path = "./crates/pgt_statement_splitter", version = "0.0.0" }

crates/pgt_query/Cargo.toml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
[package]
2+
authors.workspace = true
3+
categories.workspace = true
4+
description = "<DESCRIPTION>"
5+
edition.workspace = true
6+
homepage.workspace = true
7+
keywords.workspace = true
8+
license.workspace = true
9+
name = "pgt_query"
10+
repository.workspace = true
11+
version = "0.0.0"
12+
13+
[dependencies]
14+
prost = { workspace = true }
15+
thiserror = { workspace = true }
16+
17+
pgt_query_macros = { workspace = true }
18+
19+
20+
[features]
21+
default = ["postgres-17"]
22+
postgres-15 = []
23+
postgres-16 = []
24+
postgres-17 = []
25+
26+
[build-dependencies]
27+
bindgen = "0.72.0"
28+
cc = "1.0.83"
29+
clippy = { version = "0.0.302", optional = true }
30+
fs_extra = "1.2.0"
31+
glob = "0.3.1"
32+
prost-build = "0.13.5"
33+
which = "6.0.0"
34+
35+
[dev-dependencies]
36+
easy-parallel = "3.2.0"

crates/pgt_query/build.rs

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
#![cfg_attr(feature = "clippy", feature(plugin))]
2+
#![cfg_attr(feature = "clippy", plugin(clippy))]
3+
4+
use fs_extra::dir::CopyOptions;
5+
use glob::glob;
6+
use std::env;
7+
use std::path::PathBuf;
8+
use std::process::Command;
9+
10+
static LIBRARY_NAME: &str = "pg_query";
11+
static LIBPG_QUERY_REPO: &str = "https://github.com/pganalyze/libpg_query.git";
12+
fn get_libpg_query_tag() -> &'static str {
13+
#[cfg(feature = "postgres-15")]
14+
return "15-5.3.0";
15+
#[cfg(feature = "postgres-16")]
16+
return "16-6.1.0";
17+
#[cfg(feature = "postgres-17")]
18+
return "17-6.1.0";
19+
}
20+
21+
fn main() -> Result<(), Box<dyn std::error::Error>> {
22+
let libpg_query_tag = get_libpg_query_tag();
23+
let out_dir = PathBuf::from(env::var("OUT_DIR")?);
24+
let vendor_dir = out_dir.join("vendor");
25+
let libpg_query_dir = vendor_dir.join("libpg_query").join(libpg_query_tag);
26+
let stamp_file = libpg_query_dir.join(".stamp");
27+
28+
let src_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?).join("src");
29+
let target = env::var("TARGET").unwrap();
30+
let is_emscripten = target.contains("emscripten");
31+
32+
// Configure cargo through stdout
33+
println!("cargo:rustc-link-search=native={}", out_dir.display());
34+
println!("cargo:rustc-link-lib=static={LIBRARY_NAME}");
35+
36+
// Clone libpg_query if not already present
37+
if !stamp_file.exists() {
38+
println!("cargo:warning=Cloning libpg_query {}", libpg_query_tag);
39+
40+
// Create vendor directory
41+
std::fs::create_dir_all(&vendor_dir)?;
42+
43+
// Clone the repository with partial clone for faster download
44+
let status = Command::new("git")
45+
.args([
46+
"clone",
47+
"--filter=blob:none",
48+
"--depth",
49+
"1",
50+
"--branch",
51+
libpg_query_tag,
52+
LIBPG_QUERY_REPO,
53+
libpg_query_dir.to_str().unwrap(),
54+
])
55+
.status()?;
56+
57+
if !status.success() {
58+
return Err("Failed to clone libpg_query".into());
59+
}
60+
61+
// Create stamp file
62+
std::fs::File::create(&stamp_file)?;
63+
}
64+
65+
// Tell cargo to rerun if the stamp file is deleted
66+
println!("cargo:rerun-if-changed={}", stamp_file.display());
67+
68+
// Copy necessary files to OUT_DIR for compilation
69+
let out_header_path = out_dir.join(LIBRARY_NAME).with_extension("h");
70+
let out_protobuf_path = out_dir.join("protobuf");
71+
72+
let source_paths = vec![
73+
libpg_query_dir.join(LIBRARY_NAME).with_extension("h"),
74+
libpg_query_dir.join("Makefile"),
75+
libpg_query_dir.join("src"),
76+
libpg_query_dir.join("protobuf"),
77+
libpg_query_dir.join("vendor"),
78+
];
79+
80+
let copy_options = CopyOptions {
81+
overwrite: true,
82+
..CopyOptions::default()
83+
};
84+
85+
fs_extra::copy_items(&source_paths, &out_dir, &copy_options)?;
86+
87+
// Compile the C library.
88+
let mut build = cc::Build::new();
89+
90+
// Configure for Emscripten if needed
91+
if is_emscripten {
92+
// Use emcc as the compiler instead of gcc/clang
93+
build.compiler("emcc");
94+
// Use emar as the archiver instead of ar
95+
build.archiver("emar");
96+
// Note: We don't add WASM-specific flags here as this creates a static library
97+
// The final linking flags should be added when building the final WASM module
98+
}
99+
100+
build
101+
.files(
102+
glob(out_dir.join("src/*.c").to_str().unwrap())
103+
.unwrap()
104+
.map(|p| p.unwrap()),
105+
)
106+
.files(
107+
glob(out_dir.join("src/postgres/*.c").to_str().unwrap())
108+
.unwrap()
109+
.map(|p| p.unwrap()),
110+
)
111+
.file(out_dir.join("vendor/protobuf-c/protobuf-c.c"))
112+
.file(out_dir.join("vendor/xxhash/xxhash.c"))
113+
.file(out_dir.join("protobuf/pg_query.pb-c.c"))
114+
.include(out_dir.join("."))
115+
.include(out_dir.join("./vendor"))
116+
.include(out_dir.join("./src/postgres/include"))
117+
.include(out_dir.join("./src/include"))
118+
.warnings(false); // Avoid unnecessary warnings, as they are already considered as part of libpg_query development
119+
if env::var("PROFILE").unwrap() == "debug" || env::var("DEBUG").unwrap() == "1" {
120+
build.define("USE_ASSERT_CHECKING", None);
121+
}
122+
if target.contains("windows") && !is_emscripten {
123+
build.include(out_dir.join("./src/postgres/include/port/win32"));
124+
if target.contains("msvc") {
125+
build.include(out_dir.join("./src/postgres/include/port/win32_msvc"));
126+
}
127+
}
128+
build.compile(LIBRARY_NAME);
129+
130+
// Generate bindings for Rust
131+
let mut bindgen_builder = bindgen::Builder::default()
132+
.header(out_header_path.to_str().ok_or("Invalid header path")?)
133+
// Allowlist only the functions we need
134+
.allowlist_function("pg_query_parse_protobuf")
135+
.allowlist_function("pg_query_scan")
136+
.allowlist_function("pg_query_deparse_protobuf")
137+
.allowlist_function("pg_query_normalize")
138+
.allowlist_function("pg_query_fingerprint")
139+
.allowlist_function("pg_query_split_with_parser")
140+
.allowlist_function("pg_query_split_with_scanner")
141+
.allowlist_function("pg_query_free_protobuf_parse_result")
142+
.allowlist_function("pg_query_free_scan_result")
143+
.allowlist_function("pg_query_free_deparse_result")
144+
.allowlist_function("pg_query_free_normalize_result")
145+
.allowlist_function("pg_query_free_fingerprint_result")
146+
.allowlist_function("pg_query_free_split_result")
147+
// Allowlist the types used by these functions
148+
.allowlist_type("PgQueryProtobufParseResult")
149+
.allowlist_type("PgQueryScanResult")
150+
.allowlist_type("PgQueryError")
151+
.allowlist_type("PgQueryProtobuf")
152+
.allowlist_type("PgQueryDeparseResult")
153+
.allowlist_type("PgQueryNormalizeResult")
154+
.allowlist_type("PgQueryFingerprintResult")
155+
.allowlist_type("PgQuerySplitResult")
156+
.allowlist_type("PgQuerySplitStmt")
157+
// Also generate bindings for size_t since it's used in PgQueryProtobuf
158+
.allowlist_type("size_t")
159+
.allowlist_var("PG_VERSION_NUM");
160+
161+
// Configure bindgen for Emscripten target
162+
if is_emscripten {
163+
// Tell bindgen to generate bindings for the wasm32 target
164+
bindgen_builder = bindgen_builder.clang_arg("--target=wasm32-unknown-emscripten");
165+
166+
// Add emscripten sysroot includes
167+
// First try to use EMSDK environment variable (set in CI and when sourcing emsdk_env.sh)
168+
if let Ok(emsdk) = env::var("EMSDK") {
169+
bindgen_builder = bindgen_builder.clang_arg(format!(
170+
"-I{}/upstream/emscripten/cache/sysroot/include",
171+
emsdk
172+
));
173+
} else {
174+
// Fallback to the default path if EMSDK is not set
175+
bindgen_builder =
176+
bindgen_builder.clang_arg("-I/emsdk/upstream/emscripten/cache/sysroot/include");
177+
}
178+
179+
// Ensure we have the basic C standard library headers
180+
bindgen_builder = bindgen_builder.clang_arg("-D__EMSCRIPTEN__");
181+
182+
// Use environment variable if set (from our justfile)
183+
if let Ok(extra_args) = env::var("BINDGEN_EXTRA_CLANG_ARGS") {
184+
for arg in extra_args.split_whitespace() {
185+
bindgen_builder = bindgen_builder.clang_arg(arg);
186+
}
187+
}
188+
}
189+
190+
let bindings = bindgen_builder
191+
.generate()
192+
.map_err(|_| "Unable to generate bindings")?;
193+
194+
let bindings_path = src_dir.join("bindings.rs");
195+
bindings.write_to_file(&bindings_path)?;
196+
197+
// For WASM/emscripten builds, manually add the function declarations
198+
// since bindgen sometimes misses them due to preprocessor conditions
199+
if is_emscripten {
200+
let mut bindings_content = std::fs::read_to_string(&bindings_path)?;
201+
202+
// Check if we need to add the extern "C" block
203+
if !bindings_content.contains("extern \"C\"") {
204+
bindings_content.push_str("\nextern \"C\" {\n");
205+
bindings_content.push_str(" pub fn pg_query_scan(input: *const ::std::os::raw::c_char) -> PgQueryScanResult;\n");
206+
bindings_content.push_str(" pub fn pg_query_parse_protobuf(input: *const ::std::os::raw::c_char) -> PgQueryProtobufParseResult;\n");
207+
bindings_content.push_str(" pub fn pg_query_deparse_protobuf(protobuf: PgQueryProtobuf) -> PgQueryDeparseResult;\n");
208+
bindings_content.push_str(" pub fn pg_query_normalize(input: *const ::std::os::raw::c_char) -> PgQueryNormalizeResult;\n");
209+
bindings_content.push_str(" pub fn pg_query_fingerprint(input: *const ::std::os::raw::c_char) -> PgQueryFingerprintResult;\n");
210+
bindings_content.push_str(" pub fn pg_query_split_with_parser(input: *const ::std::os::raw::c_char) -> PgQuerySplitResult;\n");
211+
bindings_content.push_str(" pub fn pg_query_split_with_scanner(input: *const ::std::os::raw::c_char) -> PgQuerySplitResult;\n");
212+
bindings_content
213+
.push_str(" pub fn pg_query_free_scan_result(result: PgQueryScanResult);\n");
214+
bindings_content.push_str(" pub fn pg_query_free_protobuf_parse_result(result: PgQueryProtobufParseResult);\n");
215+
bindings_content.push_str(
216+
" pub fn pg_query_free_deparse_result(result: PgQueryDeparseResult);\n",
217+
);
218+
bindings_content.push_str(
219+
" pub fn pg_query_free_normalize_result(result: PgQueryNormalizeResult);\n",
220+
);
221+
bindings_content.push_str(
222+
" pub fn pg_query_free_fingerprint_result(result: PgQueryFingerprintResult);\n",
223+
);
224+
bindings_content
225+
.push_str(" pub fn pg_query_free_split_result(result: PgQuerySplitResult);\n");
226+
bindings_content.push_str("}\n");
227+
228+
std::fs::write(&bindings_path, bindings_content)?;
229+
}
230+
}
231+
232+
let protoc_exists = Command::new("protoc").arg("--version").status().is_ok();
233+
if protoc_exists {
234+
println!("generating protobuf bindings");
235+
// HACK: Set OUT_DIR to src/ so that the generated protobuf file is copied to src/protobuf.rs
236+
unsafe {
237+
env::set_var("OUT_DIR", &src_dir);
238+
}
239+
240+
prost_build::compile_protos(
241+
&[&out_protobuf_path.join(LIBRARY_NAME).with_extension("proto")],
242+
&[&out_protobuf_path],
243+
)?;
244+
245+
std::fs::rename(src_dir.join("pg_query.rs"), src_dir.join("protobuf.rs"))?;
246+
247+
// Reset OUT_DIR to the original value
248+
unsafe {
249+
env::set_var("OUT_DIR", &out_dir);
250+
}
251+
} else {
252+
println!("skipping protobuf generation");
253+
}
254+
255+
Ok(())
256+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
use pgt_query::{NodeRef, parse};
2+
3+
fn main() {
4+
let mut result = parse("SELECT * FROM users WHERE id IN (SELECT id FROM admins)").unwrap();
5+
6+
// Immutable access
7+
{
8+
let stmts = result.stmts();
9+
let stmt = stmts.first().unwrap();
10+
11+
// nodes() returns a Vec<NodeRef>
12+
let all_nodes = stmt.nodes();
13+
println!("Total nodes in AST: {}", all_nodes.len());
14+
15+
// Can still iterate with iter()
16+
let select_count = stmt
17+
.iter()
18+
.filter(|n| matches!(n, NodeRef::SelectStmt(_)))
19+
.count();
20+
println!("Number of SELECT statements: {}", select_count);
21+
}
22+
23+
// Mutable access - no cloning needed!
24+
{
25+
let mut stmts = result.stmts_mut();
26+
if let Some(stmt) = stmts.first_mut() {
27+
// Now we can iterate mutably without cloning
28+
for mut_node in stmt.iter_mut() {
29+
// Modify nodes here if needed
30+
if let pgt_query::NodeMut::SelectStmt(_select) = mut_node {
31+
println!("Found a SELECT statement to modify");
32+
// You can modify _select here
33+
}
34+
}
35+
}
36+
}
37+
38+
// Alternative: using root_mut() for single statement queries
39+
if let Some(root) = result.root_mut() {
40+
println!("Root node type: {:?}", std::mem::discriminant(root));
41+
}
42+
}

0 commit comments

Comments
 (0)