Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial Spec For Credential Management and SQLAlchemy Database Connectors #1420

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
33078aa
init commit
sundarshankar89 Jan 15, 2025
e12db4d
Added Base Connector
sundarshankar89 Jan 15, 2025
7fc59ff
Moved the Abstract class to private
sundarshankar89 Jan 15, 2025
978f8bc
Added pyodbc dependency
sundarshankar89 Jan 15, 2025
39cdc3e
fmt fixes
sundarshankar89 Jan 15, 2025
a9d799f
Merge branch 'main' into feature/credential_manager
sundarshankar89 Jan 15, 2025
d1263b5
Added Vault Manager
sundarshankar89 Jan 16, 2025
0c40eca
Added TODO
sundarshankar89 Jan 17, 2025
d7eed08
Adding credential manager for multiple secret
sundarshankar89 Jan 20, 2025
bd0d012
Added reading credentials from env and then falling back to key itself
sundarshankar89 Jan 20, 2025
c7f91f5
fixed case agnostic connection creation.
sundarshankar89 Jan 20, 2025
1dcff15
Added UT
sundarshankar89 Jan 20, 2025
211944e
fmt fixes
sundarshankar89 Jan 20, 2025
a445aba
initial test case setup
sundarshankar89 Jan 20, 2025
3cb9c05
test case setup
sundarshankar89 Jan 21, 2025
8b1c254
Refactored to better
sundarshankar89 Jan 21, 2025
74030d3
Added Integration Test
sundarshankar89 Jan 24, 2025
29b14be
Added Integration Test
sundarshankar89 Jan 24, 2025
0aa457b
fmt fixes
sundarshankar89 Jan 24, 2025
9e1f7fd
added fixture
sundarshankar89 Jan 24, 2025
ee162b0
Merge branch 'main' into feature/credential_manager
sundarshankar89 Jan 27, 2025
79e3a86
add acceptance (#1428)
sundarshankar89 Jan 28, 2025
8e9dea6
Merge branch 'main' into feature/credential_manager
sundarshankar89 Jan 29, 2025
f44a09e
fmt fixes
sundarshankar89 Jan 29, 2025
355b76d
Simplified installation journey
sundarshankar89 Feb 3, 2025
5570790
Merge branch 'main' into feature/simplified_installation
sundarshankar89 Feb 5, 2025
a0eee07
Merge branch 'feature/simplified_installation' into feature/credentia…
sundarshankar89 Feb 5, 2025
329c913
Merge branch 'main' into feature/credential_manager
sundarshankar89 Feb 7, 2025
4d0525c
Merge branch 'main' into feature/credential_manager
sundarshankar89 Feb 7, 2025
7807e94
fmt fixes
sundarshankar89 Feb 7, 2025
750645b
credential manager rewrite
sundarshankar89 Feb 7, 2025
8cdb336
integration tests
sundarshankar89 Feb 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ dependencies = [
"databricks-labs-blueprint[yaml]>=0.2.3",
"databricks-labs-lsql>=0.7.5,<0.14.0", # TODO: Limit the LSQL version until dependencies are correct.
"cryptography>=41.0.3",
"pyodbc",
"SQLAlchemy",
"pygls>=2.0.0a2",

]

[project.urls]
Expand Down
44 changes: 44 additions & 0 deletions src/databricks/labs/remorph/connections/credential_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from pathlib import Path
import os
import yaml


from databricks.labs.blueprint.wheels import ProductInfo


class Credentials:
def __init__(self, product_info: ProductInfo) -> None:
self._product_info = product_info
self._credentials: dict[str, str] = self._load_credentials(self._get_local_version_file_path())

def _get_local_version_file_path(self) -> Path:
user_home = f"{Path(__file__).home()}"
return Path(f"{user_home}/.databricks/labs/{self._product_info.product_name()}/credentials.yml")

def _load_credentials(self, file_path: Path) -> dict[str, str]:
with open(file_path, encoding="utf-8") as f:
return yaml.safe_load(f)

def get(self, source: str) -> dict[str, str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get what ?

error_msg = f"source system: {source} credentials not found in file credentials.yml"
if source in self._credentials:
value = self._credentials[source]
if isinstance(value, dict):
return {k: self.get_secret_value(v) for k, v in value.items()}
raise KeyError(error_msg)
raise KeyError(error_msg)

def get_secret_value(self, key: str) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't get_secret suffice ?

secret_vault_type = self._credentials.get('secret_vault_type', 'local').lower()
if secret_vault_type == 'local':
return key
if secret_vault_type == 'env':
value = os.getenv(str(key)) # Port numbers can be int
if value is None:
print(f"Environment variable {key} not found Failing back to actual strings")
return key
return value
if secret_vault_type == 'databricks':
raise NotImplementedError("Databricks secret vault not implemented")

raise ValueError(f"Unsupported secret vault type: {secret_vault_type}")
71 changes: 71 additions & 0 deletions src/databricks/labs/remorph/connections/database_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from abc import ABC, abstractmethod
from typing import Any

from sqlalchemy import create_engine
from sqlalchemy.engine import Engine, Result
from sqlalchemy.orm import sessionmaker
from sqlalchemy import text


class _ISourceSystemConnector(ABC):
@abstractmethod
def _connect(self) -> Engine:
pass

@abstractmethod
def execute_query(self, query: str) -> Result[Any]:
pass


class _BaseConnector(_ISourceSystemConnector):
def __init__(self, config: dict[str, Any]):
self.config = config
self.engine: Engine = self._connect()

def _connect(self) -> Engine:
raise NotImplementedError("Subclasses should implement this method")

def execute_query(self, query: str) -> Result[Any]:
if not self.engine:
raise ConnectionError("Not connected to the database.")
session = sessionmaker(bind=self.engine)
connection = session()
return connection.execute(text(query))


def _create_connector(db_type: str, config: dict[str, Any]) -> _ISourceSystemConnector:
connectors = {
"snowflake": SnowflakeConnector,
sundarshankar89 marked this conversation as resolved.
Show resolved Hide resolved
"mssql": MSSQLConnector,
"tsql": MSSQLConnector,
"synapse": MSSQLConnector,
}

connector_class = connectors.get(db_type.lower())

if connector_class is None:
raise ValueError(f"Unsupported database type: {db_type}")

return connector_class(config)


class SnowflakeConnector(_BaseConnector):
def _connect(self) -> Engine:
raise NotImplementedError("Snowflake connector not implemented")


class MSSQLConnector(_BaseConnector):
def _connect(self) -> Engine:
connection_string = (
f"mssql+pyodbc://{self.config['user']}:{self.config['password']}@{self.config['server']}/"
f"{self.config['database']}?driver={self.config['driver']}"
)
return create_engine(connection_string, echo=True)


class DatabaseManager:
def __init__(self, db_type: str, config: dict[str, Any]):
self.connector = _create_connector(db_type, config)

def execute_query(self, query: str) -> Result[Any]:
return self.connector.execute_query(query)
33 changes: 33 additions & 0 deletions src/databricks/labs/remorph/resources/config/credentials.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
secret_vault_type: local | databricks | env
secret_vault_name: null
snowflake:
account: example_account
connect_retries: 1
connect_timeout: null
host: null
insecure_mode: false
oauth_client_id: null
oauth_client_secret: null
password: null
port: null
private_key: null
private_key_passphrase: null
private_key_path: null
role: null
token: null
user: null
warehouse: null

mssql:
#TODO Expand to support sqlpools, and legacy dwh
database: DB_NAME
driver: ODBC Driver 18 for SQL Server
server: example_host
port: null
user: null
password: null





90 changes: 90 additions & 0 deletions tests/unit/connections/test_credential_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pytest
from unittest.mock import patch, MagicMock
from pathlib import Path
from databricks.labs.remorph.connections.credential_manager import Credentials
from databricks.labs.blueprint.wheels import ProductInfo


@pytest.fixture
def product_info():
mock_product_info = MagicMock(spec=ProductInfo)
mock_product_info.product_name.return_value = "test_product"
return mock_product_info


@pytest.fixture
def local_credentials():
return {
'secret_vault_type': 'local',
'mssql': {
'database': 'DB_NAME',
'driver': 'ODBC Driver 18 for SQL Server',
'server': 'example_host',
'user': 'local_user',
'password': 'local_password',
},
}


@pytest.fixture
def env_credentials():
return {
'secret_vault_type': 'env',
'mssql': {
'database': 'DB_NAME',
'driver': 'ODBC Driver 18 for SQL Server',
'server': 'example_host',
'user': 'MSSQL_USER_ENV',
'password': 'MSSQL_PASSWORD_ENV',
},
}


@pytest.fixture
def databricks_credentials():
return {
'secret_vault_type': 'databricks',
'secret_vault_name': 'databricks_vault_name',
'mssql': {
'database': 'DB_NAME',
'driver': 'ODBC Driver 18 for SQL Server',
'server': 'example_host',
'user': 'databricks_user',
'password': 'databricks_password',
},
}


@patch('databricks.labs.remorph.connections.credential_manager.Credentials._get_local_version_file_path')
@patch('databricks.labs.remorph.connections.credential_manager.Credentials._load_credentials')
def test_local_credentials(mock_load_credentials, mock_get_local_version_file_path, product_info, local_credentials):
mock_load_credentials.return_value = local_credentials
mock_get_local_version_file_path.return_value = Path("/fake/path/to/credentials.yml")
credentials = Credentials(product_info)
creds = credentials.get('mssql')
assert creds['user'] == 'local_user'
assert creds['password'] == 'local_password'


@patch('databricks.labs.remorph.connections.credential_manager.Credentials._get_local_version_file_path')
@patch('databricks.labs.remorph.connections.credential_manager.Credentials._load_credentials')
@patch.dict('os.environ', {'MSSQL_USER_ENV': 'env_user', 'MSSQL_PASSWORD_ENV': 'env_password'})
def test_env_credentials(mock_load_credentials, mock_get_local_version_file_path, product_info, env_credentials):
mock_load_credentials.return_value = env_credentials
mock_get_local_version_file_path.return_value = Path("/fake/path/to/credentials.yml")
credentials = Credentials(product_info)
creds = credentials.get('mssql')
assert creds['user'] == 'env_user'
assert creds['password'] == 'env_password'


@patch('databricks.labs.remorph.connections.credential_manager.Credentials._get_local_version_file_path')
@patch('databricks.labs.remorph.connections.credential_manager.Credentials._load_credentials')
def test_databricks_credentials(
mock_load_credentials, mock_get_local_version_file_path, product_info, databricks_credentials
):
mock_load_credentials.return_value = databricks_credentials
mock_get_local_version_file_path.return_value = Path("/fake/path/to/credentials.yml")
credentials = Credentials(product_info)
with pytest.raises(NotImplementedError):
credentials.get('mssql')
55 changes: 55 additions & 0 deletions tests/unit/connections/test_database_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
from unittest.mock import MagicMock, patch
from databricks.labs.remorph.connections.database_manager import DatabaseManager

sample_config = {
'user': 'test_user',
'password': 'test_pass',
'server': 'test_server',
'database': 'test_db',
'driver': 'ODBC Driver 17 for SQL Server',
}


def test_create_connector_unsupported_db_type():
with pytest.raises(ValueError, match="Unsupported database type: unsupported_db"):
DatabaseManager("unsupported_db", sample_config)


# Test case for MSSQLConnector
@patch('databricks.labs.remorph.connections.database_manager.MSSQLConnector')
def test_mssql_connector(mock_mssql_connector):
mock_connector_instance = MagicMock()
mock_mssql_connector.return_value = mock_connector_instance

db_manager = DatabaseManager("mssql", sample_config)

assert db_manager.connector == mock_connector_instance
mock_mssql_connector.assert_called_once_with(sample_config)


@patch('databricks.labs.remorph.connections.database_manager.MSSQLConnector')
def test_execute_query(mock_mssql_connector):
mock_connector_instance = MagicMock()
mock_mssql_connector.return_value = mock_connector_instance

db_manager = DatabaseManager("mssql", sample_config)

query = "SELECT * FROM users"
mock_result = MagicMock()
mock_connector_instance.execute_query.return_value = mock_result

result = db_manager.execute_query(query)

assert result == mock_result
mock_connector_instance.execute_query.assert_called_once_with(query)


def test_execute_query_without_connection():
db_manager = DatabaseManager("mssql", sample_config)

# Simulating that the engine is not connected
db_manager.connector.engine = None

with pytest.raises(ConnectionError, match="Not connected to the database."):
db_manager.execute_query("SELECT * FROM users")
Loading