Skip to content

Commit 4b60ab6

Browse files
authored
Add other functions to SQL query validator (#3304)
* Add uncategorized functions to SQL query validator Signed-off-by: Tomoyuki Morita <moritato@amazon.com> * Fix variable name Signed-off-by: Tomoyuki Morita <moritato@amazon.com> * Fix name from uncategorized to other Signed-off-by: Tomoyuki Morita <moritato@amazon.com> --------- Signed-off-by: Tomoyuki Morita <moritato@amazon.com>
1 parent 3450daa commit 4b60ab6

File tree

6 files changed

+98
-9
lines changed

6 files changed

+98
-9
lines changed

async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public enum FunctionType {
3232
CSV("CSV"),
3333
MISC("Misc"),
3434
GENERATOR("Generator"),
35+
OTHER("Other"),
3536
UDF("User Defined Function");
3637

3738
private final String name;
@@ -422,6 +423,51 @@ public enum FunctionType {
422423
"posexplode",
423424
"posexplode_outer",
424425
"stack"))
426+
.put(
427+
OTHER,
428+
Set.of(
429+
"aggregate",
430+
"array_size",
431+
"array_sort",
432+
"cardinality",
433+
"crc32",
434+
"exists",
435+
"filter",
436+
"forall",
437+
"hash",
438+
"ilike",
439+
"in",
440+
"like",
441+
"map_filter",
442+
"map_zip_with",
443+
"md5",
444+
"mod",
445+
"named_struct",
446+
"parse_url",
447+
"raise_error",
448+
"reduce",
449+
"reverse",
450+
"sha",
451+
"sha1",
452+
"sha2",
453+
"size",
454+
"struct",
455+
"transform",
456+
"transform_keys",
457+
"transform_values",
458+
"url_decode",
459+
"url_encode",
460+
"xpath",
461+
"xpath_boolean",
462+
"xpath_double",
463+
"xpath_float",
464+
"xpath_int",
465+
"xpath_long",
466+
"xpath_number",
467+
"xpath_short",
468+
"xpath_string",
469+
"xxhash64",
470+
"zip_with"))
425471
.build();
426472

427473
private static final Map<String, FunctionType> FUNCTION_NAME_TO_FUNCTION_TYPE_MAP =

async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLGrammarElement.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ public enum SQLGrammarElement implements GrammarElement {
7878
CSV_FUNCTIONS("CSV functions"),
7979
GENERATOR_FUNCTIONS("Generator functions"),
8080
MISC_FUNCTIONS("Misc functions"),
81+
OTHER_FUNCTIONS("Other functions"),
8182

8283
// UDF
8384
UDF("User Defined functions");

async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -560,26 +560,30 @@ public Void visitFunctionName(FunctionNameContext ctx) {
560560
return super.visitFunctionName(ctx);
561561
}
562562

563-
private void validateFunctionAllowed(String function) {
564-
FunctionType type = FunctionType.fromFunctionName(function.toLowerCase());
563+
private void validateFunctionAllowed(String functionName) {
564+
String lowerCaseFunctionName = functionName.toLowerCase();
565+
FunctionType type = FunctionType.fromFunctionName(lowerCaseFunctionName);
565566
switch (type) {
566567
case MAP:
567-
validateAllowed(SQLGrammarElement.MAP_FUNCTIONS);
568+
validateAllowed(SQLGrammarElement.MAP_FUNCTIONS, lowerCaseFunctionName);
568569
break;
569570
case BITWISE:
570-
validateAllowed(SQLGrammarElement.BITWISE_FUNCTIONS);
571+
validateAllowed(SQLGrammarElement.BITWISE_FUNCTIONS, lowerCaseFunctionName);
571572
break;
572573
case CSV:
573-
validateAllowed(SQLGrammarElement.CSV_FUNCTIONS);
574+
validateAllowed(SQLGrammarElement.CSV_FUNCTIONS, lowerCaseFunctionName);
574575
break;
575576
case MISC:
576-
validateAllowed(SQLGrammarElement.MISC_FUNCTIONS);
577+
validateAllowed(SQLGrammarElement.MISC_FUNCTIONS, lowerCaseFunctionName);
577578
break;
578579
case GENERATOR:
579-
validateAllowed(SQLGrammarElement.GENERATOR_FUNCTIONS);
580+
validateAllowed(SQLGrammarElement.GENERATOR_FUNCTIONS, lowerCaseFunctionName);
581+
break;
582+
case OTHER:
583+
validateAllowed(SQLGrammarElement.OTHER_FUNCTIONS, lowerCaseFunctionName);
580584
break;
581585
case UDF:
582-
validateAllowed(SQLGrammarElement.UDF);
586+
validateAllowed(SQLGrammarElement.UDF, lowerCaseFunctionName);
583587
break;
584588
}
585589
}
@@ -590,6 +594,12 @@ private void validateAllowed(SQLGrammarElement element) {
590594
}
591595
}
592596

597+
private void validateAllowed(SQLGrammarElement element, String detail) {
598+
if (!grammarElementValidator.isValid(element)) {
599+
throw new IllegalArgumentException(String.format("%s (%s) is not allowed.", element, detail));
600+
}
601+
}
602+
593603
@Override
594604
public Void visitErrorCapturingIdentifier(ErrorCapturingIdentifierContext ctx) {
595605
ErrorCapturingIdentifierExtraContext extra = ctx.errorCapturingIdentifierExtra();

async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ void testExtractionFromFlintSkippingIndexQueries() {
140140
+ " WHERE elb_status_code = 500 "
141141
+ " WITH (auto_refresh = true)",
142142
"DROP SKIPPING INDEX ON myS3.default.alb_logs",
143-
"ALTER SKIPPING INDEX ON myS3.default.alb_logs WITH (auto_refresh = false)",
143+
"ALTER SKIPPING INDEX ON myS3.default.alb_logs WITH (auto_refresh = false)"
144144
};
145145

146146
for (String query : createSkippingIndexQueries) {

async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ public void test() {
4242
assertEquals(FunctionType.MISC, FunctionType.fromFunctionName("version"));
4343
assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("explode"));
4444
assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("stack"));
45+
assertEquals(FunctionType.OTHER, FunctionType.fromFunctionName("aggregate"));
46+
assertEquals(FunctionType.OTHER, FunctionType.fromFunctionName("forall"));
4547
assertEquals(FunctionType.UDF, FunctionType.fromFunctionName("unknown"));
4648
}
4749
}

async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.sql.spark.validator;
77

88
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
9+
import static org.junit.jupiter.api.Assertions.assertEquals;
910
import static org.junit.jupiter.api.Assertions.assertThrows;
1011
import static org.mockito.ArgumentMatchers.any;
1112
import static org.mockito.Mockito.when;
@@ -192,6 +193,10 @@ private enum TestElement {
192193
// Generator Functions
193194
GENERATOR_FUNCTIONS("SELECT explode(array(1, 2, 3));"),
194195

196+
// Other functions
197+
NAMED_STRUCT("SELECT named_struct('a', 1);"),
198+
PARSE_URL("SELECT parse_url(url) FROM my_table;"),
199+
195200
// UDFs (User-Defined Functions)
196201
SCALAR_USER_DEFINED_FUNCTIONS("SELECT my_udf(name) FROM my_table;"),
197202
USER_DEFINED_AGGREGATE_FUNCTIONS("SELECT my_udaf(age) FROM my_table GROUP BY name;"),
@@ -323,6 +328,10 @@ void testDenyAllValidator() {
323328
// Generator Functions
324329
v.ng(TestElement.GENERATOR_FUNCTIONS);
325330

331+
// Other Functions
332+
v.ng(TestElement.NAMED_STRUCT);
333+
v.ng(TestElement.PARSE_URL);
334+
326335
// UDFs
327336
v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS);
328337
v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS);
@@ -440,6 +449,10 @@ void testS3glueQueries() {
440449
// Generator Functions
441450
v.ok(TestElement.GENERATOR_FUNCTIONS);
442451

452+
// Other Functions
453+
v.ok(TestElement.NAMED_STRUCT);
454+
v.ok(TestElement.PARSE_URL);
455+
443456
// UDFs
444457
v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS);
445458
v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS);
@@ -621,6 +634,14 @@ void testUnsupportedHiveNativeCommand() {
621634
v.ng("DFS");
622635
}
623636

637+
@Test
638+
void testException() {
639+
when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> false);
640+
VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.S3GLUE);
641+
642+
v.ng("SELECT named_struct('a', 1);", "Other functions (named_struct) is not allowed.");
643+
}
644+
624645
@AllArgsConstructor
625646
private static class VerifyValidator {
626647
private final SQLQueryValidator validator;
@@ -645,6 +666,15 @@ public void ng(String query) {
645666
"The query should throw: query=`" + query.toString() + "`");
646667
}
647668

669+
public void ng(String query, String expectedMessage) {
670+
Exception e =
671+
assertThrows(
672+
IllegalArgumentException.class,
673+
() -> runValidate(query),
674+
"The query should throw: query=`" + query.toString() + "`");
675+
assertEquals(expectedMessage, e.getMessage());
676+
}
677+
648678
void runValidate(String[] queries) {
649679
Arrays.stream(queries).forEach(query -> validator.validate(query, dataSourceType));
650680
}

0 commit comments

Comments
 (0)