@@ -12,6 +12,7 @@ use sqlparser::dialect::PostgreSqlDialect;
12
12
use sqlparser:: parser:: Parser ;
13
13
14
14
use crate :: config:: Role ;
15
+ use crate :: errors:: Error ;
15
16
use crate :: messages:: BytesMutReader ;
16
17
use crate :: pool:: PoolSettings ;
17
18
use crate :: sharding:: Sharder ;
@@ -324,10 +325,7 @@ impl QueryRouter {
324
325
Some ( ( command, value) )
325
326
}
326
327
327
- /// Try to infer which server to connect to based on the contents of the query.
328
- pub fn infer ( & mut self , message : & BytesMut ) -> bool {
329
- debug ! ( "Inferring role" ) ;
330
-
328
+ pub fn parse ( message : & BytesMut ) -> Result < Vec < sqlparser:: ast:: Statement > , Error > {
331
329
let mut message_cursor = Cursor :: new ( message) ;
332
330
333
331
let code = message_cursor. get_u8 ( ) as char ;
@@ -353,28 +351,33 @@ impl QueryRouter {
353
351
query
354
352
}
355
353
356
- _ => return false ,
354
+ _ => return Err ( Error :: UnsupportedStatement ) ,
357
355
} ;
358
356
359
- let ast = match Parser :: parse_sql ( & PostgreSqlDialect { } , & query) {
360
- Ok ( ast) => ast,
357
+ match Parser :: parse_sql ( & PostgreSqlDialect { } , & query) {
358
+ Ok ( ast) => {
359
+ debug ! ( "AST: {:?}" , ast) ;
360
+ Ok ( ast)
361
+ }
362
+
361
363
Err ( err) => {
362
- // SELECT ... FOR UPDATE won't get parsed correctly.
363
364
debug ! ( "{}: {}" , err, query) ;
364
- self . active_role = Some ( Role :: Primary ) ;
365
- return false ;
365
+ Err ( Error :: QueryRouterParserError ( err. to_string ( ) ) )
366
366
}
367
- } ;
367
+ }
368
+ }
368
369
369
- debug ! ( "AST: {:?}" , ast) ;
370
+ /// Try to infer which server to connect to based on the contents of the query.
371
+ pub fn infer ( & mut self , ast : & Vec < sqlparser:: ast:: Statement > ) -> Result < ( ) , Error > {
372
+ debug ! ( "Inferring role" ) ;
370
373
371
374
if ast. is_empty ( ) {
372
375
// That's weird, no idea, let's go to primary
373
376
self . active_role = Some ( Role :: Primary ) ;
374
- return false ;
377
+ return Err ( Error :: QueryRouterParserError ( "empty query" . into ( ) ) ) ;
375
378
}
376
379
377
- for q in & ast {
380
+ for q in ast {
378
381
match q {
379
382
// All transactions go to the primary, probably a write.
380
383
StartTransaction { .. } => {
@@ -418,7 +421,7 @@ impl QueryRouter {
418
421
} ;
419
422
}
420
423
421
- true
424
+ Ok ( ( ) )
422
425
}
423
426
424
427
/// Parse the shard number from the Bind message
@@ -862,7 +865,7 @@ mod test {
862
865
863
866
for query in queries {
864
867
// It's a recognized query
865
- assert ! ( qr. infer( & query) ) ;
868
+ assert ! ( qr. infer( & QueryRouter :: parse ( & query) . unwrap ( ) ) . is_ok ( ) ) ;
866
869
assert_eq ! ( qr. role( ) , Some ( Role :: Replica ) ) ;
867
870
}
868
871
}
@@ -881,7 +884,7 @@ mod test {
881
884
882
885
for query in queries {
883
886
// It's a recognized query
884
- assert ! ( qr. infer( & query) ) ;
887
+ assert ! ( qr. infer( & QueryRouter :: parse ( & query) . unwrap ( ) ) . is_ok ( ) ) ;
885
888
assert_eq ! ( qr. role( ) , Some ( Role :: Primary ) ) ;
886
889
}
887
890
}
@@ -893,7 +896,7 @@ mod test {
893
896
let query = simple_query ( "SELECT * FROM items WHERE id = 5" ) ;
894
897
assert ! ( qr. try_execute_command( & simple_query( "SET PRIMARY READS TO on" ) ) != None ) ;
895
898
896
- assert ! ( qr. infer( & query) ) ;
899
+ assert ! ( qr. infer( & QueryRouter :: parse ( & query) . unwrap ( ) ) . is_ok ( ) ) ;
897
900
assert_eq ! ( qr. role( ) , None ) ;
898
901
}
899
902
@@ -913,7 +916,7 @@ mod test {
913
916
res. put ( prepared_stmt) ;
914
917
res. put_i16 ( 0 ) ;
915
918
916
- assert ! ( qr. infer( & res) ) ;
919
+ assert ! ( qr. infer( & QueryRouter :: parse ( & res) . unwrap ( ) ) . is_ok ( ) ) ;
917
920
assert_eq ! ( qr. role( ) , Some ( Role :: Replica ) ) ;
918
921
}
919
922
@@ -1077,11 +1080,11 @@ mod test {
1077
1080
assert_eq ! ( qr. role( ) , None ) ;
1078
1081
1079
1082
let query = simple_query ( "INSERT INTO test_table VALUES (1)" ) ;
1080
- assert ! ( qr. infer( & query) ) ;
1083
+ assert ! ( qr. infer( & QueryRouter :: parse ( & query) . unwrap ( ) ) . is_ok ( ) ) ;
1081
1084
assert_eq ! ( qr. role( ) , Some ( Role :: Primary ) ) ;
1082
1085
1083
1086
let query = simple_query ( "SELECT * FROM test_table" ) ;
1084
- assert ! ( qr. infer( & query) ) ;
1087
+ assert ! ( qr. infer( & QueryRouter :: parse ( & query) . unwrap ( ) ) . is_ok ( ) ) ;
1085
1088
assert_eq ! ( qr. role( ) , Some ( Role :: Replica ) ) ;
1086
1089
1087
1090
assert ! ( qr. query_parser_enabled( ) ) ;
@@ -1142,15 +1145,24 @@ mod test {
1142
1145
QueryRouter :: setup ( ) ;
1143
1146
1144
1147
let mut qr = QueryRouter :: new ( ) ;
1145
- assert ! ( qr. infer( & simple_query( "BEGIN; SELECT 1; COMMIT;" ) ) ) ;
1148
+ assert ! ( qr
1149
+ . infer( & QueryRouter :: parse( & simple_query( "BEGIN; SELECT 1; COMMIT;" ) ) . unwrap( ) )
1150
+ . is_ok( ) ) ;
1146
1151
assert_eq ! ( qr. role( ) , Role :: Primary ) ;
1147
1152
1148
- assert ! ( qr. infer( & simple_query( "SELECT 1; SELECT 2;" ) ) ) ;
1153
+ assert ! ( qr
1154
+ . infer( & QueryRouter :: parse( & simple_query( "SELECT 1; SELECT 2;" ) ) . unwrap( ) )
1155
+ . is_ok( ) ) ;
1149
1156
assert_eq ! ( qr. role( ) , Role :: Replica ) ;
1150
1157
1151
- assert ! ( qr. infer( & simple_query(
1152
- "SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
1153
- ) ) ) ;
1158
+ assert ! ( qr
1159
+ . infer(
1160
+ & QueryRouter :: parse( & simple_query(
1161
+ "SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
1162
+ ) )
1163
+ . unwrap( )
1164
+ )
1165
+ . is_ok( ) ) ;
1154
1166
assert_eq ! ( qr. role( ) , Role :: Primary ) ;
1155
1167
}
1156
1168
@@ -1208,47 +1220,84 @@ mod test {
1208
1220
qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
1209
1221
qr. pool_settings . shards = 3 ;
1210
1222
1211
- assert ! ( qr. infer( & simple_query( "SELECT * FROM data WHERE id = 5" ) ) ) ;
1223
+ assert ! ( qr
1224
+ . infer( & QueryRouter :: parse( & simple_query( "SELECT * FROM data WHERE id = 5" ) ) . unwrap( ) )
1225
+ . is_ok( ) ) ;
1212
1226
assert_eq ! ( qr. shard( ) , 2 ) ;
1213
1227
1214
- assert ! ( qr. infer( & simple_query(
1215
- "SELECT one, two, three FROM public.data WHERE id = 6"
1216
- ) ) ) ;
1228
+ assert ! ( qr
1229
+ . infer(
1230
+ & QueryRouter :: parse( & simple_query(
1231
+ "SELECT one, two, three FROM public.data WHERE id = 6"
1232
+ ) )
1233
+ . unwrap( )
1234
+ )
1235
+ . is_ok( ) ) ;
1217
1236
assert_eq ! ( qr. shard( ) , 0 ) ;
1218
1237
1219
- assert ! ( qr. infer( & simple_query(
1220
- "SELECT * FROM data
1238
+ assert ! ( qr
1239
+ . infer(
1240
+ & QueryRouter :: parse( & simple_query(
1241
+ "SELECT * FROM data
1221
1242
INNER JOIN t2 ON data.id = 5
1222
1243
AND t2.data_id = data.id
1223
1244
WHERE data.id = 5"
1224
- ) ) ) ;
1245
+ ) )
1246
+ . unwrap( )
1247
+ )
1248
+ . is_ok( ) ) ;
1225
1249
assert_eq ! ( qr. shard( ) , 2 ) ;
1226
1250
1227
1251
// Shard did not move because we couldn't determine the sharding key since it could be ambiguous
1228
1252
// in the query.
1229
- assert ! ( qr. infer( & simple_query(
1230
- "SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
1231
- ) ) ) ;
1253
+ assert ! ( qr
1254
+ . infer(
1255
+ & QueryRouter :: parse( & simple_query(
1256
+ "SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
1257
+ ) )
1258
+ . unwrap( )
1259
+ )
1260
+ . is_ok( ) ) ;
1232
1261
assert_eq ! ( qr. shard( ) , 2 ) ;
1233
1262
1234
- assert ! ( qr. infer( & simple_query(
1235
- r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
1236
- ) ) ) ;
1263
+ assert ! ( qr
1264
+ . infer(
1265
+ & QueryRouter :: parse( & simple_query(
1266
+ r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
1267
+ ) )
1268
+ . unwrap( )
1269
+ )
1270
+ . is_ok( ) ) ;
1237
1271
assert_eq ! ( qr. shard( ) , 0 ) ;
1238
1272
1239
- assert ! ( qr. infer( & simple_query(
1240
- r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
1241
- ) ) ) ;
1273
+ assert ! ( qr
1274
+ . infer(
1275
+ & QueryRouter :: parse( & simple_query(
1276
+ r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
1277
+ ) )
1278
+ . unwrap( )
1279
+ )
1280
+ . is_ok( ) ) ;
1242
1281
assert_eq ! ( qr. shard( ) , 2 ) ;
1243
1282
1244
1283
// Super unique sharding key
1245
1284
qr. pool_settings . automatic_sharding_key = Some ( "*.unique_enough_column_name" . to_string ( ) ) ;
1246
- assert ! ( qr. infer( & simple_query(
1247
- "SELECT * FROM table_x WHERE unique_enough_column_name = 6"
1248
- ) ) ) ;
1285
+ assert ! ( qr
1286
+ . infer(
1287
+ & QueryRouter :: parse( & simple_query(
1288
+ "SELECT * FROM table_x WHERE unique_enough_column_name = 6"
1289
+ ) )
1290
+ . unwrap( )
1291
+ )
1292
+ . is_ok( ) ) ;
1249
1293
assert_eq ! ( qr. shard( ) , 0 ) ;
1250
1294
1251
- assert ! ( qr. infer( & simple_query( "SELECT * FROM table_y WHERE another_key = 5" ) ) ) ;
1295
+ assert ! ( qr
1296
+ . infer(
1297
+ & QueryRouter :: parse( & simple_query( "SELECT * FROM table_y WHERE another_key = 5" ) )
1298
+ . unwrap( )
1299
+ )
1300
+ . is_ok( ) ) ;
1252
1301
assert_eq ! ( qr. shard( ) , 0 ) ;
1253
1302
}
1254
1303
@@ -1272,11 +1321,21 @@ mod test {
1272
1321
qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
1273
1322
qr. pool_settings . shards = 3 ;
1274
1323
1275
- assert ! ( qr. infer( & simple_query( stmt) ) ) ;
1324
+ assert ! ( qr
1325
+ . infer( & QueryRouter :: parse( & simple_query( stmt) ) . unwrap( ) )
1326
+ . is_ok( ) ) ;
1276
1327
assert_eq ! ( qr. placeholders. len( ) , 1 ) ;
1277
1328
1278
1329
assert ! ( qr. infer_shard_from_bind( & bind) ) ;
1279
1330
assert_eq ! ( qr. shard( ) , 2 ) ;
1280
1331
assert ! ( qr. placeholders. is_empty( ) ) ;
1281
1332
}
1333
+
1334
+ #[ test]
1335
+ fn test_parse ( ) {
1336
+ let query = simple_query ( "SELECT * FROM pg_database" ) ;
1337
+ let ast = QueryRouter :: parse ( & query) ;
1338
+
1339
+ assert ! ( ast. is_ok( ) ) ;
1340
+ }
1282
1341
}
0 commit comments