Skip to content

Commit 54e9865

Browse files
committed
Plugins!!
1 parent 604bf99 commit 54e9865

File tree

12 files changed

+239
-51
lines changed

12 files changed

+239
-51
lines changed

Cargo.lock

Lines changed: 21 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ rand = "0.8"
1414
chrono = "0.4"
1515
sha-1 = "0.10"
1616
toml = "0.7"
17-
serde = "1"
17+
serde = { version = "1", features = ["derive"] }
1818
serde_derive = "1"
1919
regex = "1"
2020
num_cpus = "1"
@@ -44,6 +44,7 @@ webpki-roots = "0.23"
4444
rustls = { version = "0.21", features = ["dangerous_configuration"] }
4545
trust-dns-resolver = "0.22.0"
4646
tokio-test = "0.4.2"
47+
serde_json = "1"
4748

4849
[target.'cfg(not(target_env = "msvc"))'.dependencies]
4950
jemallocator = "0.5.0"

pgcat.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ admin_username = "admin_user"
7777
# Password to access the virtual administrative database
7878
admin_password = "admin_pass"
7979

80+
# Plugins!!
81+
# plugins = ["pg_table_access", "intercept"]
82+
8083
# pool configs are structured as pool.<pool_name>
8184
# the pool_name is what clients use as database name when connecting.
8285
# For a pool named `sharded_db`, clients access that pool using connection string like

src/admin.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ use tokio::time::Instant;
1212
use crate::config::{get_config, reload_config, VERSION};
1313
use crate::errors::Error;
1414
use crate::messages::*;
15+
use crate::pool::ClientServerMap;
1516
use crate::pool::{get_all_pools, get_pool};
1617
use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState};
17-
use crate::ClientServerMap;
1818

1919
pub fn generate_server_info_for_admin() -> BytesMut {
2020
let mut server_info = BytesMut::new();

src/client.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::auth_passthrough::refetch_auth_hash;
1616
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
1717
use crate::constants::*;
1818
use crate::messages::*;
19+
use crate::plugins::PluginOutput;
1920
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
2021
use crate::query_router::{Command, QueryRouter};
2122
use crate::server::Server;
@@ -765,6 +766,9 @@ where
765766

766767
self.stats.register(self.stats.clone());
767768

769+
// Error returned by one of the plugins.
770+
let mut plugin_output = None;
771+
768772
// Our custom protocol loop.
769773
// We expect the client to either start a transaction with regular queries
770774
// or issue commands for our sharding and server selection protocol.
@@ -816,6 +820,22 @@ where
816820
'Q' => {
817821
if query_router.query_parser_enabled() {
818822
if let Ok(ast) = QueryRouter::parse(&message) {
823+
let plugin_result = query_router.execute_plugins(&ast).await;
824+
825+
match plugin_result {
826+
Ok(PluginOutput::Deny(error)) => {
827+
error_response(&mut self.write, &error).await?;
828+
continue;
829+
}
830+
831+
Ok(PluginOutput::Intercept(result)) => {
832+
write_all(&mut self.write, result).await?;
833+
continue;
834+
}
835+
836+
_ => (),
837+
};
838+
819839
let _ = query_router.infer(&ast);
820840
}
821841
}
@@ -826,6 +846,10 @@ where
826846

827847
if query_router.query_parser_enabled() {
828848
if let Ok(ast) = QueryRouter::parse(&message) {
849+
if let Ok(output) = query_router.execute_plugins(&ast).await {
850+
plugin_output = Some(output);
851+
}
852+
829853
let _ = query_router.infer(&ast);
830854
}
831855
}
@@ -861,6 +885,18 @@ where
861885
continue;
862886
}
863887

888+
// Check on plugin results.
889+
match plugin_output {
890+
Some(PluginOutput::Deny(error)) => {
891+
self.buffer.clear();
892+
error_response(&mut self.write, &error).await?;
893+
plugin_output = None;
894+
continue;
895+
}
896+
897+
_ => (),
898+
};
899+
864900
// Get a pool instance referenced by the most up-to-date
865901
// pointer. This ensures we always read the latest config
866902
// when starting a query.
@@ -1089,6 +1125,27 @@ where
10891125
match code {
10901126
// Query
10911127
'Q' => {
1128+
if query_router.query_parser_enabled() {
1129+
if let Ok(ast) = QueryRouter::parse(&message) {
1130+
let plugin_result = query_router.execute_plugins(&ast).await;
1131+
1132+
match plugin_result {
1133+
Ok(PluginOutput::Deny(error)) => {
1134+
error_response(&mut self.write, &error).await?;
1135+
continue;
1136+
}
1137+
1138+
Ok(PluginOutput::Intercept(result)) => {
1139+
write_all(&mut self.write, result).await?;
1140+
continue;
1141+
}
1142+
1143+
_ => (),
1144+
};
1145+
1146+
let _ = query_router.infer(&ast);
1147+
}
1148+
}
10921149
debug!("Sending query to server");
10931150

10941151
self.send_and_receive_loop(
@@ -1128,6 +1185,14 @@ where
11281185
// Parse
11291186
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
11301187
'P' => {
1188+
if query_router.query_parser_enabled() {
1189+
if let Ok(ast) = QueryRouter::parse(&message) {
1190+
if let Ok(output) = query_router.execute_plugins(&ast).await {
1191+
plugin_output = Some(output);
1192+
}
1193+
}
1194+
}
1195+
11311196
self.buffer.put(&message[..]);
11321197
}
11331198

@@ -1159,6 +1224,24 @@ where
11591224
'S' => {
11601225
debug!("Sending query to server");
11611226

1227+
match plugin_output {
1228+
Some(PluginOutput::Deny(error)) => {
1229+
error_response(&mut self.write, &error).await?;
1230+
plugin_output = None;
1231+
self.buffer.clear();
1232+
continue;
1233+
}
1234+
1235+
Some(PluginOutput::Intercept(result)) => {
1236+
write_all(&mut self.write, result).await?;
1237+
plugin_output = None;
1238+
self.buffer.clear();
1239+
continue;
1240+
}
1241+
1242+
_ => (),
1243+
};
1244+
11621245
self.buffer.put(&message[..]);
11631246

11641247
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;

src/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,8 @@ pub struct General {
302302
pub auth_query: Option<String>,
303303
pub auth_query_user: Option<String>,
304304
pub auth_query_password: Option<String>,
305+
306+
pub query_router_plugins: Option<Vec<String>>,
305307
}
306308

307309
impl General {
@@ -402,6 +404,7 @@ impl Default for General {
402404
auth_query_user: None,
403405
auth_query_password: None,
404406
server_lifetime: 1000 * 3600 * 24, // 24 hours,
407+
query_router_plugins: None,
405408
}
406409
}
407410
}

src/errors.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Errors.
22
33
/// Various errors.
4-
#[derive(Debug, PartialEq)]
4+
#[derive(Debug, PartialEq, Clone)]
55
pub enum Error {
66
SocketError(String),
77
ClientSocketError(String, ClientIdentifier),
@@ -26,7 +26,9 @@ pub enum Error {
2626
AuthPassthroughError(String),
2727
UnsupportedStatement,
2828
QueryRouterParserError(String),
29+
PermissionDenied(String),
2930
PermissionDeniedTable(String),
31+
QueryDenied(String),
3032
}
3133

3234
#[derive(Clone, PartialEq, Debug)]

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
pub mod admin;
12
pub mod auth_passthrough;
3+
pub mod client;
24
pub mod config;
35
pub mod constants;
46
pub mod dns_cache;
57
pub mod errors;
68
pub mod messages;
79
pub mod mirrors;
810
pub mod multi_logger;
11+
pub mod plugins;
912
pub mod pool;
13+
pub mod prometheus;
14+
pub mod query_router;
1015
pub mod scram;
1116
pub mod server;
1217
pub mod sharding;

src/main.rs

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -61,37 +61,19 @@ use std::str::FromStr;
6161
use std::sync::Arc;
6262
use tokio::sync::broadcast;
6363

64-
mod admin;
65-
mod auth_passthrough;
66-
mod client;
67-
mod config;
68-
mod constants;
69-
mod dns_cache;
70-
mod errors;
71-
mod messages;
72-
mod mirrors;
73-
mod multi_logger;
74-
mod pool;
75-
mod prometheus;
76-
mod query_router;
77-
mod scram;
78-
mod server;
79-
mod sharding;
80-
mod stats;
81-
mod tls;
82-
83-
use crate::config::{get_config, reload_config, VERSION};
84-
use crate::messages::configure_socket;
85-
use crate::pool::{ClientServerMap, ConnectionPool};
86-
use crate::prometheus::start_metric_server;
87-
use crate::stats::{Collector, Reporter, REPORTER};
64+
use pgcat::config::{get_config, reload_config, VERSION};
65+
use pgcat::messages::configure_socket;
66+
use pgcat::pool::{ClientServerMap, ConnectionPool};
67+
use pgcat::prometheus::start_metric_server;
68+
use pgcat::stats::{Collector, Reporter, REPORTER};
69+
use pgcat::dns_cache;
8870

8971
fn main() -> Result<(), Box<dyn std::error::Error>> {
90-
multi_logger::MultiLogger::init().unwrap();
72+
pgcat::multi_logger::MultiLogger::init().unwrap();
9173

9274
info!("Welcome to PgCat! Meow. (Version {})", VERSION);
9375

94-
if !query_router::QueryRouter::setup() {
76+
if !pgcat::query_router::QueryRouter::setup() {
9577
error!("Could not setup query router");
9678
std::process::exit(exitcode::CONFIG);
9779
}
@@ -109,7 +91,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
10991
let runtime = Builder::new_multi_thread().worker_threads(1).build()?;
11092

11193
runtime.block_on(async {
112-
match config::parse(&config_file).await {
94+
match pgcat::config::parse(&config_file).await {
11395
Ok(_) => (),
11496
Err(err) => {
11597
error!("Config parse error: {:?}", err);
@@ -168,14 +150,14 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
168150
// Statistics reporting.
169151
REPORTER.store(Arc::new(Reporter::default()));
170152

171-
// Starts (if enabled) dns cache before pools initialization
172-
match dns_cache::CachedResolver::from_config().await {
173-
Ok(_) => (),
174-
Err(err) => error!("DNS cache initialization error: {:?}", err),
175-
};
153+
// Starts (if enabled) dns cache before pools initialization
154+
match dns_cache::CachedResolver::from_config().await {
155+
Ok(_) => (),
156+
Err(err) => error!("DNS cache initialization error: {:?}", err),
157+
};
176158

177-
// Connection pool that allows to query all shards and replicas.
178-
match ConnectionPool::from_config(client_server_map.clone()).await {
159+
// Connection pool that allows to query all shards and replicas.
160+
match ConnectionPool::from_config(client_server_map.clone()).await {
179161
Ok(_) => (),
180162
Err(err) => {
181163
error!("Pool error: {:?}", err);
@@ -303,7 +285,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
303285
tokio::task::spawn(async move {
304286
let start = chrono::offset::Utc::now().naive_utc();
305287

306-
match client::client_entrypoint(
288+
match pgcat::client::client_entrypoint(
307289
socket,
308290
client_server_map,
309291
shutdown_rx,
@@ -334,7 +316,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
334316

335317
Err(err) => {
336318
match err {
337-
errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
319+
pgcat::errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
338320
_ => warn!("Client disconnected with error {:?}", err),
339321
}
340322

0 commit comments

Comments
 (0)