Skip to content

Commit

Permalink
Separate SQL query building into SQLBuilder object (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw authored Feb 15, 2024
1 parent ee8b9f8 commit 1d9481d
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 103 deletions.
6 changes: 5 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
0.11.3 (unreleased)
-------------------

- Nothing changed yet.
- Moved SQL query building from SQLGateway to a separate SQLBuilder class.
Applications that use the SQLGateway should review custom query building functionality.

- Moved SQL row <-> domain model mapping to SQLGateway.mapper. Applications
overriding this mapping (dict_to_row, rows_to_dict) should adapt.


0.11.2 (2024-01-31)
Expand Down
125 changes: 125 additions & 0 deletions clean_python/sql/sql_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from datetime import datetime

from sqlalchemy import and_
from sqlalchemy import asc
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import Executable
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import Table
from sqlalchemy import true
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.sql.expression import ColumnElement
from sqlalchemy.sql.expression import false

from clean_python import ctx
from clean_python import Filter
from clean_python import Id
from clean_python import Json
from clean_python import PageOptions


class SQLBuilder:
def __init__(self, table: Table, multitenant: bool):
self.table = table
self.multitenant = multitenant

@property
def current_tenant(self) -> int | None:
if not self.multitenant:
return None
if ctx.tenant is None:
raise RuntimeError(f"{self.__class__} requires a tenant in the context")
return ctx.tenant.id

def _filter_to_sql(self, filter: Filter) -> ColumnElement:
try:
column = getattr(self.table.c, filter.field)
except AttributeError:
return false()
if len(filter.values) == 0:
return false()
elif len(filter.values) == 1:
return column == filter.values[0]
else:
return column.in_(filter.values)

def _filters_to_sql(self, filters: list[Filter]) -> ColumnElement:
qs = [self._filter_to_sql(x) for x in filters]
if self.multitenant:
qs.append(self.table.c.tenant == self.current_tenant)
return and_(*qs)

def _id_filter_to_sql(self, id: Id) -> ColumnElement:
return self._filters_to_sql([Filter(field="id", values=[id])])

def _santize_item(self, item: Json) -> Json:
known = {c.key for c in self.table.c}
result = {k: item[k] for k in item.keys() if k in known}
if "id" in result and result["id"] is None:
del result["id"]
if self.multitenant:
result["tenant"] = self.current_tenant
return result

def select_for_update(self, id: Id) -> Executable:
return select(self.table).with_for_update().where(self._id_filter_to_sql(id))

def select(self, filters: list[Filter], params: PageOptions | None) -> Executable:
query = select(self.table).where(self._filters_to_sql(filters))
if params is not None:
sort = asc(params.order_by) if params.ascending else desc(params.order_by)
query = query.order_by(sort).limit(params.limit).offset(params.offset)
return query

def insert(self, item: Json) -> Executable:
return (
insert(self.table).values(**self._santize_item(item)).returning(self.table)
)

def upsert(self, item: Json) -> Executable:
item = self._santize_item(item)
return (
insert(self.table)
.values(**item)
.on_conflict_do_update(
index_elements=["id", "tenant"] if self.multitenant else ["id"],
set_=item,
)
.returning(self.table)
)

def update(self, id: Id, item: Json, if_unmodified_since: datetime | None):
q = self._id_filter_to_sql(id)
if if_unmodified_since is not None:
q &= self.table.c.updated_at == if_unmodified_since
return (
update(self.table)
.where(q)
.values(**self._santize_item(item))
.returning(self.table)
)

def delete(self, id: Id) -> Executable:
return (
delete(self.table)
.where(self._id_filter_to_sql(id))
.returning(self.table.c.id)
)

def count(self, filters: list[Filter]) -> Executable:
return (
select(func.count().label("count"))
.select_from(self.table)
.where(self._filters_to_sql(filters))
)

def exists(self, filters: list[Filter]) -> Executable:
return (
select(true().label("exists"))
.select_from(self.table)
.where(self._filters_to_sql(filters))
.limit(1)
)
117 changes: 15 additions & 102 deletions clean_python/sql/sql_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,19 @@
from typing import TypeVar

import inject
from sqlalchemy import and_
from sqlalchemy import asc
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import Table
from sqlalchemy import true
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.sql import Executable
from sqlalchemy.sql.expression import ColumnElement
from sqlalchemy.sql.expression import false

from clean_python import Conflict
from clean_python import ctx
from clean_python import DoesNotExist
from clean_python import Filter
from clean_python import Gateway
from clean_python import Id
from clean_python import Json
from clean_python import Mapper
from clean_python import PageOptions

from .sql_builder import SQLBuilder
from .sql_provider import SQLDatabase
from .sql_provider import SQLProvider

Expand All @@ -42,9 +32,9 @@

class SQLGateway(Gateway):
table: Table
nested: bool
multitenant: bool
has_related: bool
mapper: Mapper = Mapper()

def __init__(
self,
Expand All @@ -53,6 +43,7 @@ def __init__(
):
self.provider_override = provider_override
self.nested = nested
self.builder = SQLBuilder(self.table, self.multitenant)

@property
def provider(self):
Expand All @@ -68,18 +59,6 @@ def __init_subclass__(
cls.has_related = has_related
super().__init_subclass__()

def rows_to_dict(self, rows: List[Json]) -> List[Json]:
return rows

def dict_to_row(self, obj: Json) -> Json:
known = {c.key for c in self.table.c}
result = {k: obj[k] for k in obj.keys() if k in known}
if "id" in result and result["id"] is None:
del result["id"]
if self.multitenant:
result["tenant"] = self.current_tenant
return result

@asynccontextmanager
async def transaction(self: T) -> AsyncIterator[T]:
if self.nested:
Expand All @@ -88,27 +67,17 @@ async def transaction(self: T) -> AsyncIterator[T]:
async with self.provider.transaction() as provider:
yield self.__class__(provider, nested=True)

@property
def current_tenant(self) -> Optional[int]:
if not self.multitenant:
return None
if ctx.tenant is None:
raise RuntimeError(f"{self.__class__} requires a tenant in the context")
return ctx.tenant.id

async def get_related(self, items: List[Json]) -> None:
"""Implement this to use transactions for consistently getting nested records"""

async def set_related(self, item: Json, result: Json) -> None:
"""Implement this to use transactions for consistently setting nested records"""

async def execute(self, query: Executable) -> List[Json]:
return self.rows_to_dict(await self.provider.execute(query))
return [self.mapper.to_internal(x) for x in await self.provider.execute(query)]

async def add(self, item: Json) -> Json:
query = (
insert(self.table).values(**self.dict_to_row(item)).returning(self.table)
)
query = self.builder.insert(self.mapper.to_external(item))
if self.has_related:
async with self.transaction() as transaction:
(result,) = await transaction.execute(query)
Expand All @@ -123,14 +92,8 @@ async def update(
id_ = item.get("id")
if id_ is None:
raise DoesNotExist("record", id_)
q = self._id_filter_to_sql(id_)
if if_unmodified_since is not None:
q &= self.table.c.updated_at == if_unmodified_since
query = (
update(self.table)
.where(q)
.values(**self.dict_to_row(item))
.returning(self.table)
query = self.builder.update(
id_, self.mapper.to_external(item), if_unmodified_since
)
if self.has_related:
async with self.transaction() as transaction:
Expand All @@ -147,10 +110,9 @@ async def update(
return result[0]

async def _select_for_update(self, id: Id) -> Json:
query = self.builder.select_for_update(id)
async with self.transaction() as transaction:
result = await transaction.execute(
select(self.table).with_for_update().where(self._id_filter_to_sql(id)),
)
result = await transaction.execute(query)
if not result:
raise DoesNotExist("record", id)
await transaction.get_related(result)
Expand All @@ -165,16 +127,7 @@ async def update_transactional(self, id: Id, func: Callable[[Json], Json]) -> Js
async def upsert(self, item: Json) -> Json:
if item.get("id") is None:
return await self.add(item)
values = self.dict_to_row(item)
query = (
insert(self.table)
.values(**values)
.on_conflict_do_update(
index_elements=["id", "tenant"] if self.multitenant else ["id"],
set_=values,
)
.returning(self.table)
)
query = self.builder.upsert(self.mapper.to_external(item))
if self.has_related:
async with self.transaction() as transaction:
result = await transaction.execute(query)
Expand All @@ -184,41 +137,12 @@ async def upsert(self, item: Json) -> Json:
return result[0]

async def remove(self, id: Id) -> bool:
query = (
delete(self.table)
.where(self._id_filter_to_sql(id))
.returning(self.table.c.id)
)
return bool(await self.execute(query))

def _filter_to_sql(self, filter: Filter) -> ColumnElement:
try:
column = getattr(self.table.c, filter.field)
except AttributeError:
return false()
if len(filter.values) == 0:
return false()
elif len(filter.values) == 1:
return column == filter.values[0]
else:
return column.in_(filter.values)

def _filters_to_sql(self, filters: List[Filter]) -> ColumnElement:
qs = [self._filter_to_sql(x) for x in filters]
if self.multitenant:
qs.append(self.table.c.tenant == self.current_tenant)
return and_(*qs)

def _id_filter_to_sql(self, id: Id) -> ColumnElement:
return self._filters_to_sql([Filter(field="id", values=[id])])
return bool(await self.execute(self.builder.delete(id)))

async def filter(
self, filters: List[Filter], params: Optional[PageOptions] = None
) -> List[Json]:
query = select(self.table).where(self._filters_to_sql(filters))
if params is not None:
sort = asc(params.order_by) if params.ascending else desc(params.order_by)
query = query.order_by(sort).limit(params.limit).offset(params.offset)
query = self.builder.select(filters, params)
if self.has_related:
async with self.transaction() as transaction:
result = await transaction.execute(query)
Expand All @@ -228,21 +152,10 @@ async def filter(
return result

async def count(self, filters: List[Filter]) -> int:
query = (
select(func.count().label("count"))
.select_from(self.table)
.where(self._filters_to_sql(filters))
)
return (await self.execute(query))[0]["count"]
return (await self.execute(self.builder.count(filters)))[0]["count"]

async def exists(self, filters: List[Filter]) -> bool:
query = (
select(true().label("exists"))
.select_from(self.table)
.where(self._filters_to_sql(filters))
.limit(1)
)
return len(await self.execute(query)) > 0
return len(await self.execute(self.builder.exists(filters))) > 0

async def _get_related_one_to_many(
self,
Expand Down

0 comments on commit 1d9481d

Please sign in to comment.