Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.spark.utils;

import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import lombok.Getter;
import lombok.experimental.UtilityClass;
Expand All @@ -18,6 +20,7 @@
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser;
import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.IdentifierReferenceContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor;
import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions;
import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName;
Expand All @@ -32,16 +35,15 @@
@UtilityClass
public class SQLQueryUtils {

// TODO Handle cases where the query has multiple table Names.
public static FullyQualifiedTableName extractFullyQualifiedTableName(String sqlQuery) {
public static List<FullyQualifiedTableName> extractFullyQualifiedTableNames(String sqlQuery) {
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
SqlBaseParser.StatementContext statement = sqlBaseParser.statement();
SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor();
statement.accept(sparkSqlTableNameVisitor);
return sparkSqlTableNameVisitor.getFullyQualifiedTableName();
return sparkSqlTableNameVisitor.getFullyQualifiedTableNames();
}

public static IndexQueryDetails extractIndexDetails(String sqlQuery) {
Expand Down Expand Up @@ -73,23 +75,21 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {

public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {

@Getter private FullyQualifiedTableName fullyQualifiedTableName;
@Getter private List<FullyQualifiedTableName> fullyQualifiedTableNames = new LinkedList<>();

public SparkSqlTableNameVisitor() {
this.fullyQualifiedTableName = new FullyQualifiedTableName();
}
public SparkSqlTableNameVisitor() {}

@Override
public Void visitTableName(SqlBaseParser.TableNameContext ctx) {
fullyQualifiedTableName = new FullyQualifiedTableName(ctx.getText());
return super.visitTableName(ctx);
public Void visitIdentifierReference(IdentifierReferenceContext ctx) {
fullyQualifiedTableNames.add(new FullyQualifiedTableName(ctx.getText()));
return super.visitIdentifierReference(ctx);
}

@Override
public Void visitDropTable(SqlBaseParser.DropTableContext ctx) {
for (ParseTree parseTree : ctx.children) {
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
}
}
return super.visitDropTable(ctx);
Expand All @@ -99,7 +99,7 @@ public Void visitDropTable(SqlBaseParser.DropTableContext ctx) {
public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) {
for (ParseTree parseTree : ctx.children) {
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
}
}
return super.visitDescribeRelation(ctx);
Expand All @@ -110,7 +110,7 @@ public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) {
public Void visitCreateTableHeader(SqlBaseParser.CreateTableHeaderContext ctx) {
for (ParseTree parseTree : ctx.children) {
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
}
}
return super.visitCreateTableHeader(ctx);
Expand Down
Loading