Skip to content

Commit 7d2b695

Browse files
committed
Some queries
1 parent 3601130 commit 7d2b695

File tree

7 files changed

+133
-52
lines changed

7 files changed

+133
-52
lines changed

Cargo.lock

Lines changed: 13 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgcat"
3-
version = "1.0.1"
3+
version = "1.0.2-alpha1"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -19,7 +19,7 @@ serde_derive = "1"
1919
regex = "1"
2020
num_cpus = "1"
2121
once_cell = "1"
22-
sqlparser = "0.33.0"
22+
sqlparser = {version = "0.33", features = ["visitor"] }
2323
log = "0.4"
2424
arc-swap = "1"
2525
env_logger = "0.10"

src/client.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -815,15 +815,19 @@ where
815815

816816
'Q' => {
817817
if query_router.query_parser_enabled() {
818-
query_router.infer(&message);
818+
if let Ok(ast) = QueryRouter::parse(&message) {
819+
let _ = query_router.infer(&ast);
820+
}
819821
}
820822
}
821823

822824
'P' => {
823825
self.buffer.put(&message[..]);
824826

825827
if query_router.query_parser_enabled() {
826-
query_router.infer(&message);
828+
if let Ok(ast) = QueryRouter::parse(&message) {
829+
let _ = query_router.infer(&ast);
830+
}
827831
}
828832

829833
continue;

src/config.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ pub struct General {
291291
pub admin_username: String,
292292
pub admin_password: String,
293293

294+
// Support for auth query
294295
pub auth_query: Option<String>,
295296
pub auth_query_user: Option<String>,
296297
pub auth_query_password: Option<String>,

src/errors.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ pub enum Error {
2323
ParseBytesError(String),
2424
AuthError(String),
2525
AuthPassthroughError(String),
26+
UnsupportedStatement,
27+
QueryRouterParserError(String),
28+
PermissionDeniedTable(String),
2629
}
2730

2831
#[derive(Clone, PartialEq, Debug)]

src/pool.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ impl ConnectionPool {
395395
);
396396
}
397397

398+
debug!("Query router: {}", pool_config.query_parser_enabled);
399+
398400
let pool = ConnectionPool {
399401
databases: shards,
400402
stats: pool_stats,

src/query_router.rs

Lines changed: 106 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use sqlparser::dialect::PostgreSqlDialect;
1212
use sqlparser::parser::Parser;
1313

1414
use crate::config::Role;
15+
use crate::errors::Error;
1516
use crate::messages::BytesMutReader;
1617
use crate::pool::PoolSettings;
1718
use crate::sharding::Sharder;
@@ -324,10 +325,7 @@ impl QueryRouter {
324325
Some((command, value))
325326
}
326327

327-
/// Try to infer which server to connect to based on the contents of the query.
328-
pub fn infer(&mut self, message: &BytesMut) -> bool {
329-
debug!("Inferring role");
330-
328+
pub fn parse(message: &BytesMut) -> Result<Vec<sqlparser::ast::Statement>, Error> {
331329
let mut message_cursor = Cursor::new(message);
332330

333331
let code = message_cursor.get_u8() as char;
@@ -353,28 +351,33 @@ impl QueryRouter {
353351
query
354352
}
355353

356-
_ => return false,
354+
_ => return Err(Error::UnsupportedStatement),
357355
};
358356

359-
let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
360-
Ok(ast) => ast,
357+
match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
358+
Ok(ast) => {
359+
debug!("AST: {:?}", ast);
360+
Ok(ast)
361+
}
362+
361363
Err(err) => {
362-
// SELECT ... FOR UPDATE won't get parsed correctly.
363364
debug!("{}: {}", err, query);
364-
self.active_role = Some(Role::Primary);
365-
return false;
365+
Err(Error::QueryRouterParserError(err.to_string()))
366366
}
367-
};
367+
}
368+
}
368369

369-
debug!("AST: {:?}", ast);
370+
/// Try to infer which server to connect to based on the contents of the query.
371+
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
372+
debug!("Inferring role");
370373

371374
if ast.is_empty() {
372375
// That's weird, no idea, let's go to primary
373376
self.active_role = Some(Role::Primary);
374-
return false;
377+
return Err(Error::QueryRouterParserError("empty query".into()));
375378
}
376379

377-
for q in &ast {
380+
for q in ast {
378381
match q {
379382
// All transactions go to the primary, probably a write.
380383
StartTransaction { .. } => {
@@ -418,7 +421,7 @@ impl QueryRouter {
418421
};
419422
}
420423

421-
true
424+
Ok(())
422425
}
423426

424427
/// Parse the shard number from the Bind message
@@ -862,7 +865,7 @@ mod test {
862865

863866
for query in queries {
864867
// It's a recognized query
865-
assert!(qr.infer(&query));
868+
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
866869
assert_eq!(qr.role(), Some(Role::Replica));
867870
}
868871
}
@@ -881,7 +884,7 @@ mod test {
881884

882885
for query in queries {
883886
// It's a recognized query
884-
assert!(qr.infer(&query));
887+
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
885888
assert_eq!(qr.role(), Some(Role::Primary));
886889
}
887890
}
@@ -893,7 +896,7 @@ mod test {
893896
let query = simple_query("SELECT * FROM items WHERE id = 5");
894897
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
895898

896-
assert!(qr.infer(&query));
899+
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
897900
assert_eq!(qr.role(), None);
898901
}
899902

@@ -913,7 +916,7 @@ mod test {
913916
res.put(prepared_stmt);
914917
res.put_i16(0);
915918

916-
assert!(qr.infer(&res));
919+
assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok());
917920
assert_eq!(qr.role(), Some(Role::Replica));
918921
}
919922

@@ -1077,11 +1080,11 @@ mod test {
10771080
assert_eq!(qr.role(), None);
10781081

10791082
let query = simple_query("INSERT INTO test_table VALUES (1)");
1080-
assert!(qr.infer(&query));
1083+
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
10811084
assert_eq!(qr.role(), Some(Role::Primary));
10821085

10831086
let query = simple_query("SELECT * FROM test_table");
1084-
assert!(qr.infer(&query));
1087+
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
10851088
assert_eq!(qr.role(), Some(Role::Replica));
10861089

10871090
assert!(qr.query_parser_enabled());
@@ -1142,15 +1145,24 @@ mod test {
11421145
QueryRouter::setup();
11431146

11441147
let mut qr = QueryRouter::new();
1145-
assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;")));
1148+
assert!(qr
1149+
.infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
1150+
.is_ok());
11461151
assert_eq!(qr.role(), Role::Primary);
11471152

1148-
assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;")));
1153+
assert!(qr
1154+
.infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
1155+
.is_ok());
11491156
assert_eq!(qr.role(), Role::Replica);
11501157

1151-
assert!(qr.infer(&simple_query(
1152-
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
1153-
)));
1158+
assert!(qr
1159+
.infer(
1160+
&QueryRouter::parse(&simple_query(
1161+
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
1162+
))
1163+
.unwrap()
1164+
)
1165+
.is_ok());
11541166
assert_eq!(qr.role(), Role::Primary);
11551167
}
11561168

@@ -1208,47 +1220,84 @@ mod test {
12081220
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
12091221
qr.pool_settings.shards = 3;
12101222

1211-
assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5")));
1223+
assert!(qr
1224+
.infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap())
1225+
.is_ok());
12121226
assert_eq!(qr.shard(), 2);
12131227

1214-
assert!(qr.infer(&simple_query(
1215-
"SELECT one, two, three FROM public.data WHERE id = 6"
1216-
)));
1228+
assert!(qr
1229+
.infer(
1230+
&QueryRouter::parse(&simple_query(
1231+
"SELECT one, two, three FROM public.data WHERE id = 6"
1232+
))
1233+
.unwrap()
1234+
)
1235+
.is_ok());
12171236
assert_eq!(qr.shard(), 0);
12181237

1219-
assert!(qr.infer(&simple_query(
1220-
"SELECT * FROM data
1238+
assert!(qr
1239+
.infer(
1240+
&QueryRouter::parse(&simple_query(
1241+
"SELECT * FROM data
12211242
INNER JOIN t2 ON data.id = 5
12221243
AND t2.data_id = data.id
12231244
WHERE data.id = 5"
1224-
)));
1245+
))
1246+
.unwrap()
1247+
)
1248+
.is_ok());
12251249
assert_eq!(qr.shard(), 2);
12261250

12271251
// Shard did not move because we couldn't determine the sharding key since it could be ambiguous
12281252
// in the query.
1229-
assert!(qr.infer(&simple_query(
1230-
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
1231-
)));
1253+
assert!(qr
1254+
.infer(
1255+
&QueryRouter::parse(&simple_query(
1256+
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
1257+
))
1258+
.unwrap()
1259+
)
1260+
.is_ok());
12321261
assert_eq!(qr.shard(), 2);
12331262

1234-
assert!(qr.infer(&simple_query(
1235-
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
1236-
)));
1263+
assert!(qr
1264+
.infer(
1265+
&QueryRouter::parse(&simple_query(
1266+
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
1267+
))
1268+
.unwrap()
1269+
)
1270+
.is_ok());
12371271
assert_eq!(qr.shard(), 0);
12381272

1239-
assert!(qr.infer(&simple_query(
1240-
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
1241-
)));
1273+
assert!(qr
1274+
.infer(
1275+
&QueryRouter::parse(&simple_query(
1276+
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
1277+
))
1278+
.unwrap()
1279+
)
1280+
.is_ok());
12421281
assert_eq!(qr.shard(), 2);
12431282

12441283
// Super unique sharding key
12451284
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
1246-
assert!(qr.infer(&simple_query(
1247-
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
1248-
)));
1285+
assert!(qr
1286+
.infer(
1287+
&QueryRouter::parse(&simple_query(
1288+
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
1289+
))
1290+
.unwrap()
1291+
)
1292+
.is_ok());
12491293
assert_eq!(qr.shard(), 0);
12501294

1251-
assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5")));
1295+
assert!(qr
1296+
.infer(
1297+
&QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
1298+
.unwrap()
1299+
)
1300+
.is_ok());
12521301
assert_eq!(qr.shard(), 0);
12531302
}
12541303

@@ -1272,11 +1321,21 @@ mod test {
12721321
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
12731322
qr.pool_settings.shards = 3;
12741323

1275-
assert!(qr.infer(&simple_query(stmt)));
1324+
assert!(qr
1325+
.infer(&QueryRouter::parse(&simple_query(stmt)).unwrap())
1326+
.is_ok());
12761327
assert_eq!(qr.placeholders.len(), 1);
12771328

12781329
assert!(qr.infer_shard_from_bind(&bind));
12791330
assert_eq!(qr.shard(), 2);
12801331
assert!(qr.placeholders.is_empty());
12811332
}
1333+
1334+
#[test]
1335+
fn test_parse() {
1336+
let query = simple_query("SELECT * FROM pg_database");
1337+
let ast = QueryRouter::parse(&query);
1338+
1339+
assert!(ast.is_ok());
1340+
}
12821341
}

0 commit comments

Comments
 (0)