Skip to content

Commit 09e54e1

Browse files
authored
Plugins! (#420)
* Some queries * Plugins!! * cleanup * actual names * the actual plugins * comment * fix tests * Tests * unused errors * Increase reaper rate to actually enforce settings * ok
1 parent 23819c8 commit 09e54e1

File tree

17 files changed

+772
-96
lines changed

17 files changed

+772
-96
lines changed

Cargo.lock

Lines changed: 34 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgcat"
3-
version = "1.0.1"
3+
version = "1.0.2-alpha1"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -14,12 +14,12 @@ rand = "0.8"
1414
chrono = "0.4"
1515
sha-1 = "0.10"
1616
toml = "0.7"
17-
serde = "1"
17+
serde = { version = "1", features = ["derive"] }
1818
serde_derive = "1"
1919
regex = "1"
2020
num_cpus = "1"
2121
once_cell = "1"
22-
sqlparser = "0.33.0"
22+
sqlparser = {version = "0.33", features = ["visitor"] }
2323
log = "0.4"
2424
arc-swap = "1"
2525
env_logger = "0.10"
@@ -44,6 +44,7 @@ webpki-roots = "0.23"
4444
rustls = { version = "0.21", features = ["dangerous_configuration"] }
4545
trust-dns-resolver = "0.22.0"
4646
tokio-test = "0.4.2"
47+
serde_json = "1"
4748

4849
[target.'cfg(not(target_env = "msvc"))'.dependencies]
4950
jemallocator = "0.5.0"

pgcat.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ admin_username = "admin_user"
7777
# Password to access the virtual administrative database
7878
admin_password = "admin_pass"
7979

80+
# Plugins!!
81+
# query_router_plugins = ["pg_table_access", "intercept"]
82+
8083
# pool configs are structured as pool.<pool_name>
8184
# the pool_name is what clients use as database name when connecting.
8285
# For a pool named `sharded_db`, clients access that pool using connection string like

src/admin.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ use tokio::time::Instant;
1212
use crate::config::{get_config, reload_config, VERSION};
1313
use crate::errors::Error;
1414
use crate::messages::*;
15+
use crate::pool::ClientServerMap;
1516
use crate::pool::{get_all_pools, get_pool};
1617
use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState};
17-
use crate::ClientServerMap;
1818

1919
pub fn generate_server_info_for_admin() -> BytesMut {
2020
let mut server_info = BytesMut::new();

src/client.rs

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::auth_passthrough::refetch_auth_hash;
1616
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
1717
use crate::constants::*;
1818
use crate::messages::*;
19+
use crate::plugins::PluginOutput;
1920
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
2021
use crate::query_router::{Command, QueryRouter};
2122
use crate::server::Server;
@@ -765,6 +766,9 @@ where
765766

766767
self.stats.register(self.stats.clone());
767768

769+
// Result returned by one of the plugins.
770+
let mut plugin_output = None;
771+
768772
// Our custom protocol loop.
769773
// We expect the client to either start a transaction with regular queries
770774
// or issue commands for our sharding and server selection protocol.
@@ -815,15 +819,39 @@ where
815819

816820
'Q' => {
817821
if query_router.query_parser_enabled() {
818-
query_router.infer(&message);
822+
if let Ok(ast) = QueryRouter::parse(&message) {
823+
let plugin_result = query_router.execute_plugins(&ast).await;
824+
825+
match plugin_result {
826+
Ok(PluginOutput::Deny(error)) => {
827+
error_response(&mut self.write, &error).await?;
828+
continue;
829+
}
830+
831+
Ok(PluginOutput::Intercept(result)) => {
832+
write_all(&mut self.write, result).await?;
833+
continue;
834+
}
835+
836+
_ => (),
837+
};
838+
839+
let _ = query_router.infer(&ast);
840+
}
819841
}
820842
}
821843

822844
'P' => {
823845
self.buffer.put(&message[..]);
824846

825847
if query_router.query_parser_enabled() {
826-
query_router.infer(&message);
848+
if let Ok(ast) = QueryRouter::parse(&message) {
849+
if let Ok(output) = query_router.execute_plugins(&ast).await {
850+
plugin_output = Some(output);
851+
}
852+
853+
let _ = query_router.infer(&ast);
854+
}
827855
}
828856

829857
continue;
@@ -857,6 +885,18 @@ where
857885
continue;
858886
}
859887

888+
// Check on plugin results.
889+
match plugin_output {
890+
Some(PluginOutput::Deny(error)) => {
891+
self.buffer.clear();
892+
error_response(&mut self.write, &error).await?;
893+
plugin_output = None;
894+
continue;
895+
}
896+
897+
_ => (),
898+
};
899+
860900
// Get a pool instance referenced by the most up-to-date
861901
// pointer. This ensures we always read the latest config
862902
// when starting a query.
@@ -1085,6 +1125,27 @@ where
10851125
match code {
10861126
// Query
10871127
'Q' => {
1128+
if query_router.query_parser_enabled() {
1129+
if let Ok(ast) = QueryRouter::parse(&message) {
1130+
let plugin_result = query_router.execute_plugins(&ast).await;
1131+
1132+
match plugin_result {
1133+
Ok(PluginOutput::Deny(error)) => {
1134+
error_response(&mut self.write, &error).await?;
1135+
continue;
1136+
}
1137+
1138+
Ok(PluginOutput::Intercept(result)) => {
1139+
write_all(&mut self.write, result).await?;
1140+
continue;
1141+
}
1142+
1143+
_ => (),
1144+
};
1145+
1146+
let _ = query_router.infer(&ast);
1147+
}
1148+
}
10881149
debug!("Sending query to server");
10891150

10901151
self.send_and_receive_loop(
@@ -1124,6 +1185,14 @@ where
11241185
// Parse
11251186
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
11261187
'P' => {
1188+
if query_router.query_parser_enabled() {
1189+
if let Ok(ast) = QueryRouter::parse(&message) {
1190+
if let Ok(output) = query_router.execute_plugins(&ast).await {
1191+
plugin_output = Some(output);
1192+
}
1193+
}
1194+
}
1195+
11271196
self.buffer.put(&message[..]);
11281197
}
11291198

@@ -1155,6 +1224,24 @@ where
11551224
'S' => {
11561225
debug!("Sending query to server");
11571226

1227+
match plugin_output {
1228+
Some(PluginOutput::Deny(error)) => {
1229+
error_response(&mut self.write, &error).await?;
1230+
plugin_output = None;
1231+
self.buffer.clear();
1232+
continue;
1233+
}
1234+
1235+
Some(PluginOutput::Intercept(result)) => {
1236+
write_all(&mut self.write, result).await?;
1237+
plugin_output = None;
1238+
self.buffer.clear();
1239+
continue;
1240+
}
1241+
1242+
_ => (),
1243+
};
1244+
11581245
self.buffer.put(&message[..]);
11591246

11601247
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;

src/config.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,12 @@ pub struct General {
298298
pub admin_username: String,
299299
pub admin_password: String,
300300

301+
// Support for auth query
301302
pub auth_query: Option<String>,
302303
pub auth_query_user: Option<String>,
303304
pub auth_query_password: Option<String>,
305+
306+
pub query_router_plugins: Option<Vec<String>>,
304307
}
305308

306309
impl General {
@@ -401,6 +404,7 @@ impl Default for General {
401404
auth_query_user: None,
402405
auth_query_password: None,
403406
server_lifetime: 1000 * 3600 * 24, // 24 hours,
407+
query_router_plugins: None,
404408
}
405409
}
406410
}

src/errors.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Errors.
22
33
/// Various errors.
4-
#[derive(Debug, PartialEq)]
4+
#[derive(Debug, PartialEq, Clone)]
55
pub enum Error {
66
SocketError(String),
77
ClientSocketError(String, ClientIdentifier),
@@ -24,6 +24,8 @@ pub enum Error {
2424
ParseBytesError(String),
2525
AuthError(String),
2626
AuthPassthroughError(String),
27+
UnsupportedStatement,
28+
QueryRouterParserError(String),
2729
}
2830

2931
#[derive(Clone, PartialEq, Debug)]

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
pub mod admin;
12
pub mod auth_passthrough;
3+
pub mod client;
24
pub mod config;
35
pub mod constants;
46
pub mod dns_cache;
57
pub mod errors;
68
pub mod messages;
79
pub mod mirrors;
810
pub mod multi_logger;
11+
pub mod plugins;
912
pub mod pool;
13+
pub mod prometheus;
14+
pub mod query_router;
1015
pub mod scram;
1116
pub mod server;
1217
pub mod sharding;

0 commit comments

Comments
 (0)