Skip to content

Commit

Permalink
Update query.py (#6)
Browse files Browse the repository at this point in the history
* Update query.py

Add `pre_filter` and `pre_exclude` methods to TreeQuerySet.

Make query methods call `_setup_query` to deal with sibling order and pre_filter persistence issues.


* Update compiler.py

Replace the `get_sibling_order_params` with `get_rank_table_params` to support early tree filtering.

Change `sibling_order` and `pre_filter` from class variables to variables that have to be initiated by `_setup_query` so that they don't persist between user queries.

Handle pre_filter params by passing them to the django backend.


* Update test_queries.py

Added tests for the `pre_filter` and `pre_exclude` methods.
  • Loading branch information
rhomboss authored Apr 15, 2024
1 parent a9f2cba commit 6a95af1
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 39 deletions.
118 changes: 118 additions & 0 deletions tests/testapp/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,3 +753,121 @@ def test_order_by_related(self):
tree.child2_2,
],
)

def test_pre_exclude(self):
tree = self.create_tree()
# Pre-filter should remove children if
# the parent meets the filtering criteria
nodes = Model.objects.pre_exclude(name="2")
self.assertEqual(
list(nodes),
[
tree.root,
tree.child1,
tree.child1_1,
],
)

def test_pre_filter(self):
tree = self.create_tree()
# Pre-filter should remove children if
# the parent does not meet the filtering criteria
nodes = Model.objects.pre_filter(name__in=["root","1-1","2","2-1","2-2"])
self.assertEqual(
list(nodes),
[
tree.root,
tree.child2,
tree.child2_1,
tree.child2_2,
],
)

def test_pre_filter_chaining(self):
tree = self.create_tree()
# Pre-filter should remove children if
# the parent does not meet the filtering criteria
nodes = Model.objects.pre_exclude(name="2-2").pre_filter(name__in=["root","1-1","2","2-1","2-2"])
self.assertEqual(
list(nodes),
[
tree.root,
tree.child2,
tree.child2_1,
],
)

def test_pre_filter_related(self):
tree = type("Namespace", (), {})() # SimpleNamespace for PY2...

tree.root = RelatedOrderModel.objects.create(name="root")
tree.root_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.root, order=0
)
tree.child1 = RelatedOrderModel.objects.create(parent=tree.root, name="1")
tree.child1_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child1, order=0
)
tree.child2 = RelatedOrderModel.objects.create(parent=tree.root, name="2")
tree.child2_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child2, order=1
)
tree.child1_1 = RelatedOrderModel.objects.create(parent=tree.child1, name="1-1")
tree.child1_1_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child1_1, order=0
)
tree.child2_1 = RelatedOrderModel.objects.create(parent=tree.child2, name="2-1")
tree.child2_1_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child2_1, order=0
)
tree.child2_2 = RelatedOrderModel.objects.create(parent=tree.child2, name="2-2")
tree.child2_2_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child2_2, order=1
)

nodes = RelatedOrderModel.objects.pre_filter(related__order=0)
self.assertEqual(
list(nodes),
[
tree.root,
tree.child1,
tree.child1_1,
],
)

def test_pre_filter_with_order(self):
tree = type("Namespace", (), {})() # SimpleNamespace for PY2...

tree.root = MultiOrderedModel.objects.create(
name="root", first_position=1,
)
tree.child1 = MultiOrderedModel.objects.create(
parent=tree.root, first_position=0, second_position=1, name="1"
)
tree.child2 = MultiOrderedModel.objects.create(
parent=tree.root, first_position=1, second_position=0, name="2"
)
tree.child1_1 = MultiOrderedModel.objects.create(
parent=tree.child1, first_position=1, second_position=1, name="1-1"
)
tree.child2_1 = MultiOrderedModel.objects.create(
parent=tree.child2, first_position=1, second_position=1, name="2-1"
)
tree.child2_2 = MultiOrderedModel.objects.create(
parent=tree.child2, first_position=1, second_position=0, name="2-2"
)

nodes = (
MultiOrderedModel.objects
.pre_filter(first_position__gt=0)
.order_siblings_by("-second_position")
)
self.assertEqual(
list(nodes),
[
tree.root,
tree.child2,
tree.child2_1,
tree.child2_2,
],
)
126 changes: 87 additions & 39 deletions tree_queries/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,34 @@ def _find_tree_model(cls):


class TreeQuery(Query):
# Set by TreeQuerySet.order_siblings_by
sibling_order = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._setup_query()

def _setup_query(self):
"""
Run on initialization and at the end of chaining. Any attributes that
would normally be set in __init__() should go here instead.
"""
# We add the variables for `sibling_order` and `pre_filter` here so they
# act as instance variables which do not persist between user queries
# the way class variables do

# Only add the sibling_order attribute if the query doesn't already have one to preserve cloning behavior
if not hasattr(self, "sibling_order"):
# Add an attribute to control the ordering of siblings within trees
opts = _find_tree_model(self.model)._meta
self.sibling_order = (
opts.ordering
if opts.ordering
else opts.pk.attname
)

# Only add the pre_filter attribute if the query doesn't already have one to preserve cloning behavior
if not hasattr(self, "pre_filter"):
self.pre_filter = []


def get_compiler(self, using=None, connection=None, **kwargs):
# Copied from django/db/models/sql/query.py
Expand All @@ -28,13 +54,10 @@ def get_compiler(self, using=None, connection=None, **kwargs):
return TreeCompiler(self, connection, using, **kwargs)

def get_sibling_order(self):
if self.sibling_order is not None:
return self.sibling_order
opts = _find_tree_model(self.model)._meta
if opts.ordering:
return opts.ordering
return opts.pk.attname
return self.sibling_order

def get_pre_filter(self):
return self.pre_filter

class TreeCompiler(SQLCompiler):
CTE_POSTGRESQL = """
Expand All @@ -48,6 +71,7 @@ class TreeCompiler(SQLCompiler):
{rank_parent},
ROW_NUMBER() OVER (ORDER BY {rank_order_by})
FROM {rank_from}
{pre_filter}
),
__tree (
"tree_depth",
Expand Down Expand Up @@ -82,6 +106,7 @@ class TreeCompiler(SQLCompiler):
{rank_parent},
ROW_NUMBER() OVER (ORDER BY {rank_order_by})
FROM {rank_from}
{pre_filter}
),
__tree(tree_depth, tree_path, tree_ordering, tree_pk) AS (
SELECT
Expand Down Expand Up @@ -113,6 +138,7 @@ class TreeCompiler(SQLCompiler):
{rank_parent},
row_number() OVER (ORDER BY {rank_order_by})
FROM {rank_from}
{pre_filter}
),
__tree(tree_depth, tree_path, tree_ordering, tree_pk) AS (
SELECT
Expand All @@ -135,13 +161,14 @@ class TreeCompiler(SQLCompiler):
)
"""

def get_sibling_order_params(self):
def get_rank_table_params(self):
"""
This method uses a simple django queryset to generate sql
that can be used to create the __rank_table that orders
siblings. This is done so that any joins required by order_by
are pre-calculated by django
that can be used to create the __rank_table that pre-filters
and orders siblings. This is done so that any joins required
by order_by or filter/exclude are pre-calculated by django
"""
# Get can validate sibling_order
sibling_order = self.query.get_sibling_order()

if isinstance(sibling_order, (list, tuple)):
Expand All @@ -152,39 +179,57 @@ def get_sibling_order_params(self):
raise ValueError(
"Sibling order must be a string or a list or tuple of strings."
)

# Use Django to make a SQL query whose parts can be repurposed for __rank_table
base_query = (
_find_tree_model(self.query.model)
.objects.only("pk", "parent")
.order_by(*order_fields)
.query
)

# Use the base compiler because we want vanilla sql and want to avoid recursion.

# Get pre_filter
pre_filter = self.query.get_pre_filter()

# Use Django to make a SQL query that can be repurposed for __rank_table
base_query = _find_tree_model(self.query.model).objects.only("pk", "parent")

# Add pre_filters if they exist
if pre_filter:
# Apply filters and excludes to the query in the order provided by the user
for is_filter, filter_fields in pre_filter:
if is_filter:
base_query = base_query.filter(**filter_fields)
else:
base_query = base_query.exclude(**filter_fields)

# Apply sibling_order
base_query = base_query.order_by(*order_fields).query

# Get SQL and parameters
base_compiler = SQLCompiler(base_query, self.connection, None)
base_sql, base_params = base_compiler.as_sql()
result_sql = base_sql % base_params

# Split the base SQL string on the SQL keywords 'FROM' and 'ORDER BY'
from_split = result_sql.split("FROM")
order_split = from_split[1].split("ORDER BY")
# Split sql on the last ORDER BY to get the rank_order param
head, sep, tail = base_sql.rpartition("ORDER BY")

# Identify the FROM and ORDER BY parts of the base SQL
ordering_params = {
"rank_from": order_split[0].strip(),
"rank_order_by": order_split[1].strip(),
# Add rank_order_by to params
rank_table_params = {
"rank_order_by": tail.strip(),
}

# Identify the primary key field and parent_id field from the SELECT section
base_select = from_split[0][6:]
for field in base_select.split(","):
# Split on the first WHERE if present to get the pre_filter param
if pre_filter:
head, sep, tail = head.partition("WHERE")
rank_table_params["pre_filter"] = "WHERE " + tail.strip() # Note the space after WHERE
else:
rank_table_params["pre_filter"] = ""

# Split on the first FROM to get any joins etc.
head, sep, tail = head.partition("FROM")
rank_table_params["rank_from"] = tail.strip()

# Identify the parent and primary key fields
head, sep, tail = head.partition("SELECT")
for field in tail.split(","):
if "parent_id" in field: # XXX Taking advantage of Hardcoded.
ordering_params["rank_parent"] = field.strip()
rank_table_params["rank_parent"] = field.strip()
else:
ordering_params["rank_pk"] = field.strip()
rank_table_params["rank_pk"] = field.strip()

return ordering_params
return rank_table_params, base_params

def as_sql(self, *args, **kwargs):
# Try detecting if we're used in a EXISTS(1 as "a") subquery like
Expand Down Expand Up @@ -229,8 +274,9 @@ def as_sql(self, *args, **kwargs):
"sep": SEPARATOR,
}

# Add ordering params to params
params.update(self.get_sibling_order_params())
# Get params needed by the rank_table
rank_table_params, rank_table_sql_params = self.get_rank_table_params()
params.update(rank_table_params)

if "__tree" not in self.query.extra_tables: # pragma: no branch - unlikely
tree_params = params.copy()
Expand Down Expand Up @@ -280,7 +326,9 @@ def as_sql(self, *args, **kwargs):
if sql_0.startswith("EXPLAIN "):
explain, sql_0 = sql_0.split(" ", 1)

return ("".join([explain, cte.format(**params), sql_0]), sql_1)
# Pass any additional rank table sql paramaters so that the db backend can handle them.
# This only works because we know that the CTE is at the start of the query.
return ("".join([explain, cte.format(**params), sql_0]), rank_table_sql_params + sql_1)

def get_converters(self, expressions):
converters = super().get_converters(expressions)
Expand Down
26 changes: 26 additions & 0 deletions tree_queries/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def with_tree_fields(self, tree_fields=True): # noqa: FBT002
"""
if tree_fields:
self.query.__class__ = TreeQuery
self.query._setup_query()
else:
self.query.__class__ = Query
return self
Expand All @@ -45,9 +46,34 @@ def order_siblings_by(self, *order_by):
to order tree siblings by those model fields
"""
self.query.__class__ = TreeQuery
self.query._setup_query()
self.query.sibling_order = order_by
return self

def pre_filter(self, **filter):
"""
Sets TreeQuery pre_filter attribute
Pass a dict of fields and their values to filter by
"""
self.query.__class__ = TreeQuery
self.query._setup_query()
filter_tuple = (True, filter)
self.query.pre_filter.append(filter_tuple)
return self

def pre_exclude(self, **filter):
"""
Sets TreeQuery pre_filter attribute
Pass a dict of fields and their values to filter by
"""
self.query.__class__ = TreeQuery
self.query._setup_query()
exclude_tuple = (False, filter)
self.query.pre_filter.append(exclude_tuple)
return self

def as_manager(cls, *, with_tree_fields=False): # noqa: N805
manager_class = TreeManager.from_queryset(cls)
# Only used in deconstruct:
Expand Down

0 comments on commit 6a95af1

Please sign in to comment.