diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index 441c67394..8989cce97 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -1,9 +1,11 @@ from __future__ import annotations import importlib -from importlib.metadata import version +import urllib.parse -from typing import Literal, TYPE_CHECKING, overload +from importlib.metadata import version +from pathlib import Path +from typing import Literal, TYPE_CHECKING, overload, Generic, TypeVar from .connectorx import ( read_sql as _read_sql, @@ -20,7 +22,7 @@ import pyarrow as pa # only for typing hints - from .connectorx import _DataframeInfos, _ArrowInfos + from .connectorx import _DataframeInfos, _ArrowInfos __version__ = version(__name__) @@ -42,7 +44,12 @@ Protocol = Literal["csv", "binary", "cursor", "simple", "text"] -def rewrite_conn(conn: str, protocol: Protocol | None = None) -> tuple[str, Protocol]: +_BackendT = TypeVar("_BackendT") + + +def rewrite_conn( + conn: str | ConnectionUrl, protocol: Protocol | None = None +) -> tuple[str, Protocol]: if not protocol: # note: redshift/clickhouse are not compatible with the 'binary' protocol, and use other database # drivers to connect. set a compatible protocol and masquerade as the appropriate backend. @@ -59,7 +66,7 @@ def rewrite_conn(conn: str, protocol: Protocol | None = None) -> tuple[str, Prot def get_meta( - conn: str, + conn: str | ConnectionUrl, query: str, protocol: Protocol | None = None, ) -> pd.DataFrame: @@ -84,7 +91,7 @@ def get_meta( def partition_sql( - conn: str, + conn: str | ConnectionUrl, query: str, partition_on: str, partition_num: int, @@ -118,7 +125,7 @@ def partition_sql( def read_sql_pandas( sql: list[str] | str, - con: str | dict[str, str], + con: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl], index_col: str | None = None, protocol: Protocol | None = None, partition_on: str | None = None, @@ -159,7 +166,7 @@ def read_sql_pandas( # default return pd.DataFrame @overload def read_sql( - conn: str | dict[str, str], + conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl], query: list[str] | str, *, protocol: Protocol | None = None, @@ -172,7 +179,7 @@ def read_sql( @overload def read_sql( - conn: str | dict[str, str], + conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl], query: list[str] | str, *, return_type: Literal["pandas"], @@ -186,7 +193,7 @@ def read_sql( @overload def read_sql( - conn: str | dict[str, str], + conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl], query: list[str] | str, *, return_type: Literal["arrow", "arrow2"], @@ -200,7 +207,7 @@ def read_sql( @overload def read_sql( - conn: str | dict[str, str], + conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl], query: list[str] | str, *, return_type: Literal["modin"], @@ -214,7 +221,7 @@ def read_sql( @overload def read_sql( - conn: str | dict[str, str], + conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl], query: list[str] | str, *, return_type: Literal["dask"], @@ -228,7 +235,7 @@ def read_sql( @overload def read_sql( - conn: str | dict[str, str], + conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl], query: list[str] | str, *, return_type: Literal["polars", "polars2"], @@ -241,7 +248,7 @@ def read_sql( def read_sql( - conn: str | dict[str, str], + conn: str | ConnectionUrl | dict[str, str] | dict[str, ConnectionUrl], query: list[str] | str, *, return_type: Literal[ @@ -477,3 +484,129 @@ def try_import_module(name: str): return importlib.import_module(name) except ModuleNotFoundError: raise ValueError(f"You need to install {name.split('.')[0]} first") + + +_ServerBackendT = TypeVar( + "_ServerBackendT", + bound=Literal[ + "redshift", + "clickhouse", + "postgres", + "postgresql", + "mysql", + "mssql", + "oracle", + "duckdb", + ], +) + + +class ConnectionUrl(Generic[_BackendT], str): + @overload + def __new__( + cls, + *, + backend: Literal["sqlite"], + db_path: str | Path, + ) -> ConnectionUrl[Literal["sqlite"]]: + """ + Help to build sqlite connection string url. + + Parameters + ========== + backend: + must specify "sqlite". + db_path: + the path to the sqlite database file. + """ + + @overload + def __new__( + cls, + *, + backend: Literal["bigquery"], + db_path: str | Path, + ) -> ConnectionUrl[Literal["bigquery"]]: + """ + Help to build BigQuery connection string url. + + Parameters + ========== + backend: + must specify "bigquery". + db_path: + the path to the bigquery database file. + """ + + @overload + def __new__( + cls, + *, + backend: _ServerBackendT, + username: str, + password: str = "", + server: str, + port: int, + database: str = "", + database_options: dict[str, str] | None = None, + ) -> ConnectionUrl[_ServerBackendT]: + """ + Help to build server-side backend database connection string url. + + Parameters + ========== + backend: + the database backend. + username: + the database username. + password: + the database password. + server: + the database server name. + port: + the database server port. + database: + the database name. + database_options: + the database options for connection. + """ + + @overload + def __new__( + cls, + raw_connection: str, + ) -> ConnectionUrl: + """ + Build connection from raw connection string url + + Parameters + ========== + raw_connection: + raw connection string + """ + + def __new__( + cls, + raw_connection: str | None = None, + *, + backend: str = "", + username: str = "", + password: str = "", + server: str = "", + port: int | None = None, + database: str = "", + database_options: dict[str, str] | None = None, + db_path: str | Path = "", + ) -> ConnectionUrl: + if raw_connection is not None: + return super().__new__(cls, raw_connection) + + assert backend + if backend == "sqlite": + db_path = urllib.parse.quote(str(db_path)) + connection = f"{backend}://{db_path}" + else: + connection = f"{backend}://{username}:{password}@{server}:{port}/{database}" + if database_options: + connection += "?" + urllib.parse.urlencode(database_options) + return super().__new__(cls, connection) diff --git a/connectorx-python/connectorx/tests/test_bigquery.py b/connectorx-python/connectorx/tests/test_bigquery.py index c5007a30f..87417500e 100644 --- a/connectorx-python/connectorx/tests/test_bigquery.py +++ b/connectorx-python/connectorx/tests/test_bigquery.py @@ -4,7 +4,7 @@ import pytest from pandas.testing import assert_frame_equal -from .. import read_sql +from .. import read_sql, ConnectionUrl @pytest.fixture(scope="module") # type: ignore @@ -121,9 +121,7 @@ def test_bigquery_some_empty_partition(bigquery_url: str) -> None: index=range(1), data={ "test_int": pd.Series([1], dtype="Int64"), - "test_string": pd.Series( - ["str1"], dtype="object" - ), + "test_string": pd.Series(["str1"], dtype="object"), "test_float": pd.Series([1.10], dtype="float64"), "test_bool": pd.Series([True], dtype="boolean"), }, @@ -137,10 +135,7 @@ def test_bigquery_some_empty_partition(bigquery_url: str) -> None: ) def test_bigquery_join(bigquery_url: str) -> None: query = "SELECT T.test_int, T.test_string, S.test_str FROM `dataprep-bigquery.dataprep.test_table` T INNER JOIN `dataprep-bigquery.dataprep.test_types` S ON T.test_int = S.test_int" - df = read_sql( - bigquery_url, - query - ) + df = read_sql(bigquery_url, query) df = df.sort_values("test_int").reset_index(drop=True) expected = pd.DataFrame( index=range(2), @@ -151,14 +146,14 @@ def test_bigquery_join(bigquery_url: str) -> None: "str1", "str2", ], - dtype="object" + dtype="object", ), "test_str": pd.Series( [ "๐Ÿ˜๐Ÿ˜‚๐Ÿ˜œ", "ใ“ใ‚“ใซใกใฏะ—ะดั€ะฐฬะฒ", ], - dtype="object" + dtype="object", ), }, ) @@ -188,14 +183,14 @@ def test_bigquery_join_with_partition(bigquery_url: str) -> None: "str1", "str2", ], - dtype="object" + dtype="object", ), "test_str": pd.Series( [ "๐Ÿ˜๐Ÿ˜‚๐Ÿ˜œ", "ใ“ใ‚“ใซใกใฏะ—ะดั€ะฐฬะฒ", ], - dtype="object" + dtype="object", ), }, ) @@ -203,7 +198,6 @@ def test_bigquery_join_with_partition(bigquery_url: str) -> None: assert_frame_equal(df, expected, check_names=True) - @pytest.mark.skipif( not os.environ.get("BIGQUERY_URL"), reason="Test bigquery only when `BIGQUERY_URL` is set", @@ -310,3 +304,11 @@ def test_bigquery_types(bigquery_url: str) -> None: }, ) assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("BIGQUERY_URL"), + reason="Test bigquery only when `BIGQUERY_URL` is set", +) +def test_connection_url(bigquery_url: str) -> None: + test_bigquery_types(ConnectionUrl(bigquery_url)) diff --git a/connectorx-python/connectorx/tests/test_clickhouse.py b/connectorx-python/connectorx/tests/test_clickhouse.py index 630ef3772..e67562ef3 100644 --- a/connectorx-python/connectorx/tests/test_clickhouse.py +++ b/connectorx-python/connectorx/tests/test_clickhouse.py @@ -4,7 +4,7 @@ import pytest from pandas.testing import assert_frame_equal -from .. import read_sql +from .. import read_sql, ConnectionUrl @pytest.fixture(scope="module") # type: ignore @@ -81,3 +81,11 @@ def test_clickhouse_types(clickhouse_url: str) -> None: }, ) assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("CLICKHOUSE_URL"), + reason="Do not test Clickhouse unless `CLICKHOUSE_URL` is set", +) +def test_connection_url(clickhouse_url: str) -> None: + test_clickhouse_types(ConnectionUrl(clickhouse_url)) diff --git a/connectorx-python/connectorx/tests/test_mssql.py b/connectorx-python/connectorx/tests/test_mssql.py index ee04d0055..6c9403396 100644 --- a/connectorx-python/connectorx/tests/test_mssql.py +++ b/connectorx-python/connectorx/tests/test_mssql.py @@ -3,6 +3,7 @@ import pandas as pd import pytest from pandas.testing import assert_frame_equal +from connectorx import ConnectionUrl from .. import read_sql @@ -92,7 +93,6 @@ def test_mssql_udf(mssql_url: str) -> None: def test_manual_partition(mssql_url: str) -> None: - queries = [ "SELECT * FROM test_table WHERE test_int < 2", "SELECT * FROM test_table WHERE test_int >= 2", @@ -496,3 +496,7 @@ def test_mssql_offset(mssql_url: str) -> None: } ) assert_frame_equal(df, expected, check_names=True) + + +def test_connection_url(mssql_url: str) -> None: + test_mssql_offset(ConnectionUrl(mssql_url)) \ No newline at end of file diff --git a/connectorx-python/connectorx/tests/test_mysql.py b/connectorx-python/connectorx/tests/test_mysql.py index 9376bf505..51ffcf168 100644 --- a/connectorx-python/connectorx/tests/test_mysql.py +++ b/connectorx-python/connectorx/tests/test_mysql.py @@ -4,7 +4,7 @@ import pytest from pandas.testing import assert_frame_equal -from .. import read_sql +from .. import read_sql, ConnectionUrl @pytest.fixture(scope="module") # type: ignore @@ -468,3 +468,7 @@ def test_mysql_cte(mysql_url: str) -> None: }, ) assert_frame_equal(df, expected, check_names=True) + + +def test_connection_url(mysql_url: str) -> None: + test_mysql_cte(ConnectionUrl(mysql_url)) diff --git a/connectorx-python/connectorx/tests/test_oracle.py b/connectorx-python/connectorx/tests/test_oracle.py index 59e489c10..ef6d8c0f2 100644 --- a/connectorx-python/connectorx/tests/test_oracle.py +++ b/connectorx-python/connectorx/tests/test_oracle.py @@ -4,7 +4,7 @@ import pytest from pandas.testing import assert_frame_equal -from .. import read_sql +from .. import read_sql, ConnectionUrl @pytest.fixture(scope="module") # type: ignore @@ -440,4 +440,11 @@ def test_oracle_round_function(oracle_url: str) -> None: "TEST_ROUND": pd.Series([1.11, 2.22, 3.33, None], dtype="float64"), } ) - assert_frame_equal(df, expected, check_names=True) \ No newline at end of file + assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("ORACLE_URL"), reason="Test oracle only when `ORACLE_URL` is set" +) +def test_connection_url(oracle_url: str) -> None: + test_oracle_round_function(ConnectionUrl(oracle_url)) \ No newline at end of file diff --git a/connectorx-python/connectorx/tests/test_redshift.py b/connectorx-python/connectorx/tests/test_redshift.py index 7c17f41d9..97f4cb814 100644 --- a/connectorx-python/connectorx/tests/test_redshift.py +++ b/connectorx-python/connectorx/tests/test_redshift.py @@ -5,7 +5,7 @@ import pytest from pandas.testing import assert_frame_equal -from .. import read_sql +from .. import read_sql, ConnectionUrl @pytest.fixture(scope="module") # type: ignore @@ -134,3 +134,11 @@ def test_read_sql_on_utf8(redshift_url: str) -> None: }, ) assert_frame_equal(df, expected, check_names=True) + + +@pytest.mark.skipif( + not os.environ.get("REDSHIFT_URL"), + reason="Do not test Redshift unless `REDSHIFT_URL` is set", +) +def test_connection_url(redshift_url: str) -> None: + test_read_sql_on_utf8(ConnectionUrl(redshift_url)) diff --git a/connectorx-python/connectorx/tests/test_sqlite.py b/connectorx-python/connectorx/tests/test_sqlite.py index f0e8a9977..2456a4d2d 100644 --- a/connectorx-python/connectorx/tests/test_sqlite.py +++ b/connectorx-python/connectorx/tests/test_sqlite.py @@ -5,7 +5,7 @@ import pytest from pandas.testing import assert_frame_equal -from .. import read_sql +from .. import read_sql, ConnectionUrl @pytest.fixture(scope="module") # type: ignore @@ -215,7 +215,6 @@ def test_sqlite_with_partition(sqlite_db: str) -> None: def test_manual_partition(sqlite_db: str) -> None: - queries = [ "SELECT test_int, test_nullint, test_str, test_float, test_bool, test_date, test_time, test_datetime FROM test_table WHERE test_int < 2", "SELECT test_int, test_nullint, test_str, test_float, test_bool, test_date, test_time, test_datetime FROM test_table WHERE test_int >= 2", @@ -391,3 +390,7 @@ def test_sqlite_cte(sqlite_db: str) -> None: }, ) assert_frame_equal(df, expected, check_names=True) + + +def test_connection_url(sqlite_db: str) -> None: + test_sqlite_cte(ConnectionUrl(sqlite_db))