Skip to content

Commit dbc6d67

Browse files
Fix SQLQueryUtils to extract multiple tables (#2784)
* Fix SQLQueryUtils to extract multiple tables Signed-off-by: Tomoyuki Morita <moritato@amazon.com> * Improve test coverage Signed-off-by: Tomoyuki Morita <moritato@amazon.com> --------- Signed-off-by: Tomoyuki Morita <moritato@amazon.com> (cherry picked from commit 883cc7e) Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 775aa90 commit dbc6d67

File tree

2 files changed

+196
-152
lines changed

2 files changed

+196
-152
lines changed

async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
package org.opensearch.sql.spark.utils;
77

8+
import java.util.LinkedList;
9+
import java.util.List;
810
import java.util.Locale;
911
import lombok.Getter;
1012
import lombok.experimental.UtilityClass;
@@ -18,6 +20,7 @@
1820
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser;
1921
import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer;
2022
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser;
23+
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.IdentifierReferenceContext;
2124
import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor;
2225
import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions;
2326
import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName;
@@ -32,16 +35,15 @@
3235
@UtilityClass
3336
public class SQLQueryUtils {
3437

35-
// TODO Handle cases where the query has multiple table Names.
36-
public static FullyQualifiedTableName extractFullyQualifiedTableName(String sqlQuery) {
38+
public static List<FullyQualifiedTableName> extractFullyQualifiedTableNames(String sqlQuery) {
3739
SqlBaseParser sqlBaseParser =
3840
new SqlBaseParser(
3941
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
4042
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
4143
SqlBaseParser.StatementContext statement = sqlBaseParser.statement();
4244
SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor();
4345
statement.accept(sparkSqlTableNameVisitor);
44-
return sparkSqlTableNameVisitor.getFullyQualifiedTableName();
46+
return sparkSqlTableNameVisitor.getFullyQualifiedTableNames();
4547
}
4648

4749
public static IndexQueryDetails extractIndexDetails(String sqlQuery) {
@@ -73,23 +75,21 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {
7375

7476
public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {
7577

76-
@Getter private FullyQualifiedTableName fullyQualifiedTableName;
78+
@Getter private List<FullyQualifiedTableName> fullyQualifiedTableNames = new LinkedList<>();
7779

78-
public SparkSqlTableNameVisitor() {
79-
this.fullyQualifiedTableName = new FullyQualifiedTableName();
80-
}
80+
public SparkSqlTableNameVisitor() {}
8181

8282
@Override
83-
public Void visitTableName(SqlBaseParser.TableNameContext ctx) {
84-
fullyQualifiedTableName = new FullyQualifiedTableName(ctx.getText());
85-
return super.visitTableName(ctx);
83+
public Void visitIdentifierReference(IdentifierReferenceContext ctx) {
84+
fullyQualifiedTableNames.add(new FullyQualifiedTableName(ctx.getText()));
85+
return super.visitIdentifierReference(ctx);
8686
}
8787

8888
@Override
8989
public Void visitDropTable(SqlBaseParser.DropTableContext ctx) {
9090
for (ParseTree parseTree : ctx.children) {
9191
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
92-
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
92+
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
9393
}
9494
}
9595
return super.visitDropTable(ctx);
@@ -99,7 +99,7 @@ public Void visitDropTable(SqlBaseParser.DropTableContext ctx) {
9999
public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) {
100100
for (ParseTree parseTree : ctx.children) {
101101
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
102-
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
102+
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
103103
}
104104
}
105105
return super.visitDescribeRelation(ctx);
@@ -110,7 +110,7 @@ public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) {
110110
public Void visitCreateTableHeader(SqlBaseParser.CreateTableHeaderContext ctx) {
111111
for (ParseTree parseTree : ctx.children) {
112112
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
113-
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
113+
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
114114
}
115115
}
116116
return super.visitCreateTableHeader(ctx);

0 commit comments

Comments
 (0)