Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added SQL Server reconcile code #1403

Merged
merged 12 commits into from
Feb 4, 2025
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@ class JDBCReaderMixin:
_spark: SparkSession

# TODO update the url
def _get_jdbc_reader(self, query, jdbc_url, driver):
def _get_jdbc_reader(self, query, jdbc_url, driver, prepare_query=None):
driver_class = {
"oracle": "oracle.jdbc.driver.OracleDriver",
"snowflake": "net.snowflake.client.jdbc.SnowflakeDriver",
"sqlserver": "com.microsoft.sqlserver.jdbc.SQLServerDriver",
}
return (
reader = (
self._spark.read.format("jdbc")
.option("url", jdbc_url)
.option("driver", driver_class.get(driver, driver))
.option("dbtable", f"({query}) tmp")
)
if driver == "sqlserver":
vijaypavann-db marked this conversation as resolved.
Show resolved Hide resolved
reader = reader.option('prepareQuery', prepare_query)
return reader

@staticmethod
def _get_jdbc_reader_options(options: JdbcReaderOptions):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from pyspark.sql import SparkSession
from sqlglot import Dialect
from sqlglot.dialects import TSQL

from databricks.labs.remorph.reconcile.connectors.data_source import DataSource
from databricks.labs.remorph.reconcile.connectors.databricks import DatabricksDataSource
from databricks.labs.remorph.reconcile.connectors.oracle import OracleDataSource
from databricks.labs.remorph.reconcile.connectors.snowflake import SnowflakeDataSource
from databricks.labs.remorph.reconcile.connectors.sql_server import SQLServerDataSource
from databricks.labs.remorph.transpiler.sqlglot.generator.databricks import Databricks
from databricks.labs.remorph.transpiler.sqlglot.parsers.oracle import Oracle
from databricks.labs.remorph.transpiler.sqlglot.parsers.snowflake import Snowflake
Expand All @@ -23,4 +25,6 @@ def create_adapter(
return OracleDataSource(engine, spark, ws, secret_scope)
if isinstance(engine, Databricks):
return DatabricksDataSource(engine, spark, ws, secret_scope)
if isinstance(engine, TSQL):
return SQLServerDataSource(engine, spark, ws, secret_scope)
vijaypavann-db marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Unsupported source type --> {engine}")
132 changes: 132 additions & 0 deletions src/databricks/labs/remorph/reconcile/connectors/sql_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import re
import logging
from datetime import datetime

from pyspark.errors import PySparkException
from pyspark.sql import DataFrame, DataFrameReader, SparkSession
from pyspark.sql.functions import col
from sqlglot import Dialect

from databricks.labs.remorph.reconcile.connectors.data_source import DataSource
from databricks.labs.remorph.reconcile.connectors.jdbc_reader import JDBCReaderMixin
from databricks.labs.remorph.reconcile.connectors.secrets import SecretsMixin
from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema
from databricks.sdk import WorkspaceClient

logger = logging.getLogger(__name__)

_SCHEMA_QUERY = """SELECT
COLUMN_NAME,
vijaypavann-db marked this conversation as resolved.
Show resolved Hide resolved
CASE
WHEN DATA_TYPE IN ('int', 'bigint')
THEN DATA_TYPE
WHEN DATA_TYPE IN ('smallint', 'tinyint')
THEN 'smallint'
WHEN DATA_TYPE IN ('decimal' ,'numeric')
THEN 'decimal(' +
CAST(NUMERIC_PRECISION AS VARCHAR) + ',' +
CAST(NUMERIC_SCALE AS VARCHAR) + ')'
WHEN DATA_TYPE IN ('float', 'real')
THEN 'double'
WHEN CHARACTER_MAXIMUM_LENGTH IS NOT NULL AND DATA_TYPE IN ('varchar','char','text','nchar','nvarchar','ntext')
THEN DATA_TYPE
WHEN DATA_TYPE IN ('date','time','datetime', 'datetime2','smalldatetime','datetimeoffset')
THEN DATA_TYPE
WHEN DATA_TYPE IN ('bit')
THEN 'boolean'
WHEN DATA_TYPE IN ('binary','varbinary')
THEN 'binary'
ELSE DATA_TYPE
END AS 'DATA_TYPE'
FROM
INFORMATION_SCHEMA.COLUMNS
WHERE
LOWER(TABLE_NAME) = LOWER('{table}')
AND LOWER(TABLE_SCHEMA) = LOWER('{schema}')
AND LOWER(TABLE_CATALOG) = LOWER('{catalog}')
"""


class SQLServerDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
_DRIVER = "sqlserver"

def __init__(
self,
engine: Dialect,
spark: SparkSession,
ws: WorkspaceClient,
secret_scope: str,
):
self._engine = engine
self._spark = spark
self._ws = ws
self._secret_scope = secret_scope

@property
def get_jdbc_url(self) -> str:
# Construct the JDBC URL
return (
f"jdbc:{self._DRIVER}://{self._get_secret('host')}:{self._get_secret('port')};"
f"databaseName={self._get_secret('database')};"
f"user={self._get_secret('user')};"
f"password={self._get_secret('password')};"
f"encrypt={self._get_secret('encrypt')};"
f"trustServerCertificate={self._get_secret('trustServerCertificate')};"
)

def read_data(
self,
catalog: str | None,
schema: str,
table: str,
query: str,
options: JdbcReaderOptions | None,
) -> DataFrame:
table_query = query.replace(":tbl", f"{catalog}.{schema}.{table}")
with_clause_pattern = re.compile(r'WITH\s+.*?\)\s*(?=SELECT)', re.IGNORECASE | re.DOTALL)
match = with_clause_pattern.search(table_query)
if match:
prepareqry_str = match.group(0)
query = table_query.replace(match.group(0), '')
else:
query = table_query
prepareqry_str = ""
try:
if options is None:
df = self.reader(query, prepareqry_str).load()
else:
options = self._get_jdbc_reader_options(options)
df = self._get_jdbc_reader(table_query, self.get_jdbc_url, self._DRIVER).options(**options).load()
return df.select([col(column).alias(column.lower()) for column in df.columns])
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "data", table_query)

vijaypavann-db marked this conversation as resolved.
Show resolved Hide resolved
def get_schema(
self,
catalog: str | None,
schema: str,
table: str,
) -> list[Schema]:
"""
Fetch the Schema from the INFORMATION_SCHEMA.COLUMNS table in SQL Server.

If the user's current role does not have the necessary privileges to access the specified
Information Schema object, RunTimeError will be raised:
"SQL access control error: Insufficient privileges to operate on schema 'INFORMATION_SCHEMA' "
"""
schema_query = re.sub(
r'\s+',
' ',
_SCHEMA_QUERY.format(catalog=catalog, schema=schema, table=table),
)
try:
logger.debug(f"Fetching schema using query: \n`{schema_query}`")
logger.info(f"Fetching Schema: Started at: {datetime.now()}")
schema_metadata = self.reader(schema_query).load().collect()
logger.info(f"Schema fetched successfully. Completed at: {datetime.now()}")
return [Schema(field.COLUMN_NAME.lower(), field.DATA_TYPE.lower()) for field in schema_metadata]
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "schema", schema_query)

def reader(self, query: str, prepareqry_str="") -> DataFrameReader:
return self._get_jdbc_reader(query, self.get_jdbc_url, self._DRIVER, prepareqry_str)
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ def _get_is_string(column_types_dict: dict[str, DataType], column_name: str) ->
partial(anonymous, func="CONCAT_WS(',', SORT_ARRAY({}))", dialect=get_dialect("databricks"))
],
},
"tsql": {
"default": [partial(anonymous, func="COALESCE(LTRIM(RTRIM(CAST([{}] AS VARCHAR(256)))), '_null_recon_')")],
exp.DataType.Type.DATE.value: [partial(anonymous, func="COALESCE(CONVERT(DATE, {0}, 101), '1900-01-01')")],
exp.DataType.Type.TIME.value: [partial(anonymous, func="COALESCE(CONVERT(TIME, {0}, 108), '00:00:00')")],
exp.DataType.Type.DATETIME.value: [
partial(anonymous, func="COALESCE(CONVERT(DATETIME, {0}, 120), '1900-01-01 00:00:00')")
],
},
}

sha256_partial = partial(sha2, num_bits="256", is_expr=True)
Expand All @@ -267,4 +275,10 @@ def _get_is_string(column_types_dict: dict[str, DataType], column_name: str) ->
source=sha256_partial,
target=sha256_partial,
),
get_dialect("tsql"): HashAlgoMapping(
source=partial(
anonymous, func="CONVERT(VARCHAR(256), HASHBYTES('SHA2_256', CONVERT(VARCHAR(256),{})), 2)", is_expr=True
),
target=sha256_partial,
),
}
175 changes: 175 additions & 0 deletions tests/unit/reconcile/connectors/test_sql_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import base64
import re
from unittest.mock import MagicMock, create_autospec

import pytest

from databricks.labs.remorph.transpiler.sqlglot.dialect_utils import get_dialect
from databricks.labs.remorph.reconcile.connectors.sql_server import SQLServerDataSource
from databricks.labs.remorph.reconcile.exception import DataSourceRuntimeException
from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Table
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.workspace import GetSecretResponse


def mock_secret(scope, key):
scope_secret_mock = {
"scope": {
'user': GetSecretResponse(key='user', value=base64.b64encode('my_user'.encode('utf-8')).decode('utf-8')),
'password': GetSecretResponse(
key='password', value=base64.b64encode(bytes('my_password', 'utf-8')).decode('utf-8')
),
'host': GetSecretResponse(key='host', value=base64.b64encode(bytes('my_host', 'utf-8')).decode('utf-8')),
'port': GetSecretResponse(key='port', value=base64.b64encode(bytes('777', 'utf-8')).decode('utf-8')),
'database': GetSecretResponse(
key='database', value=base64.b64encode(bytes('my_database', 'utf-8')).decode('utf-8')
),
'encrypt': GetSecretResponse(key='encrypt', value=base64.b64encode(bytes('true', 'utf-8')).decode('utf-8')),
'trustServerCertificate': GetSecretResponse(
key='trustServerCertificate', value=base64.b64encode(bytes('true', 'utf-8')).decode('utf-8')
),
}
}

return scope_secret_mock[scope][key]


def initial_setup():
pyspark_sql_session = MagicMock()
spark = pyspark_sql_session.SparkSession.builder.getOrCreate()

# Define the source, workspace, and scope
engine = get_dialect("tsql")
ws = create_autospec(WorkspaceClient)
scope = "scope"
ws.secrets.get_secret.side_effect = mock_secret
return engine, spark, ws, scope


def test_get_jdbc_url_happy():
# initial setup
engine, spark, ws, scope = initial_setup()
# create object for SnowflakeDataSource
data_source = SQLServerDataSource(engine, spark, ws, scope)
url = data_source.get_jdbc_url
# Assert that the URL is generated correctly
assert url == (
"""jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;"""
)


def test_get_jdbc_url_fail():
# initial setup
engine, spark, ws, scope = initial_setup()
ws.secrets.get_secret.side_effect = mock_secret
# create object for SnowflakeDataSource
data_source = SQLServerDataSource(engine, spark, ws, scope)
url = data_source.get_jdbc_url
# Assert that the URL is generated correctly
assert url == (
"""jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;"""
)


def test_read_data_with_options():
# initial setup
engine, spark, ws, scope = initial_setup()

# create object for SnowflakeDataSource
data_source = SQLServerDataSource(engine, spark, ws, scope)
# Create a Tables configuration object with JDBC reader options
table_conf = Table(
source_name="src_supplier",
target_name="tgt_supplier",
jdbc_reader_options=JdbcReaderOptions(
number_partitions=100, partition_column="s_partition_key", lower_bound="0", upper_bound="100"
),
)

# Call the read_data method with the Tables configuration
data_source.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options)

# spark assertions
spark.read.format.assert_called_with("jdbc")
spark.read.format().option.assert_called_with(
"url",
"jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;",
)
spark.read.format().option().option.assert_called_with("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
spark.read.format().option().option().option.assert_called_with("dbtable", "(select 1 from org.data.employee) tmp")
spark.read.format().option().option().option().option.assert_called_with("prepareQuery", None)
actual_args = spark.read.format().option().option().option().option().options.call_args.kwargs
expected_args = {
"numPartitions": 100,
"partitionColumn": "s_partition_key",
"lowerBound": '0',
"upperBound": "100",
"fetchsize": 100,
}
assert actual_args == expected_args
spark.read.format().option().option().option().option().options().load.assert_called_once()


def test_get_schema():
# initial setup
engine, spark, ws, scope = initial_setup()
# Mocking get secret method to return the required values
data_source = SQLServerDataSource(engine, spark, ws, scope)
# call test method
data_source.get_schema("org", "schema", "supplier")
# spark assertions
spark.read.format.assert_called_with("jdbc")
spark.read.format().option().option().option.assert_called_with(
"dbtable",
re.sub(
r'\s+',
' ',
r"""(SELECT
COLUMN_NAME,
CASE
WHEN DATA_TYPE IN ('int', 'bigint')
THEN DATA_TYPE
WHEN DATA_TYPE IN ('smallint', 'tinyint')
THEN 'smallint'
WHEN DATA_TYPE IN ('decimal' ,'numeric')
THEN 'decimal(' +
CAST(NUMERIC_PRECISION AS VARCHAR) + ',' +
CAST(NUMERIC_SCALE AS VARCHAR) + ')'
WHEN DATA_TYPE IN ('float', 'real')
THEN 'double'
WHEN CHARACTER_MAXIMUM_LENGTH IS NOT NULL AND DATA_TYPE IN ('varchar','char','text','nchar','nvarchar','ntext')
THEN DATA_TYPE
WHEN DATA_TYPE IN ('date','time','datetime', 'datetime2','smalldatetime','datetimeoffset')
THEN DATA_TYPE
WHEN DATA_TYPE IN ('bit')
THEN 'boolean'
WHEN DATA_TYPE IN ('binary','varbinary')
THEN 'binary'
ELSE DATA_TYPE
END AS 'DATA_TYPE'
FROM
INFORMATION_SCHEMA.COLUMNS
WHERE LOWER(TABLE_NAME) = LOWER('supplier')
AND LOWER(TABLE_SCHEMA) = LOWER('schema')
AND LOWER(TABLE_CATALOG) = LOWER('org')
) tmp""",
),
)


def test_get_schema_exception_handling():
# initial setup
engine, spark, ws, scope = initial_setup()
data_source = SQLServerDataSource(engine, spark, ws, scope)

spark.read.format().option().option().option().option().load.side_effect = RuntimeError("Test Exception")

# Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException
# is raised
with pytest.raises(
DataSourceRuntimeException,
match=re.escape(
"""Runtime exception occurred while fetching schema using SELECT COLUMN_NAME, CASE WHEN DATA_TYPE IN ('int', 'bigint') THEN DATA_TYPE WHEN DATA_TYPE IN ('smallint', 'tinyint') THEN 'smallint' WHEN DATA_TYPE IN ('decimal' ,'numeric') THEN 'decimal(' + CAST(NUMERIC_PRECISION AS VARCHAR) + ',' + CAST(NUMERIC_SCALE AS VARCHAR) + ')' WHEN DATA_TYPE IN ('float', 'real') THEN 'double' WHEN CHARACTER_MAXIMUM_LENGTH IS NOT NULL AND DATA_TYPE IN ('varchar','char','text','nchar','nvarchar','ntext') THEN DATA_TYPE WHEN DATA_TYPE IN ('date','time','datetime', 'datetime2','smalldatetime','datetimeoffset') THEN DATA_TYPE WHEN DATA_TYPE IN ('bit') THEN 'boolean' WHEN DATA_TYPE IN ('binary','varbinary') THEN 'binary' ELSE DATA_TYPE END AS 'DATA_TYPE' FROM INFORMATION_SCHEMA.COLUMNS WHERE LOWER(TABLE_NAME) = LOWER('supplier') AND LOWER(TABLE_SCHEMA) = LOWER('schema') AND LOWER(TABLE_CATALOG) = LOWER('org') : Test Exception"""
),
):
data_source.get_schema("org", "schema", "supplier")
Loading