diff --git a/src/sqlodbc/opensearch_communication.cpp b/src/sqlodbc/opensearch_communication.cpp index 571357f..7d45abb 100644 --- a/src/sqlodbc/opensearch_communication.cpp +++ b/src/sqlodbc/opensearch_communication.cpp @@ -16,12 +16,18 @@ #include // clang-format on -#define SQL_ENDPOINT_ERROR_STR "Error" +static const std::string SQL_ENDPOINT_OPENSEARCH = "/_plugins/_sql"; +static const std::string SQL_ENDPOINT_ELASTICSEARCH = "/_opendistro/_sql"; +static const std::string SQL_ENDPOINT_ERROR = "Error"; + +static const std::string SERVICE_NAME_DEFAULT = "es"; +static const std::string SERVICE_NAME_SERVERLESS = "aoss"; + +static const std::string CREDENTIALS_PROFILE = "opensearchodbc"; +static const std::string CREDENTIALS_PROVIDER_ALLOCATION_TAG = + "CREDENTIAL_PROVIDER"; static const std::string ctype = "application/json"; -static const std::string ALLOCATION_TAG = "AWS_SIGV4_AUTH"; -static const std::string SERVICE_NAME = "es"; -static const std::string ESODBC_PROFILE_NAME = "opensearchodbc"; static const std::string ERROR_MSG_PREFIX = "[OpenSearch][SQL ODBC Driver][SQL Plugin] "; static const std::string JSON_SCHEMA = @@ -433,7 +439,9 @@ OpenSearchCommunication::IssueRequest( } // Handle authentication - if (m_rt_opts.auth.auth_type == AUTHTYPE_BASIC) { + std::string& auth_type = m_rt_opts.auth.auth_type; + + if (auth_type == AUTHTYPE_BASIC) { std::string userpw_str = m_rt_opts.auth.username + ":" + m_rt_opts.auth.password; Aws::Utils::Array< unsigned char > userpw_arr( @@ -442,14 +450,23 @@ OpenSearchCommunication::IssueRequest( Aws::String hashed_userpw = Aws::Utils::HashingUtils::Base64Encode(userpw_arr); request->SetAuthorization("Basic " + hashed_userpw); - } else if (m_rt_opts.auth.auth_type == AUTHTYPE_IAM) { + } + + // TODO #70: Handle serverless + else if (auth_type == AUTHTYPE_IAM) { std::shared_ptr< Aws::Auth::ProfileConfigFileAWSCredentialsProvider > credential_provider = Aws::MakeShared< Aws::Auth::ProfileConfigFileAWSCredentialsProvider >( - ALLOCATION_TAG.c_str(), ESODBC_PROFILE_NAME.c_str()); + CREDENTIALS_PROVIDER_ALLOCATION_TAG.c_str(), + CREDENTIALS_PROFILE.c_str()); + + const std::string& service_name = + isServerless() + ? SERVICE_NAME_SERVERLESS + : SERVICE_NAME_DEFAULT; Aws::Client::AWSAuthV4Signer signer(credential_provider, - SERVICE_NAME.c_str(), + service_name.c_str(), m_rt_opts.auth.region.c_str()); if (m_rt_opts.auth.tunnel_host.length() > 0) { @@ -548,26 +565,33 @@ bool OpenSearchCommunication::CheckSQLPluginAvailability() { } bool OpenSearchCommunication::EstablishConnection() { - // Generate HttpClient Connection class if it does not exist + LogMsg(OPENSEARCH_ALL, "Attempting to establish DB connection."); + + // Generate HttpClient Connection class if it does not exist if (!m_http_client) { InitializeConnection(); } - // check if the endpoint is initialized + // Check if the endpoint can be determined. if (sql_endpoint.empty()) { SetSqlEndpoint(); } + if (sql_endpoint == SQL_ENDPOINT_ERROR) { + LogMsg(OPENSEARCH_ERROR, m_error_message.c_str()); + return false; + } + // Check whether SQL plugin has been installed and enabled in the // OpenSearch server since the SQL plugin is a prerequisite to // use this driver. - if((sql_endpoint != SQL_ENDPOINT_ERROR_STR) && CheckSQLPluginAvailability()) { - return true; + if(!CheckSQLPluginAvailability()) { + LogMsg(OPENSEARCH_ERROR, m_error_message.c_str()); + return false; } - LogMsg(OPENSEARCH_ERROR, m_error_message.c_str()); - return false; + return true; } std::vector< std::string > OpenSearchCommunication::GetColumnsWithSelectQuery( @@ -1048,15 +1072,39 @@ std::string OpenSearchCommunication::GetClusterName() { /** * @brief Sets URL endpoint for SQL plugin. On failure to * determine appropriate endpoint, value is set to SQL_ENDPOINT_ERROR_STR - * + * */ void OpenSearchCommunication::SetSqlEndpoint() { + + // TODO #70: Support serverless + if (isServerless()) { + sql_endpoint = SQL_ENDPOINT_OPENSEARCH; + return; + } + std::string distribution = GetServerDistribution(); if (distribution.empty()) { - sql_endpoint = SQL_ENDPOINT_ERROR_STR; - } else if (distribution.compare("opensearch") == 0) { - sql_endpoint = "/_plugins/_sql"; + sql_endpoint = SQL_ENDPOINT_ERROR; + } else if (distribution == "opensearch") { + sql_endpoint = SQL_ENDPOINT_OPENSEARCH; } else { - sql_endpoint = "/_opendistro/_sql"; + sql_endpoint = SQL_ENDPOINT_ELASTICSEARCH; } } + +/** + * Returns whether this is connecting to an OpenSearch Serverless cluster. + * @see + * https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-overview.html + */ +bool OpenSearchCommunication::isServerless() { + + // TODO #70: Support serverless + + // Parse the server URL. + if(m_rt_opts.conn.server.find("aoss.amazonaws.com") != std::string::npos) { + return true; + } + + return false; +} diff --git a/src/sqlodbc/opensearch_communication.h b/src/sqlodbc/opensearch_communication.h index b37ccc3..3a6eea6 100644 --- a/src/sqlodbc/opensearch_communication.h +++ b/src/sqlodbc/opensearch_communication.h @@ -97,6 +97,9 @@ class OpenSearchCommunication { std::string m_response_str; std::shared_ptr< Aws::Http::HttpClient > m_http_client; std::string m_error_message_to_user; + + // TODO #70 - Support serverless + bool isServerless(); }; #endif