Skip to content

Commit

Permalink
Snowflake case insensitive search improve with ANNOTATION added back …
Browse files Browse the repository at this point in the history
…for backward compatibility (awslabs#2437)
  • Loading branch information
chngpe authored Nov 27, 2024
1 parent ac93ab3 commit 8a43a53
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,74 +20,126 @@
package com.amazonaws.athena.connectors.snowflake;

import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connectors.jdbc.manager.PreparedStatementBuilder;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Arrays;
import java.util.Map;

import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.DEFAULT_GLUE_CONNECTION;

public class SnowflakeCaseInsensitiveResolver
{
private static final Logger LOGGER = LoggerFactory.getLogger(SnowflakeCaseInsensitiveResolver.class);
private static final String SCHEMA_NAME_QUERY = "select * from INFORMATION_SCHEMA.SCHEMATA where lower(SCHEMA_NAME) = ";
private static final String TABLE_NAME_QUERY = "select * from INFORMATION_SCHEMA.TABLES where TABLE_SCHEMA = ";
private static final String SCHEMA_NAME_QUERY_TEMPLATE = "select * from INFORMATION_SCHEMA.SCHEMATA where lower(SCHEMA_NAME) = ?";
private static final String TABLE_NAME_QUERY_TEMPLATE = "select * from INFORMATION_SCHEMA.TABLES where TABLE_SCHEMA = ? and lower(TABLE_NAME) = ?";
private static final String SCHEMA_NAME_COLUMN_KEY = "SCHEMA_NAME";
private static final String TABLE_NAME_COLUMN_KEY = "TABLE_NAME";

private static final String ENABLE_CASE_INSENSITIVE_MATCH = "enable_case_insensitive_match";
private static final String CASING_MODE = "casing_mode";
private static final String ANNOTATION_CASE_UPPER = "upper";
private static final String ANNOTATION_CASE_LOWER = "lower";

private SnowflakeCaseInsensitiveResolver()
{
}

public static TableName getTableNameObjectCaseInsensitiveMatch(final Connection connection, TableName tableName, Map<String, String> configOptions)
private enum SnowflakeCasingMode
{
NONE,
CASE_INSENSITIVE_SEARCH,
ANNOTATION
}

public static TableName getAdjustedTableObjectNameBasedOnConfig(final Connection connection, TableName tableName, Map<String, String> configOptions)
throws SQLException
{
if (!isCaseInsensitiveMatchEnable(configOptions)) {
return tableName;
SnowflakeCasingMode casingMode = getCasingMode(configOptions);
switch (casingMode) {
case CASE_INSENSITIVE_SEARCH:
String schemaNameCaseInsensitively = getSchemaNameCaseInsensitively(connection, tableName.getSchemaName(), configOptions);
String tableNameCaseInsensitively = getTableNameCaseInsensitively(connection, schemaNameCaseInsensitively, tableName.getTableName(), configOptions);
TableName tableNameResult = new TableName(schemaNameCaseInsensitively, tableNameCaseInsensitively);
LOGGER.info("casing mode is `CASE_INSENSITIVE_SEARCH`: adjusting casing from Slowflake case insensitive search for TableName object. TableName:{}", tableNameResult);
return tableNameResult;
case ANNOTATION:
TableName tableNameFromQueryHint = getTableNameFromQueryHint(tableName);
LOGGER.info("casing mode is `ANNOTATION`: adjusting casing from input if annotation found for TableName object. TableName:{}", tableNameFromQueryHint);
return tableNameFromQueryHint;
case NONE:
LOGGER.info("casing mode is `NONE`: not adjust casing from input for TableName object. TableName:{}", tableName);
return tableName;
}

String schemaNameCaseInsensitively = getSchemaNameCaseInsensitively(connection, tableName.getSchemaName(), configOptions);
String tableNameCaseInsensitively = getTableNameCaseInsensitively(connection, schemaNameCaseInsensitively, tableName.getTableName(), configOptions);

return new TableName(schemaNameCaseInsensitively, tableNameCaseInsensitively);
LOGGER.warn("casing mode is empty: not adjust casing from input for TableName object. TableName:{}", tableName);
return tableName;
}

public static String getSchemaNameCaseInsensitively(final Connection connection, String schemaNameInput, Map<String, String> configOptions)
public static String getAdjustedSchemaNameBasedOnConfig(final Connection connection, String schemaNameInput, Map<String, String> configOptions)
throws SQLException
{
if (!isCaseInsensitiveMatchEnable(configOptions)) {
return schemaNameInput;
SnowflakeCasingMode casingMode = getCasingMode(configOptions);
switch (casingMode) {
case CASE_INSENSITIVE_SEARCH:
LOGGER.info("casing mode is `CASE_INSENSITIVE_SEARCH`: adjusting casing from Slowflake case insensitive search for Schema...");
return getSchemaNameCaseInsensitively(connection, schemaNameInput, configOptions);
case NONE:
LOGGER.info("casing mode is `NONE`: not adjust casing from input for Schema");
return schemaNameInput;
case ANNOTATION:
LOGGER.info("casing mode is `ANNOTATION`: adjust casing for SCHEMA is NOT SUPPORTED. Skip casing adjustment");
}

return getNameCaseInsensitively(connection, SCHEMA_NAME_COLUMN_KEY, SCHEMA_NAME_QUERY + "'" + schemaNameInput.toLowerCase() + "'", configOptions);
return schemaNameInput;
}

public static String getTableNameCaseInsensitively(final Connection connection, String schemaName, String tableNameInput, Map<String, String> configOptions)
public static String getSchemaNameCaseInsensitively(final Connection connection, String schemaNameInput, Map<String, String> configOptions)
throws SQLException
{
if (!isCaseInsensitiveMatchEnable(configOptions)) {
return tableNameInput;
String nameFromSnowFlake = null;
int i = 0;
try (PreparedStatement preparedStatement = new PreparedStatementBuilder()
.withConnection(connection)
.withQuery(SCHEMA_NAME_QUERY_TEMPLATE)
.withParameters(Arrays.asList(schemaNameInput.toLowerCase())).build();
ResultSet resultSet = preparedStatement.executeQuery()) {
while (resultSet.next()) {
i++;
String schemaNameCandidate = resultSet.getString(SCHEMA_NAME_COLUMN_KEY);
LOGGER.debug("Case insensitive search on columLabel: {}, schema name: {}", SCHEMA_NAME_COLUMN_KEY, schemaNameCandidate);
nameFromSnowFlake = schemaNameCandidate;
}
}
//'?' and lower(TABLE_NAME) = '?'
return getNameCaseInsensitively(connection, TABLE_NAME_COLUMN_KEY, TABLE_NAME_QUERY + "'" + schemaName + "' and lower(TABLE_NAME) = '" + tableNameInput.toLowerCase() + "'", configOptions);
catch (SQLException e) {
throw new RuntimeException(e);
}

if (i == 0 || i > 1) {
throw new RuntimeException(String.format("Schema name case insensitive match failed, number of match : %d", i));
}

return nameFromSnowFlake;
}

public static String getNameCaseInsensitively(final Connection connection, String columnLabel, String query, Map<String, String> configOptions)
public static String getTableNameCaseInsensitively(final Connection connection, String schemaName, String tableNameInput, Map<String, String> configOptions)
throws SQLException
{
LOGGER.debug("getNameCaseInsensitively, query:" + query);
// schema name input should be correct case before searching tableName already
String nameFromSnowFlake = null;
int i = 0;
try (Statement statement = connection.createStatement();
ResultSet resultSet = statement.executeQuery(query)) {
try (PreparedStatement preparedStatement = new PreparedStatementBuilder()
.withConnection(connection)
.withQuery(TABLE_NAME_QUERY_TEMPLATE)
.withParameters(Arrays.asList(schemaName, tableNameInput.toLowerCase())).build();
ResultSet resultSet = preparedStatement.executeQuery()) {
while (resultSet.next()) {
i++;
String schemaNameCandidate = resultSet.getString(columnLabel);
LOGGER.debug("Case insensitive search on columLabel: {}, schema name: {}", columnLabel, schemaNameCandidate);
String schemaNameCandidate = resultSet.getString(TABLE_NAME_COLUMN_KEY);
LOGGER.debug("Case insensitive search on columLabel: {}, schema name: {}", TABLE_NAME_COLUMN_KEY, schemaNameCandidate);
nameFromSnowFlake = schemaNameCandidate;
}
}
Expand All @@ -102,13 +154,70 @@ public static String getNameCaseInsensitively(final Connection connection, Strin
return nameFromSnowFlake;
}

private static boolean isCaseInsensitiveMatchEnable(Map<String, String> configOptions)
/*
Keep previous implementation of table name casing adjustment from query hint. This is to keep backward compatibility.
*/
public static TableName getTableNameFromQueryHint(TableName table)
{
LOGGER.info("getTableNameFromQueryHint: " + table);
//if no query hints has been passed then return input table name
if (!table.getTableName().contains("@")) {
return new TableName(table.getSchemaName().toUpperCase(), table.getTableName().toUpperCase());
}
//analyze the hint to find table and schema case
String[] tbNameWithQueryHint = table.getTableName().split("@");
String[] hintDetails = tbNameWithQueryHint[1].split("&");
String schemaCase = ANNOTATION_CASE_UPPER;
String tableCase = ANNOTATION_CASE_UPPER;
String tableName = tbNameWithQueryHint[0];
for (String str : hintDetails) {
String[] hintDetail = str.split("=");
if (hintDetail[0].contains("schema")) {
schemaCase = hintDetail[1];
}
else if (hintDetail[0].contains("table")) {
tableCase = hintDetail[1];
}
}
if (schemaCase.equalsIgnoreCase(ANNOTATION_CASE_UPPER) && tableCase.equalsIgnoreCase(ANNOTATION_CASE_UPPER)) {
return new TableName(table.getSchemaName().toUpperCase(), tableName.toUpperCase());
}
else if (schemaCase.equalsIgnoreCase(ANNOTATION_CASE_LOWER) && tableCase.equalsIgnoreCase(ANNOTATION_CASE_LOWER)) {
return new TableName(table.getSchemaName().toLowerCase(), tableName.toLowerCase());
}
else if (schemaCase.equalsIgnoreCase(ANNOTATION_CASE_LOWER) && tableCase.equalsIgnoreCase(ANNOTATION_CASE_UPPER)) {
return new TableName(table.getSchemaName().toLowerCase(), tableName.toUpperCase());
}
else if (schemaCase.equalsIgnoreCase(ANNOTATION_CASE_UPPER) && tableCase.equalsIgnoreCase(ANNOTATION_CASE_LOWER)) {
return new TableName(table.getSchemaName().toUpperCase(), tableName.toLowerCase());
}
else {
return new TableName(table.getSchemaName().toUpperCase(), tableName.toUpperCase());
}
}

/*
Default behavior with and without glue connection is different. As we want to make it backward compatible for customer who is not using glue connection.
With Glue connection, default behavior is `NONE` which we will not adjust any casing in the connector.
Without Glue connection, default behavior is `ANNOTATION` which customer can perform MY_TABLE@schemaCase=upper&tableCase=upper
*/
private static SnowflakeCasingMode getCasingMode(Map<String, String> configOptions)
{
String enableCaseInsensitiveMatchEnvValue = configOptions.getOrDefault(ENABLE_CASE_INSENSITIVE_MATCH, "false").toLowerCase();
boolean enableCaseInsensitiveMatch = enableCaseInsensitiveMatchEnvValue.equals("true");
LOGGER.info("{} environment variable set to: {}. Resolved to: {}",
ENABLE_CASE_INSENSITIVE_MATCH, enableCaseInsensitiveMatchEnvValue, enableCaseInsensitiveMatch);
boolean isGlueConnection = StringUtils.isNotBlank(configOptions.get(DEFAULT_GLUE_CONNECTION));
if (!configOptions.containsKey(CASING_MODE)) {
LOGGER.info("CASING MODE disable");
return isGlueConnection ? SnowflakeCasingMode.NONE : SnowflakeCasingMode.ANNOTATION;
}

return enableCaseInsensitiveMatch;
try {
SnowflakeCasingMode snowflakeCasingMode = SnowflakeCasingMode.valueOf(configOptions.get(CASING_MODE).toUpperCase());
LOGGER.info("CASING MODE enable: {}", snowflakeCasingMode.toString());
return snowflakeCasingMode;
}
catch (Exception ex) {
// print error log for customer along with list of input
LOGGER.error("Invalid input for:{}, input value:{}, valid values:{}", CASING_MODE, configOptions.get(CASING_MODE), Arrays.asList(SnowflakeCasingMode.values()), ex);
throw ex;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ private Optional<String> getPrimaryKey(TableName tableName) throws Exception
}
}

String primaryKey = String.join(", ", primaryKeys);
if (!Strings.isNullOrEmpty(primaryKey) && hasUniquePrimaryKey(tableName, primaryKey)) {
return Optional.of(primaryKey);
String primaryKeyString = primaryKeys.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(","));
if (!Strings.isNullOrEmpty(primaryKeyString) && hasUniquePrimaryKey(tableName, primaryKeyString)) {
return Optional.of(primaryKeyString);
}
}
return Optional.empty();
Expand All @@ -228,7 +228,7 @@ private Optional<String> getPrimaryKey(TableName tableName) throws Exception
private boolean hasUniquePrimaryKey(TableName tableName, String primaryKey) throws Exception
{
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
try (PreparedStatement preparedStatement = connection.prepareStatement("SELECT " + primaryKey + ", count(*) as COUNTS FROM " + tableName.getTableName() + " GROUP BY " + primaryKey + " ORDER BY COUNTS DESC");
try (PreparedStatement preparedStatement = connection.prepareStatement("SELECT " + primaryKey + ", count(*) as COUNTS FROM " + "\"" + tableName.getSchemaName() + "\".\"" + tableName.getTableName() + "\"" + " GROUP BY " + primaryKey + " ORDER BY COUNTS DESC");
ResultSet rs = preparedStatement.executeQuery()) {
if (rs.next()) {
if (rs.getInt(COUNTS_COLUMN_NAME) == 1) {
Expand Down Expand Up @@ -258,7 +258,7 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest getTabl
getTableLayoutRequest.getTableName().getTableName());

try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
TableName tableName = SnowflakeCaseInsensitiveResolver.getTableNameObjectCaseInsensitiveMatch(connection, getTableLayoutRequest.getTableName(), configOptions);
TableName tableName = getTableLayoutRequest.getTableName();
/**
* "MAX_PARTITION_COUNT" is currently set to 50 to limit the number of partitions.
* this is to handle timeout issues because of huge partitions
Expand Down Expand Up @@ -383,7 +383,7 @@ public GetTableResponse doGetTable(final BlockAllocator blockAllocator, final Ge
LOGGER.debug("doGetTable getTableName:{}", getTableRequest.getTableName());
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
Schema partitionSchema = getPartitionSchema(getTableRequest.getCatalogName());
TableName tableName = SnowflakeCaseInsensitiveResolver.getTableNameObjectCaseInsensitiveMatch(connection, getTableRequest.getTableName(), configOptions);
TableName tableName = SnowflakeCaseInsensitiveResolver.getAdjustedTableObjectNameBasedOnConfig(connection, getTableRequest.getTableName(), configOptions);
GetTableResponse getTableResponse = new GetTableResponse(getTableRequest.getCatalogName(), tableName, getSchema(connection, tableName, partitionSchema),
partitionSchema.getFields().stream().map(Field::getName).collect(Collectors.toSet()));
return getTableResponse;
Expand All @@ -397,7 +397,7 @@ public ListTablesResponse doListTables(final BlockAllocator blockAllocator, fina
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
LOGGER.info("{}: List table names for Catalog {}, Schema {}", listTablesRequest.getQueryId(),
listTablesRequest.getCatalogName(), listTablesRequest.getSchemaName());
String schemaName = SnowflakeCaseInsensitiveResolver.getSchemaNameCaseInsensitively(connection, listTablesRequest.getSchemaName(), configOptions);
String schemaName = SnowflakeCaseInsensitiveResolver.getAdjustedSchemaNameBasedOnConfig(connection, listTablesRequest.getSchemaName(), configOptions);

String token = listTablesRequest.getNextToken();
int pageSize = listTablesRequest.getPageSize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ public PreparedStatement buildSplitSql(Connection jdbcConnection, String catalog
preparedStatement = buildQueryPassthroughSql(jdbcConnection, constraints);
}
else {
TableName tableName = SnowflakeCaseInsensitiveResolver.getTableNameObjectCaseInsensitiveMatch(jdbcConnection, tableNameInput, configOptions);
preparedStatement = jdbcSplitQueryBuilder.buildSql(jdbcConnection, null, tableName.getSchemaName(), tableName.getTableName(), schema, constraints, split);
preparedStatement = jdbcSplitQueryBuilder.buildSql(jdbcConnection, null, tableNameInput.getSchemaName(), tableNameInput.getTableName(), schema, constraints, split);
}

// Disable fetching all rows.
Expand Down
Loading

0 comments on commit 8a43a53

Please sign in to comment.