Skip to content

Commit

Permalink
Merge pull request #8 from Ceridan/feature/add-schema-support
Browse files Browse the repository at this point in the history
[KQL] Add support for "schema" (used in cross-database joins)
  • Loading branch information
Ceridan authored Jan 10, 2022
2 parents a4b0412 + cd00a0b commit 340c40d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 2 deletions.
32 changes: 30 additions & 2 deletions sqlalchemy_kusto/dialect_kql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re
from typing import List, Optional, Tuple

from sqlalchemy import Column, exc
Expand Down Expand Up @@ -60,12 +61,15 @@ def visit_select(
query = self._get_most_inner_element(from_object.element)
(main, lets) = self._extract_let_statements(query.text)
compiled_query_lines.extend(lets)
compiled_query_lines.append(f"let {from_object.name} = ({main});")
compiled_query_lines.append(f"let {from_object.name} = ({self._convert_schema_in_statement(main)});")
compiled_query_lines.append(from_object.name)
elif hasattr(from_object, "name"):
if from_object.schema is not None:
unquoted_schema = from_object.schema.strip("\"'")
compiled_query_lines.append(f'database("{unquoted_schema}").')
compiled_query_lines.append(from_object.name)
else:
compiled_query_lines.append(from_object.text)
compiled_query_lines.append(self._convert_schema_in_statement(from_object.text))

if select._whereclause is not None:
where_clause = select._whereclause._compiler_dispatch(self, **kwargs)
Expand Down Expand Up @@ -145,6 +149,30 @@ def _build_column_projection(column_name: str, column_alias: str = None):
"""Generates column alias semantic for project statement"""
return f"{column_alias} = {column_name}" if column_alias else column_name

@staticmethod
def _convert_schema_in_statement(query: str) -> str:
"""
Converts schema in the query from SQL notation to KQL notation. Returns converted query.
Examples:
- schema.table -> database("schema").table
- schema."table.name" -> database("schema")."table.name"
- "schema.name".table -> database("schema.name").table
- "schema.name"."table.name" -> database("schema.name")."table.name"
- "schema name"."table name" -> database("schema name")."table name"
- "table.name" -> "table.name"
- MyTable -> MyTable
"""

pattern = r"^([a-zA-Z0-9]+\b|\"[a-zA-Z0-9 \-_.]+\")?\.?([a-zA-Z0-9]+\b|\"[a-zA-Z0-9 \-_.]+\")"
match = re.search(pattern, query)

if not match or not match.group(1):
return query

unquoted_schema = match.group(1).strip("\"'")
return query.replace(query, f'database("{unquoted_schema}").{match.group(2)}', 1)


class KustoKqlHttpsDialect(KustoBaseDialect):
name = "kustokql"
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/test_dialect_kql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import sqlalchemy as sa
from sqlalchemy import Column, MetaData, String, Table, column, create_engine, literal_column, select, text
from sqlalchemy.sql.selectable import TextAsFrom
Expand Down Expand Up @@ -146,3 +147,50 @@ def test_quotes():
# fmt: on

assert query_compiled == query_expected


@pytest.mark.parametrize(
"schema_name,table_name,expected_table_name",
[
("schema", "table", 'database("schema").table'),
("schema", '"table.name"', 'database("schema")."table.name"'),
('"schema.name"', "table", 'database("schema.name").table'),
('"schema.name"', '"table.name"', 'database("schema.name")."table.name"'),
('"schema name"', '"table name"', 'database("schema name")."table name"'),
(None, '"table.name"', '"table.name"'),
(None, "MyTable", "MyTable"),
],
)
def test_schema_from_metadata(table_name: str, schema_name: str, expected_table_name: str):
metadata = MetaData(schema=schema_name) if schema_name else MetaData()
stream = Table(
table_name,
metadata,
)
query = stream.select().limit(5)

query_compiled = str(query.compile(engine)).replace("\n", "")

query_expected = f"{expected_table_name}| take %(param_1)s"
assert query_compiled == query_expected


@pytest.mark.parametrize(
"query_table_name,expected_table_name",
[
("schema.table", 'database("schema").table'),
('schema."table.name"', 'database("schema")."table.name"'),
('"schema.name".table', 'database("schema.name").table'),
('"schema.name"."table.name"', 'database("schema.name")."table.name"'),
('"schema name"."table name"', 'database("schema name")."table name"'),
('"table.name"', '"table.name"'),
("MyTable", "MyTable"),
],
)
def test_schema_from_query(query_table_name: str, expected_table_name: str):
query = select("*").select_from(TextAsFrom(text(query_table_name), ["*"]).alias("inner_qry")).limit(5)

query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "")

query_expected = f"let inner_qry = ({expected_table_name});inner_qry| take 5"
assert query_compiled == query_expected

0 comments on commit 340c40d

Please sign in to comment.