55
66package org .opensearch .sql .spark .utils ;
77
8+ import java .util .LinkedList ;
9+ import java .util .List ;
810import java .util .Locale ;
911import lombok .Getter ;
1012import lombok .experimental .UtilityClass ;
1820import org .opensearch .sql .spark .antlr .parser .FlintSparkSqlExtensionsParser ;
1921import org .opensearch .sql .spark .antlr .parser .SqlBaseLexer ;
2022import org .opensearch .sql .spark .antlr .parser .SqlBaseParser ;
23+ import org .opensearch .sql .spark .antlr .parser .SqlBaseParser .IdentifierReferenceContext ;
2124import org .opensearch .sql .spark .antlr .parser .SqlBaseParserBaseVisitor ;
2225import org .opensearch .sql .spark .dispatcher .model .FlintIndexOptions ;
2326import org .opensearch .sql .spark .dispatcher .model .FullyQualifiedTableName ;
3235@ UtilityClass
3336public class SQLQueryUtils {
3437
35- // TODO Handle cases where the query has multiple table Names.
36- public static FullyQualifiedTableName extractFullyQualifiedTableName (String sqlQuery ) {
38+ public static List <FullyQualifiedTableName > extractFullyQualifiedTableNames (String sqlQuery ) {
3739 SqlBaseParser sqlBaseParser =
3840 new SqlBaseParser (
3941 new CommonTokenStream (new SqlBaseLexer (new CaseInsensitiveCharStream (sqlQuery ))));
4042 sqlBaseParser .addErrorListener (new SyntaxAnalysisErrorListener ());
4143 SqlBaseParser .StatementContext statement = sqlBaseParser .statement ();
4244 SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor ();
4345 statement .accept (sparkSqlTableNameVisitor );
44- return sparkSqlTableNameVisitor .getFullyQualifiedTableName ();
46+ return sparkSqlTableNameVisitor .getFullyQualifiedTableNames ();
4547 }
4648
4749 public static IndexQueryDetails extractIndexDetails (String sqlQuery ) {
@@ -73,23 +75,21 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {
7375
7476 public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor <Void > {
7577
76- @ Getter private FullyQualifiedTableName fullyQualifiedTableName ;
78+ @ Getter private List < FullyQualifiedTableName > fullyQualifiedTableNames = new LinkedList <>() ;
7779
78- public SparkSqlTableNameVisitor () {
79- this .fullyQualifiedTableName = new FullyQualifiedTableName ();
80- }
80+ public SparkSqlTableNameVisitor () {}
8181
8282 @ Override
83- public Void visitTableName ( SqlBaseParser . TableNameContext ctx ) {
84- fullyQualifiedTableName = new FullyQualifiedTableName (ctx .getText ());
85- return super .visitTableName (ctx );
83+ public Void visitIdentifierReference ( IdentifierReferenceContext ctx ) {
84+ fullyQualifiedTableNames . add ( new FullyQualifiedTableName (ctx .getText () ));
85+ return super .visitIdentifierReference (ctx );
8686 }
8787
8888 @ Override
8989 public Void visitDropTable (SqlBaseParser .DropTableContext ctx ) {
9090 for (ParseTree parseTree : ctx .children ) {
9191 if (parseTree instanceof SqlBaseParser .IdentifierReferenceContext ) {
92- fullyQualifiedTableName = new FullyQualifiedTableName (parseTree .getText ());
92+ fullyQualifiedTableNames . add ( new FullyQualifiedTableName (parseTree .getText () ));
9393 }
9494 }
9595 return super .visitDropTable (ctx );
@@ -99,7 +99,7 @@ public Void visitDropTable(SqlBaseParser.DropTableContext ctx) {
9999 public Void visitDescribeRelation (SqlBaseParser .DescribeRelationContext ctx ) {
100100 for (ParseTree parseTree : ctx .children ) {
101101 if (parseTree instanceof SqlBaseParser .IdentifierReferenceContext ) {
102- fullyQualifiedTableName = new FullyQualifiedTableName (parseTree .getText ());
102+ fullyQualifiedTableNames . add ( new FullyQualifiedTableName (parseTree .getText () ));
103103 }
104104 }
105105 return super .visitDescribeRelation (ctx );
@@ -110,7 +110,7 @@ public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) {
110110 public Void visitCreateTableHeader (SqlBaseParser .CreateTableHeaderContext ctx ) {
111111 for (ParseTree parseTree : ctx .children ) {
112112 if (parseTree instanceof SqlBaseParser .IdentifierReferenceContext ) {
113- fullyQualifiedTableName = new FullyQualifiedTableName (parseTree .getText ());
113+ fullyQualifiedTableNames . add ( new FullyQualifiedTableName (parseTree .getText () ));
114114 }
115115 }
116116 return super .visitCreateTableHeader (ctx );
0 commit comments