Skip to content

Commit

Permalink
Oracle case insensitive search (#2487)
Browse files Browse the repository at this point in the history
  • Loading branch information
aimethed authored Jan 8, 2025
1 parent f910d92 commit dbce015
Show file tree
Hide file tree
Showing 4 changed files with 536 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@

/*-
* #%L
* athena-oracle
* %%
* Copyright (C) 2019 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.oracle;

import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Map;

import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.DEFAULT_GLUE_CONNECTION;

public class OracleCaseResolver
{
private static final Logger LOGGER = LoggerFactory.getLogger(OracleCaseResolver.class);
static final String SCHEMA_NAME_QUERY_TEMPLATE = "SELECT DISTINCT OWNER as \"OWNER\" FROM all_tables WHERE lower(OWNER) = ?";
static final String TABLE_NAME_QUERY_TEMPLATE = "SELECT DISTINCT TABLE_NAME as \"TABLE_NAME\" FROM all_tables WHERE OWNER = ? and lower(TABLE_NAME) = ?";
static final String SCHEMA_NAME_COLUMN_KEY = "OWNER";
static final String TABLE_NAME_COLUMN_KEY = "TABLE_NAME";

// the environment variable that can be set to specify which casing mode to use
static final String CASING_MODE = "casing_mode";

// used for identifying database objects (ex: table names)
private static final String ORACLE_IDENTIFIER_CHARACTER = "\"";
// used in SQL statements for character strings (ex: where OWNER = 'example')
private static final String ORACLE_STRING_LITERAL_CHARACTER = "\'";

private OracleCaseResolver() {}

private enum OracleCasingMode
{
LOWER, // casing mode to lower case everything (glue and trino lower case everything)
UPPER, // casing mode to upper case everything (oracle by default upper cases everything)
CASE_INSENSITIVE_SEARCH // casing mode to perform case insensitive search
}

public static TableName getAdjustedTableObjectName(final Connection connection, TableName tableName, Map<String, String> configOptions)
throws SQLException
{
OracleCasingMode casingMode = getCasingMode(configOptions);
switch (casingMode) {
case CASE_INSENSITIVE_SEARCH:
String schemaNameCaseInsensitively = getSchemaNameCaseInsensitively(connection, tableName.getSchemaName());
String tableNameCaseInsensitively = getTableNameCaseInsensitively(connection, schemaNameCaseInsensitively, tableName.getTableName());
TableName tableNameResult = new TableName(schemaNameCaseInsensitively, tableNameCaseInsensitively);
LOGGER.info("casing mode is `SEARCH`: performing case insensitive search for TableName object. TableName:{}", tableNameResult);
return tableNameResult;
case UPPER:
TableName upperTableName = new TableName(tableName.getSchemaName().toUpperCase(), tableName.getTableName().toUpperCase());
LOGGER.info("casing mode is `UPPER`: adjusting casing from input to upper case for TableName object. TableName:{}", upperTableName);
return upperTableName;
case LOWER:
TableName lowerTableName = new TableName(tableName.getSchemaName().toLowerCase(), tableName.getTableName().toLowerCase());
LOGGER.info("casing mode is `LOWER`: adjusting casing from input to lower case for TableName object. TableName:{}", lowerTableName);
return lowerTableName;
}
LOGGER.warn("casing mode is empty: not adjust casing from input for TableName object. TableName:{}", tableName);
return tableName;
}

public static String getAdjustedSchemaName(final Connection connection, String schemaNameInput, Map<String, String> configOptions)
throws SQLException
{
OracleCasingMode casingMode = getCasingMode(configOptions);
switch (casingMode) {
case CASE_INSENSITIVE_SEARCH:
LOGGER.info("casing mode is SEARCH: performing case insensitive search for Schema...");
return getSchemaNameCaseInsensitively(connection, schemaNameInput);
case UPPER:
LOGGER.info("casing mode is `UPPER`: adjusting casing from input to upper case for Schema");
return schemaNameInput.toUpperCase();
case LOWER:
LOGGER.info("casing mode is `LOWER`: adjusting casing from input to lower case for Schema");
return schemaNameInput.toLowerCase();
}

return schemaNameInput;
}

public static String getSchemaNameCaseInsensitively(final Connection connection, String schemaName)
throws SQLException
{
String nameFromOracle = null;
int i = 0;
try (PreparedStatement preparedStatement = new PreparedStatementBuilder()
.withConnection(connection)
.withQuery(SCHEMA_NAME_QUERY_TEMPLATE)
.withParameters(Arrays.asList(schemaName.toLowerCase())).build();
ResultSet resultSet = preparedStatement.executeQuery()) {
while (resultSet.next()) {
i++;
String schemaNameCandidate = resultSet.getString(SCHEMA_NAME_COLUMN_KEY);
LOGGER.debug("Case insensitive search on columLabel: {}, schema name: {}", SCHEMA_NAME_COLUMN_KEY, schemaNameCandidate);
nameFromOracle = schemaNameCandidate;
}
}
catch (SQLException e) {
throw new RuntimeException(String.format("getSchemaNameCaseInsensitively query failed for %s", schemaName), e);
}

if (i != 1) {
throw new RuntimeException(String.format("Schema name case insensitive match failed, number of match : %d", i));
}

return nameFromOracle;
}

public static String getTableNameCaseInsensitively(final Connection connection, String schemaName, String tableNameInput)
throws SQLException
{
// schema name input should be correct case before searching tableName already
String nameFromOracle = null;
int i = 0;
try (PreparedStatement preparedStatement = new PreparedStatementBuilder()
.withConnection(connection)
.withQuery(TABLE_NAME_QUERY_TEMPLATE)
.withParameters(Arrays.asList((schemaName), tableNameInput.toLowerCase())).build();
ResultSet resultSet = preparedStatement.executeQuery()) {
while (resultSet.next()) {
i++;
String schemaNameCandidate = resultSet.getString(TABLE_NAME_COLUMN_KEY);
LOGGER.debug("Case insensitive search on columLabel: {}, schema name: {}", TABLE_NAME_COLUMN_KEY, schemaNameCandidate);
nameFromOracle = schemaNameCandidate;
}
}
catch (SQLException e) {
throw new RuntimeException(String.format("getTableNameCaseInsensitively query failed for schema: %s tableName: %s", schemaName, tableNameInput), e);
}

if (i != 1) {
throw new RuntimeException(String.format("Schema name case insensitive match failed, number of match : %d", i));
}

return nameFromOracle;
}

private static OracleCasingMode getCasingMode(Map<String, String> configOptions)
{
boolean isGlueConnection = StringUtils.isNotBlank(configOptions.get(DEFAULT_GLUE_CONNECTION));
if (!configOptions.containsKey(CASING_MODE)) {
LOGGER.info("CASING MODE not set");
return isGlueConnection ? OracleCasingMode.LOWER : OracleCasingMode.UPPER;
}

try {
OracleCasingMode oracleCasingMode = OracleCasingMode.valueOf(configOptions.get(CASING_MODE).toUpperCase());
LOGGER.info("CASING MODE enable: {}", oracleCasingMode.toString());
return oracleCasingMode;
}
catch (IllegalArgumentException ex) {
// print error log for customer along with list of input
LOGGER.error("Invalid input for:{}, input value:{}, valid values:{}", CASING_MODE, configOptions.get(CASING_MODE), Arrays.asList(OracleCasingMode.values()), ex);
throw ex;
}
}

public static TableName quoteTableName(TableName inputTable)
{
String schemaName = inputTable.getSchemaName();
String tableName = inputTable.getTableName();
if (!schemaName.contains(ORACLE_IDENTIFIER_CHARACTER)) {
schemaName = ORACLE_IDENTIFIER_CHARACTER + schemaName + ORACLE_IDENTIFIER_CHARACTER;
}
if (!tableName.contains(ORACLE_IDENTIFIER_CHARACTER)) {
tableName = ORACLE_IDENTIFIER_CHARACTER + tableName + ORACLE_IDENTIFIER_CHARACTER;
}
return new TableName(schemaName, tableName);
}

public static String convertToLiteral(String input)
{
if (!input.contains(ORACLE_STRING_LITERAL_CHARACTER)) {
input = ORACLE_STRING_LITERAL_CHARACTER + input + ORACLE_STRING_LITERAL_CHARACTER;
}
return input;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.athena.AthenaClient;
Expand All @@ -79,7 +78,6 @@
import java.util.Set;
import java.util.stream.Collectors;

import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.DEFAULT_GLUE_CONNECTION;
import static com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME;
import static com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions.MODULUS_FUNCTION_NAME;
import static com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions.NULLIF_FUNCTION_NAME;
Expand All @@ -95,11 +93,9 @@ public class OracleMetadataHandler
static final String BLOCK_PARTITION_COLUMN_NAME = "PARTITION_NAME".toLowerCase();
static final String ALL_PARTITIONS = "0";
static final String PARTITION_COLUMN_NAME = "PARTITION_NAME".toLowerCase();
static final String CASING_MODE = "casing_mode";
private static final Logger LOGGER = LoggerFactory.getLogger(OracleMetadataHandler.class);
private static final int MAX_SPLITS_PER_REQUEST = 1000_000;
private static final String COLUMN_NAME = "COLUMN_NAME";
private static final String ORACLE_QUOTE_CHARACTER = "\"";

static final String LIST_PAGINATED_TABLES_QUERY = "SELECT TABLE_NAME as \"TABLE_NAME\", OWNER as \"TABLE_SCHEM\" FROM all_tables WHERE owner = ? ORDER BY TABLE_NAME OFFSET ? ROWS FETCH NEXT ? ROWS ONLY";

Expand Down Expand Up @@ -158,10 +154,11 @@ public Schema getPartitionSchema(final String catalogName)
public void getPartitions(final BlockWriter blockWriter, final GetTableLayoutRequest getTableLayoutRequest, QueryStatusChecker queryStatusChecker)
throws Exception
{
LOGGER.debug("{}: Schema {}, table {}", getTableLayoutRequest.getQueryId(), transformString(getTableLayoutRequest.getTableName().getSchemaName(), true),
transformString(getTableLayoutRequest.getTableName().getTableName(), true));
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
List<String> parameters = Arrays.asList(transformString(getTableLayoutRequest.getTableName().getTableName(), true));
TableName casedTableName = getTableLayoutRequest.getTableName();
LOGGER.debug("{}: Schema {}, table {}", getTableLayoutRequest.getQueryId(), casedTableName.getSchemaName(),
casedTableName.getTableName());
List<String> parameters = Arrays.asList(OracleCaseResolver.convertToLiteral(casedTableName.getTableName()));
try (PreparedStatement preparedStatement = new PreparedStatementBuilder().withConnection(connection).withQuery(GET_PARTITIONS_QUERY).withParameters(parameters).build();
ResultSet resultSet = preparedStatement.executeQuery()) {
// Return a single partition if no partitions defined
Expand Down Expand Up @@ -256,7 +253,8 @@ protected ListTablesResponse listPaginatedTables(final Connection connection, fi
int t = token != null ? Integer.parseInt(token) : 0;

LOGGER.info("Starting pagination at {} with page size {}", token, pageSize);
List<TableName> paginatedTables = getPaginatedTables(connection, listTablesRequest.getSchemaName(), t, pageSize);
String casedSchemaName = OracleCaseResolver.getAdjustedSchemaName(connection, listTablesRequest.getSchemaName(), configOptions);
List<TableName> paginatedTables = getPaginatedTables(connection, casedSchemaName, t, pageSize);
LOGGER.info("{} tables returned. Next token is {}", paginatedTables.size(), t + pageSize);
return new ListTablesResponse(listTablesRequest.getCatalogName(), paginatedTables, Integer.toString(t + pageSize));
}
Expand Down Expand Up @@ -310,7 +308,7 @@ public GetTableResponse doGetTable(final BlockAllocator blockAllocator, final Ge
{
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
Schema partitionSchema = getPartitionSchema(getTableRequest.getCatalogName());
TableName tableName = new TableName(transformString(getTableRequest.getTableName().getSchemaName(), false), transformString(getTableRequest.getTableName().getTableName(), false));
TableName tableName = OracleCaseResolver.getAdjustedTableObjectName(connection, getTableRequest.getTableName(), configOptions);
return new GetTableResponse(getTableRequest.getCatalogName(), tableName, getSchema(connection, tableName, partitionSchema),
partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet()));
}
Expand Down Expand Up @@ -413,25 +411,4 @@ private Schema getSchema(Connection jdbcConnection, TableName tableName, Schema
return schemaBuilder.build();
}
}

/**
* Always adds double quotes around the string
* If the lambda uses a glue connection, return the string as is (lowercased by the trino engine)
* Otherwise uppercase it (the default of oracle)
* @param str
* @param quote
* @return
*/
private String transformString(String str, boolean quote)
{
boolean isGlueConnection = StringUtils.isNotBlank(configOptions.get(DEFAULT_GLUE_CONNECTION));
boolean uppercase = configOptions.getOrDefault(CASING_MODE, isGlueConnection ? "lower" : "upper").toLowerCase().equals("upper");
if (uppercase) {
str = str.toUpperCase();
}
if (quote && !str.contains(ORACLE_QUOTE_CHARACTER)) {
str = ORACLE_QUOTE_CHARACTER + str + ORACLE_QUOTE_CHARACTER;
}
return str;
}
}
Loading

0 comments on commit dbce015

Please sign in to comment.