From 1f8cf7333164b8b9ddfe3caf8625c539c28c845d Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 22 Aug 2024 19:11:28 +0300 Subject: [PATCH] [PECO-1857] Use SSL options with HTTPS connection pool (#425) * [PECO-1857] Use SSL options with HTTPS connection pool Signed-off-by: Levko Kravets * Some cleanup Signed-off-by: Levko Kravets * Resolve circular dependencies Signed-off-by: Levko Kravets * Update existing tests Signed-off-by: Levko Kravets * Fix MyPy issues Signed-off-by: Levko Kravets * Fix `_tls_no_verify` handling Signed-off-by: Levko Kravets * Add tests Signed-off-by: Levko Kravets --------- Signed-off-by: Levko Kravets --- src/databricks/sql/auth/thrift_http_client.py | 41 ++-- src/databricks/sql/client.py | 18 +- .../sql/cloudfetch/download_manager.py | 9 +- src/databricks/sql/cloudfetch/downloader.py | 12 +- src/databricks/sql/thrift_backend.py | 43 +--- src/databricks/sql/types.py | 48 +++++ src/databricks/sql/utils.py | 18 +- tests/unit/test_cloud_fetch_queue.py | 30 +-- tests/unit/test_download_manager.py | 5 +- tests/unit/test_downloader.py | 16 +- tests/unit/test_thrift_backend.py | 186 ++++++++++++------ 11 files changed, 267 insertions(+), 159 deletions(-) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index f7c22a1e..6273ab28 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -1,13 +1,11 @@ import base64 import logging import urllib.parse -from typing import Dict, Union +from typing import Dict, Union, Optional import six import thrift -logger = logging.getLogger(__name__) - import ssl import warnings from http.client import HTTPResponse @@ -16,6 +14,9 @@ from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager from urllib3.util import make_headers from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy +from databricks.sql.types import SSLOptions + +logger = logging.getLogger(__name__) class THttpClient(thrift.transport.THttpClient.THttpClient): @@ -25,13 +26,12 @@ def __init__( uri_or_host, port=None, path=None, - cafile=None, - cert_file=None, - key_file=None, - ssl_context=None, + ssl_options: Optional[SSLOptions] = None, max_connections: int = 1, retry_policy: Union[DatabricksRetryPolicy, int] = 0, ): + self._ssl_options = ssl_options + if port is not None: warnings.warn( "Please use the THttpClient('http{s}://host:port/path') constructor", @@ -48,13 +48,11 @@ def __init__( self.scheme = parsed.scheme assert self.scheme in ("http", "https") if self.scheme == "https": - self.certfile = cert_file - self.keyfile = key_file - self.context = ( - ssl.create_default_context(cafile=cafile) - if (cafile and not ssl_context) - else ssl_context - ) + if self._ssl_options is not None: + # TODO: Not sure if those options are used anywhere - need to double-check + self.certfile = self._ssl_options.tls_client_cert_file + self.keyfile = self._ssl_options.tls_client_cert_key_file + self.context = self._ssl_options.create_ssl_context() self.port = parsed.port self.host = parsed.hostname self.path = parsed.path @@ -109,12 +107,23 @@ def startRetryTimer(self): def open(self): # self.__pool replaces the self.__http used by the original THttpClient + _pool_kwargs = {"maxsize": self.max_connections} + if self.scheme == "http": pool_class = HTTPConnectionPool elif self.scheme == "https": pool_class = HTTPSConnectionPool - - _pool_kwargs = {"maxsize": self.max_connections} + _pool_kwargs.update( + { + "cert_reqs": ssl.CERT_REQUIRED + if self._ssl_options.tls_verify + else ssl.CERT_NONE, + "ca_certs": self._ssl_options.tls_trusted_ca_file, + "cert_file": self._ssl_options.tls_client_cert_file, + "key_file": self._ssl_options.tls_client_cert_key_file, + "key_password": self._ssl_options.tls_client_cert_key_password, + } + ) if self.using_proxy(): proxy_manager = ProxyManager( diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index c0bf534d..addc340e 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -35,7 +35,7 @@ ) -from databricks.sql.types import Row +from databricks.sql.types import Row, SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence @@ -178,8 +178,9 @@ def read(self) -> Optional[OAuthToken]: # _tls_trusted_ca_file # Set to the path of the file containing trusted CA certificates for server certificate # verification. If not provide, uses system truststore. - # _tls_client_cert_file, _tls_client_cert_key_file + # _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password # Set client SSL certificate. + # See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain # _retry_stop_after_attempts_count # The maximum number of attempts during a request retry sequence (defaults to 24) # _socket_timeout @@ -220,12 +221,25 @@ def read(self) -> Optional[OAuthToken]: base_headers = [("User-Agent", useragent_header)] + self._ssl_options = SSLOptions( + # Double negation is generally a bad thing, but we have to keep backward compatibility + tls_verify=not kwargs.get( + "_tls_no_verify", False + ), # by default - verify cert and host + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + self.thrift_backend = ThriftBackend( self.host, self.port, http_path, (http_headers or []) + base_headers, auth_provider, + ssl_options=self._ssl_options, _use_arrow_native_complex_types=_use_arrow_native_complex_types, **kwargs, ) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index e30adcd6..7e96cd32 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -1,6 +1,5 @@ import logging -from ssl import SSLContext from concurrent.futures import ThreadPoolExecutor, Future from typing import List, Union @@ -9,6 +8,8 @@ DownloadableResultSettings, DownloadedFile, ) +from databricks.sql.types import SSLOptions + from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink logger = logging.getLogger(__name__) @@ -20,7 +21,7 @@ def __init__( links: List[TSparkArrowResultLink], max_download_threads: int, lz4_compressed: bool, - ssl_context: SSLContext, + ssl_options: SSLOptions, ): self._pending_links: List[TSparkArrowResultLink] = [] for link in links: @@ -38,7 +39,7 @@ def __init__( self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads) self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed) - self._ssl_context = ssl_context + self._ssl_options = ssl_options def get_next_downloaded_file( self, next_row_offset: int @@ -95,7 +96,7 @@ def _schedule_downloads(self): handler = ResultSetDownloadHandler( settings=self._downloadable_result_settings, link=link, - ssl_context=self._ssl_context, + ssl_options=self._ssl_options, ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 00ffecd0..03c70054 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -3,13 +3,12 @@ import requests from requests.adapters import HTTPAdapter, Retry -from ssl import SSLContext, CERT_NONE import lz4.frame import time from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink - from databricks.sql.exc import Error +from databricks.sql.types import SSLOptions logger = logging.getLogger(__name__) @@ -66,11 +65,11 @@ def __init__( self, settings: DownloadableResultSettings, link: TSparkArrowResultLink, - ssl_context: SSLContext, + ssl_options: SSLOptions, ): self.settings = settings self.link = link - self._ssl_context = ssl_context + self._ssl_options = ssl_options def run(self) -> DownloadedFile: """ @@ -95,14 +94,13 @@ def run(self) -> DownloadedFile: session.mount("http://", HTTPAdapter(max_retries=retryPolicy)) session.mount("https://", HTTPAdapter(max_retries=retryPolicy)) - ssl_verify = self._ssl_context.verify_mode != CERT_NONE - try: # Get the file via HTTP request response = session.get( self.link.fileLink, timeout=self.settings.download_timeout, - verify=ssl_verify, + verify=self._ssl_options.tls_verify, + # TODO: Pass cert from `self._ssl_options` ) response.raise_for_status() diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 56412fce..e89bff26 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -5,7 +5,6 @@ import time import uuid import threading -from ssl import CERT_NONE, CERT_REQUIRED, create_default_context from typing import List, Union import pyarrow @@ -36,6 +35,7 @@ convert_decimals_in_arrow_table, convert_column_based_set_to_arrow_table, ) +from databricks.sql.types import SSLOptions logger = logging.getLogger(__name__) @@ -85,6 +85,7 @@ def __init__( http_path: str, http_headers, auth_provider: AuthProvider, + ssl_options: SSLOptions, staging_allowed_local_path: Union[None, str, List[str]] = None, **kwargs, ): @@ -93,16 +94,6 @@ def __init__( # Tag to add to User-Agent header. For use by partners. # _username, _password # Username and password Basic authentication (no official support) - # _tls_no_verify - # Set to True (Boolean) to completely disable SSL verification. - # _tls_verify_hostname - # Set to False (Boolean) to disable SSL hostname verification, but check certificate. - # _tls_trusted_ca_file - # Set to the path of the file containing trusted CA certificates for server certificate - # verification. If not provide, uses system truststore. - # _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password - # Set client SSL certificate. - # See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain # _connection_uri # Overrides server_hostname and http_path. # RETRY/ATTEMPT POLICY @@ -162,29 +153,7 @@ def __init__( # Cloud fetch self.max_download_threads = kwargs.get("max_download_threads", 10) - # Configure tls context - ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file")) - if kwargs.get("_tls_no_verify") is True: - ssl_context.check_hostname = False - ssl_context.verify_mode = CERT_NONE - elif kwargs.get("_tls_verify_hostname") is False: - ssl_context.check_hostname = False - ssl_context.verify_mode = CERT_REQUIRED - else: - ssl_context.check_hostname = True - ssl_context.verify_mode = CERT_REQUIRED - - tls_client_cert_file = kwargs.get("_tls_client_cert_file") - tls_client_cert_key_file = kwargs.get("_tls_client_cert_key_file") - tls_client_cert_key_password = kwargs.get("_tls_client_cert_key_password") - if tls_client_cert_file: - ssl_context.load_cert_chain( - certfile=tls_client_cert_file, - keyfile=tls_client_cert_key_file, - password=tls_client_cert_key_password, - ) - - self._ssl_context = ssl_context + self._ssl_options = ssl_options self._auth_provider = auth_provider @@ -225,7 +194,7 @@ def __init__( self._transport = databricks.sql.auth.thrift_http_client.THttpClient( auth_provider=self._auth_provider, uri_or_host=uri, - ssl_context=self._ssl_context, + ssl_options=self._ssl_options, **additional_transport_args, # type: ignore ) @@ -776,7 +745,7 @@ def _results_message_to_execute_response(self, resp, operation_state): max_download_threads=self.max_download_threads, lz4_compressed=lz4_compressed, description=description, - ssl_context=self._ssl_context, + ssl_options=self._ssl_options, ) else: arrow_queue_opt = None @@ -1008,7 +977,7 @@ def fetch_results( max_download_threads=self.max_download_threads, lz4_compressed=lz4_compressed, description=description, - ssl_context=self._ssl_context, + ssl_options=self._ssl_options, ) return queue, resp.hasMoreRows diff --git a/src/databricks/sql/types.py b/src/databricks/sql/types.py index aa11b954..fef22cd9 100644 --- a/src/databricks/sql/types.py +++ b/src/databricks/sql/types.py @@ -19,6 +19,54 @@ from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar import datetime import decimal +from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context + + +class SSLOptions: + tls_verify: bool + tls_verify_hostname: bool + tls_trusted_ca_file: Optional[str] + tls_client_cert_file: Optional[str] + tls_client_cert_key_file: Optional[str] + tls_client_cert_key_password: Optional[str] + + def __init__( + self, + tls_verify: bool = True, + tls_verify_hostname: bool = True, + tls_trusted_ca_file: Optional[str] = None, + tls_client_cert_file: Optional[str] = None, + tls_client_cert_key_file: Optional[str] = None, + tls_client_cert_key_password: Optional[str] = None, + ): + self.tls_verify = tls_verify + self.tls_verify_hostname = tls_verify_hostname + self.tls_trusted_ca_file = tls_trusted_ca_file + self.tls_client_cert_file = tls_client_cert_file + self.tls_client_cert_key_file = tls_client_cert_key_file + self.tls_client_cert_key_password = tls_client_cert_key_password + + def create_ssl_context(self) -> SSLContext: + ssl_context = create_default_context(cafile=self.tls_trusted_ca_file) + + if self.tls_verify is False: + ssl_context.check_hostname = False + ssl_context.verify_mode = CERT_NONE + elif self.tls_verify_hostname is False: + ssl_context.check_hostname = False + ssl_context.verify_mode = CERT_REQUIRED + else: + ssl_context.check_hostname = True + ssl_context.verify_mode = CERT_REQUIRED + + if self.tls_client_cert_file: + ssl_context.load_cert_chain( + certfile=self.tls_client_cert_file, + keyfile=self.tls_client_cert_key_file, + password=self.tls_client_cert_key_password, + ) + + return ssl_context class Row(tuple): diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index c22688bb..2807bd2b 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -9,7 +9,6 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union import re -from ssl import SSLContext import lz4.frame import pyarrow @@ -21,13 +20,14 @@ TSparkArrowResultLink, TSparkRowSetType, ) +from databricks.sql.types import SSLOptions from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter -BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] - import logging +BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] + logger = logging.getLogger(__name__) @@ -48,7 +48,7 @@ def build_queue( t_row_set: TRowSet, arrow_schema_bytes: bytes, max_download_threads: int, - ssl_context: SSLContext, + ssl_options: SSLOptions, lz4_compressed: bool = True, description: Optional[List[List[Any]]] = None, ) -> ResultSetQueue: @@ -62,7 +62,7 @@ def build_queue( lz4_compressed (bool): Whether result data has been lz4 compressed. description (List[List[Any]]): Hive table schema description. max_download_threads (int): Maximum number of downloader thread pool threads. - ssl_context (SSLContext): SSLContext object for CloudFetchQueue + ssl_options (SSLOptions): SSLOptions object for CloudFetchQueue Returns: ResultSetQueue @@ -91,7 +91,7 @@ def build_queue( lz4_compressed=lz4_compressed, description=description, max_download_threads=max_download_threads, - ssl_context=ssl_context, + ssl_options=ssl_options, ) else: raise AssertionError("Row set type is not valid") @@ -137,7 +137,7 @@ def __init__( self, schema_bytes, max_download_threads: int, - ssl_context: SSLContext, + ssl_options: SSLOptions, start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, @@ -160,7 +160,7 @@ def __init__( self.result_links = result_links self.lz4_compressed = lz4_compressed self.description = description - self._ssl_context = ssl_context + self._ssl_options = ssl_options logger.debug( "Initialize CloudFetch loader, row set start offset: {}, file list:".format( @@ -178,7 +178,7 @@ def __init__( links=result_links or [], max_download_threads=self.max_download_threads, lz4_compressed=self.lz4_compressed, - ssl_context=self._ssl_context, + ssl_options=self._ssl_options, ) self.table = self._create_next_table() diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index cd14c676..acd0c392 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -1,10 +1,10 @@ import pyarrow import unittest from unittest.mock import MagicMock, patch -from ssl import create_default_context from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink import databricks.sql.utils as utils +from databricks.sql.types import SSLOptions class CloudFetchQueueSuite(unittest.TestCase): @@ -51,7 +51,7 @@ def test_initializer_adds_links(self, mock_create_next_table): schema_bytes, result_links=result_links, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert len(queue.download_manager._pending_links) == 10 @@ -65,7 +65,7 @@ def test_initializer_no_links_to_add(self): schema_bytes, result_links=result_links, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert len(queue.download_manager._pending_links) == 0 @@ -78,7 +78,7 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file): MagicMock(), result_links=[], max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue._create_next_table() is None @@ -95,7 +95,7 @@ def test_initializer_create_next_table_success(self, mock_get_next_downloaded_fi result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) expected_result = self.make_arrow_table() @@ -120,7 +120,7 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -140,7 +140,7 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -160,7 +160,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -180,7 +180,7 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -199,7 +199,7 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table is None @@ -216,7 +216,7 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -235,7 +235,7 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -254,7 +254,7 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -273,7 +273,7 @@ def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_ta result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -293,7 +293,7 @@ def test_remaining_rows_empty_table(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table is None diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index c084d8e4..a11bc8d4 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -1,9 +1,8 @@ import unittest from unittest.mock import patch, MagicMock -from ssl import create_default_context - import databricks.sql.cloudfetch.download_manager as download_manager +from databricks.sql.types import SSLOptions from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink @@ -17,7 +16,7 @@ def create_download_manager(self, links, max_download_threads=10, lz4_compressed links, max_download_threads, lz4_compressed, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) def create_result_link( diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index b6e473b5..7075ef6c 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -2,10 +2,10 @@ from unittest.mock import Mock, patch, MagicMock import requests -from ssl import create_default_context import databricks.sql.cloudfetch.downloader as downloader from databricks.sql.exc import Error +from databricks.sql.types import SSLOptions def create_response(**kwargs) -> requests.Response: @@ -26,7 +26,7 @@ def test_run_link_expired(self, mock_time): result_link = Mock() # Already expired result_link.expiryTime = 999 - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(Error) as context: d.run() @@ -40,7 +40,7 @@ def test_run_link_past_expiry_buffer(self, mock_time): result_link = Mock() # Within the expiry buffer time result_link.expiryTime = 1004 - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(Error) as context: d.run() @@ -58,7 +58,7 @@ def test_run_get_response_not_ok(self, mock_time, mock_session): settings.use_proxy = False result_link = Mock(expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(requests.exceptions.HTTPError) as context: d.run() self.assertTrue('404' in str(context.exception)) @@ -73,7 +73,7 @@ def test_run_uncompressed_successful(self, mock_time, mock_session): settings.is_lz4_compressed = False result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) file = d.run() assert file.file_bytes == b"1234567890" * 10 @@ -89,7 +89,7 @@ def test_run_compressed_successful(self, mock_time, mock_session): settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) file = d.run() assert file.file_bytes == b"1234567890" * 10 @@ -102,7 +102,7 @@ def test_download_connection_error(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = \ b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(ConnectionError): d.run() @@ -114,6 +114,6 @@ def test_download_timeout(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = \ b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(TimeoutError): d.run() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 4bcf84d2..0333766c 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -4,11 +4,13 @@ import unittest from unittest.mock import patch, MagicMock, Mock from ssl import CERT_NONE, CERT_REQUIRED +from urllib3 import HTTPSConnectionPool import pyarrow import databricks.sql from databricks.sql import utils +from databricks.sql.types import SSLOptions from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql import * from databricks.sql.auth.authenticators import AuthProvider @@ -67,7 +69,7 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -77,7 +79,7 @@ def _make_type_desc(self, type): ) def _make_fake_thrift_backend(self): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() thrift_backend._create_arrow_table = MagicMock() @@ -155,7 +157,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_headers_are_set(self, t_http_client_class): - ThriftBackend("foo", 123, "bar", [("header", "value")], auth_provider=AuthProvider()) + ThriftBackend("foo", 123, "bar", [("header", "value")], auth_provider=AuthProvider(), ssl_options=SSLOptions()) t_http_client_class.return_value.setCustomHeaders.assert_called_with({"header": "value"}) def test_proxy_headers_are_set(self): @@ -175,75 +177,140 @@ def test_proxy_headers_are_set(self): assert isinstance(result.get('proxy-authorization'), type(str())) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_backend.create_default_context") + @patch("databricks.sql.types.create_default_context") def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_client_class): mock_cert_key_file = Mock() mock_cert_key_password = Mock() mock_trusted_ca_file = Mock() mock_cert_file = Mock() + mock_ssl_options = SSLOptions( + tls_client_cert_file=mock_cert_file, + tls_client_cert_key_file=mock_cert_key_file, + tls_client_cert_key_password=mock_cert_key_password, + tls_trusted_ca_file=mock_trusted_ca_file, + ) + mock_ssl_context = mock_ssl_options.create_ssl_context() + mock_create_default_context.assert_called_once_with(cafile=mock_trusted_ca_file) + ThriftBackend( "foo", 123, "bar", [], auth_provider=AuthProvider(), - _tls_client_cert_file=mock_cert_file, - _tls_client_cert_key_file=mock_cert_key_file, - _tls_client_cert_key_password=mock_cert_key_password, - _tls_trusted_ca_file=mock_trusted_ca_file, + ssl_options=mock_ssl_options, ) - mock_create_default_context.assert_called_once_with(cafile=mock_trusted_ca_file) - mock_ssl_context = mock_create_default_context.return_value mock_ssl_context.load_cert_chain.assert_called_once_with( certfile=mock_cert_file, keyfile=mock_cert_key_file, password=mock_cert_key_password ) self.assertTrue(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) - self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) + self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) + + @patch("databricks.sql.types.create_default_context") + def test_tls_cert_args_are_used_by_http_client(self, mock_create_default_context): + from databricks.sql.auth.thrift_http_client import THttpClient + + mock_cert_key_file = Mock() + mock_cert_key_password = Mock() + mock_trusted_ca_file = Mock() + mock_cert_file = Mock() + + mock_ssl_options = SSLOptions( + tls_verify=True, + tls_client_cert_file=mock_cert_file, + tls_client_cert_key_file=mock_cert_key_file, + tls_client_cert_key_password=mock_cert_key_password, + tls_trusted_ca_file=mock_trusted_ca_file, + ) + + http_client = THttpClient( + auth_provider=None, + uri_or_host="https://example.com", + ssl_options=mock_ssl_options, + ) + + self.assertEqual(http_client.scheme, 'https') + self.assertEqual(http_client.certfile, mock_ssl_options.tls_client_cert_file) + self.assertEqual(http_client.keyfile, mock_ssl_options.tls_client_cert_key_file) + self.assertIsNotNone(http_client.certfile) + mock_create_default_context.assert_called() + + http_client.open() + + conn_pool = http_client._THttpClient__pool + self.assertIsInstance(conn_pool, HTTPSConnectionPool) + self.assertEqual(conn_pool.cert_reqs, CERT_REQUIRED) + self.assertEqual(conn_pool.ca_certs, mock_ssl_options.tls_trusted_ca_file) + self.assertEqual(conn_pool.cert_file, mock_ssl_options.tls_client_cert_file) + self.assertEqual(conn_pool.key_file, mock_ssl_options.tls_client_cert_key_file) + self.assertEqual(conn_pool.key_password, mock_ssl_options.tls_client_cert_key_password) + + def test_tls_no_verify_is_respected_by_http_client(self): + from databricks.sql.auth.thrift_http_client import THttpClient + + http_client = THttpClient( + auth_provider=None, + uri_or_host="https://example.com", + ssl_options=SSLOptions(tls_verify=False), + ) + self.assertEqual(http_client.scheme, 'https') + + http_client.open() + + conn_pool = http_client._THttpClient__pool + self.assertIsInstance(conn_pool, HTTPSConnectionPool) + self.assertEqual(conn_pool.cert_reqs, CERT_NONE) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_backend.create_default_context") + @patch("databricks.sql.types.create_default_context") def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_client_class): - ThriftBackend("foo", 123, "bar", [], auth_provider=AuthProvider(), _tls_no_verify=True) + mock_ssl_options = SSLOptions(tls_verify=False) + mock_ssl_context = mock_ssl_options.create_ssl_context() + mock_create_default_context.assert_called() + + ThriftBackend("foo", 123, "bar", [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options) - mock_ssl_context = mock_create_default_context.return_value self.assertFalse(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_NONE) - self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) + self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_backend.create_default_context") + @patch("databricks.sql.types.create_default_context") def test_tls_verify_hostname_is_respected( self, mock_create_default_context, t_http_client_class ): + mock_ssl_options = SSLOptions(tls_verify_hostname=False) + mock_ssl_context = mock_ssl_options.create_ssl_context() + mock_create_default_context.assert_called() + ThriftBackend( - "foo", 123, "bar", [], auth_provider=AuthProvider(), _tls_verify_hostname=False + "foo", 123, "bar", [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options ) - mock_ssl_context = mock_create_default_context.return_value self.assertFalse(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) - self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) + self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_port_and_host_are_respected(self, t_http_client_class): - ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider()) + ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_https_does_not_duplicate(self, t_http_client_class): - ThriftBackend("https://hostname", 123, "path_value", [], auth_provider=AuthProvider()) + ThriftBackend("https://hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_class): - ThriftBackend("https://hostname/", 123, "path_value", [], auth_provider=AuthProvider()) + ThriftBackend("https://hostname/", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" ) @@ -251,17 +318,17 @@ def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_cla @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_socket_timeout_is_propagated(self, t_http_client_class): ThriftBackend( - "hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=129 + "hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), _socket_timeout=129 ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000) ThriftBackend( - "hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=0 + "hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), _socket_timeout=0 ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0) - ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider()) + ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000) ThriftBackend( - "hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=None + "hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), _socket_timeout=None ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], None) @@ -350,7 +417,7 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): def test_make_request_checks_status_code(self): error_codes = [ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS] - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) for code in error_codes: mock_error_response = Mock() @@ -388,7 +455,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): ), ) thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider() + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() ) with self.assertRaises(DatabaseError) as cm: @@ -417,7 +484,7 @@ def test_handle_execute_response_sets_compression_in_direct_results(self, build_ closeOperation=None, ), ) - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) execute_response = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -449,7 +516,7 @@ def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_serv tcli_service_instance.GetOperationStatus.return_value = op_state_resp thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider() + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() ) with self.assertRaises(DatabaseError) as cm: @@ -477,7 +544,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance.GetOperationStatus.return_value = t_get_operation_status_resp tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) @@ -510,7 +577,7 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) @@ -562,7 +629,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider() + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() ) with self.assertRaises(DatabaseError) as cm: @@ -603,7 +670,7 @@ def test_handle_execute_response_can_handle_without_direct_results(self, tcli_se op_state_3, ] thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider() + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() ) results_message_response = thrift_backend._handle_execute_response( execute_resp, Mock() @@ -634,7 +701,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): ) thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider() + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() ) thrift_backend._results_message_to_execute_response = Mock() @@ -818,7 +885,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): .to_pybytes() ) - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) arrow_queue, has_more_results = thrift_backend.fetch_results( op_handle=Mock(), max_rows=1, @@ -836,7 +903,7 @@ def test_execute_statement_calls_client_and_handle_execute_response(self, tcli_s tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.ExecuteStatement.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -854,7 +921,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response(self, tcli_servic tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -871,7 +938,7 @@ def test_get_schemas_calls_client_and_handle_execute_response(self, tcli_service tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -897,7 +964,7 @@ def test_get_tables_calls_client_and_handle_execute_response(self, tcli_service_ tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -927,7 +994,7 @@ def test_get_columns_calls_client_and_handle_execute_response(self, tcli_service tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -957,14 +1024,14 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class tcli_service_instance = tcli_service_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend.open_session({}, None, None) self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend.close_command(self.operation_handle) self.assertEqual( tcli_service_instance.CloseOperation.call_args[0][0].operationHandle, @@ -974,7 +1041,7 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend.close_session(self.session_handle) self.assertEqual( tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle @@ -1012,7 +1079,7 @@ def test_non_arrow_non_column_based_set_triggers_exception(self, tcli_service_cl def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) @@ -1021,7 +1088,7 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) @@ -1043,7 +1110,7 @@ def test_create_arrow_table_calls_correct_conversion_method( @patch("lz4.frame.decompress") @patch("pyarrow.ipc.open_stream") def test_convert_arrow_based_set_to_arrow_table(self, open_stream_mock, lz4_decompress_mock): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) lz4_decompress_mock.return_value = bytearray("Testing", "utf-8") @@ -1221,6 +1288,7 @@ def test_make_request_will_retry_GetOperationStatus( "path", [], auth_provider=AuthProvider(), + ssl_options=SSLOptions(), _retry_stop_after_attempts_count=EXPECTED_RETRIES, _retry_delay_default=1, ) @@ -1287,6 +1355,7 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( "path", [], auth_provider=AuthProvider(), + ssl_options=SSLOptions(), _retry_stop_after_attempts_count=EXPECTED_RETRIES, _retry_delay_default=1, ) @@ -1308,7 +1377,7 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_ mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(OperationalError) as cm: thrift_backend.make_request(mock_method, Mock()) @@ -1333,6 +1402,7 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( "path", [], auth_provider=AuthProvider(), + ssl_options=SSLOptions(), _retry_stop_after_attempts_count=14, _retry_delay_max=0, _retry_delay_min=0, @@ -1353,7 +1423,7 @@ def test_make_request_will_read_error_message_headers_if_set(self, t_transport_c mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) error_headers = [ [("x-thriftserver-error-message", "thrift server error message")], @@ -1454,7 +1524,7 @@ def test_retry_args_passthrough(self, mock_http_client): "_retry_stop_after_attempts_duration": 100, } backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), **retry_delay_args + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), **retry_delay_args ) for arg, val in retry_delay_args.items(): self.assertEqual(getattr(backend, arg), val) @@ -1470,7 +1540,7 @@ def test_retry_args_bounding(self, mock_http_client): k: v[i][0] for (k, v) in retry_delay_test_args_and_expected_values.items() } backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), **retry_delay_args + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), **retry_delay_args ) retry_delay_expected_vals = { k: v[i][1] for (k, v) in retry_delay_test_args_and_expected_values.items() @@ -1490,7 +1560,7 @@ def test_configuration_passthrough(self, tcli_client_class): "42": "42", } - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) backend.open_session(mock_config, None, None) open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] @@ -1501,7 +1571,7 @@ def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp mock_config = {"spark.thriftserver.arrowBasedRowSet.timestampAsString": True} - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(databricks.sql.Error) as cm: backend.open_session(mock_config, None, None) @@ -1520,7 +1590,7 @@ def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, sch def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) initial_cat_schem_args = [("cat", None), (None, "schem"), ("cat", "schem")] for cat, schem in initial_cat_schem_args: @@ -1540,7 +1610,7 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_ tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) backend.open_session({}, None, None) open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] @@ -1550,7 +1620,7 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_ def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) # If the initial catalog is set, but server returns canUseMultipleCatalogs=False, we # expect failure. If the initial catalog isn't set, then canUseMultipleCatalogs=False # is fine @@ -1588,7 +1658,7 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem"), ) - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(InvalidServerResponseError) as cm: backend.open_session({}, "cat", "schem") @@ -1616,7 +1686,7 @@ def test_execute_command_sets_complex_type_fields_correctly( complex_arg_types["_use_arrow_native_decimals"] = decimals thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), **complex_arg_types + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), **complex_arg_types ) thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[0][0]