Skip to content

Commit

Permalink
Update compiler.py
Browse files Browse the repository at this point in the history
Change `pre_filter` to `tree_filter` and add support for `Q` objects
  • Loading branch information
rhomboss authored Apr 15, 2024
1 parent 6a95af1 commit 252a4fe
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions tree_queries/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _setup_query(self):
Run on initialization and at the end of chaining. Any attributes that
would normally be set in __init__() should go here instead.
"""
# We add the variables for `sibling_order` and `pre_filter` here so they
# We add the variables for `sibling_order` and `tree_filter` here so they
# act as instance variables which do not persist between user queries
# the way class variables do

Expand All @@ -36,9 +36,9 @@ def _setup_query(self):
else opts.pk.attname
)

# Only add the pre_filter attribute if the query doesn't already have one to preserve cloning behavior
if not hasattr(self, "pre_filter"):
self.pre_filter = []
# Only add the tree_filter attribute if the query doesn't already have one to preserve cloning behavior
if not hasattr(self, "tree_filter"):
self.tree_filter = []


def get_compiler(self, using=None, connection=None, **kwargs):
Expand All @@ -56,8 +56,8 @@ def get_compiler(self, using=None, connection=None, **kwargs):
def get_sibling_order(self):
return self.sibling_order

def get_pre_filter(self):
return self.pre_filter
def get_tree_filter(self):
return self.tree_filter

class TreeCompiler(SQLCompiler):
CTE_POSTGRESQL = """
Expand All @@ -71,7 +71,7 @@ class TreeCompiler(SQLCompiler):
{rank_parent},
ROW_NUMBER() OVER (ORDER BY {rank_order_by})
FROM {rank_from}
{pre_filter}
{tree_filter}
),
__tree (
"tree_depth",
Expand Down Expand Up @@ -106,7 +106,7 @@ class TreeCompiler(SQLCompiler):
{rank_parent},
ROW_NUMBER() OVER (ORDER BY {rank_order_by})
FROM {rank_from}
{pre_filter}
{tree_filter}
),
__tree(tree_depth, tree_path, tree_ordering, tree_pk) AS (
SELECT
Expand Down Expand Up @@ -138,7 +138,7 @@ class TreeCompiler(SQLCompiler):
{rank_parent},
row_number() OVER (ORDER BY {rank_order_by})
FROM {rank_from}
{pre_filter}
{tree_filter}
),
__tree(tree_depth, tree_path, tree_ordering, tree_pk) AS (
SELECT
Expand All @@ -164,7 +164,7 @@ class TreeCompiler(SQLCompiler):
def get_rank_table_params(self):
"""
This method uses a simple django queryset to generate sql
that can be used to create the __rank_table that pre-filters
that can be used to create the __rank_table that tree-filters
and orders siblings. This is done so that any joins required
by order_by or filter/exclude are pre-calculated by django
"""
Expand All @@ -180,20 +180,20 @@ def get_rank_table_params(self):
"Sibling order must be a string or a list or tuple of strings."
)

# Get pre_filter
pre_filter = self.query.get_pre_filter()
# Get tree_filter
tree_filter = self.query.get_tree_filter()

# Use Django to make a SQL query that can be repurposed for __rank_table
base_query = _find_tree_model(self.query.model).objects.only("pk", "parent")

# Add pre_filters if they exist
if pre_filter:
# Add tree_filters if they exist
if tree_filter:
# Apply filters and excludes to the query in the order provided by the user
for is_filter, filter_fields in pre_filter:
for is_filter, filter_Q, filter_fields in tree_filter:
if is_filter:
base_query = base_query.filter(**filter_fields)
base_query = base_query.filter(*filter_Q, **filter_fields)
else:
base_query = base_query.exclude(**filter_fields)
base_query = base_query.exclude(*filter_Q, **filter_fields)

# Apply sibling_order
base_query = base_query.order_by(*order_fields).query
Expand All @@ -210,12 +210,12 @@ def get_rank_table_params(self):
"rank_order_by": tail.strip(),
}

# Split on the first WHERE if present to get the pre_filter param
if pre_filter:
# Split on the first WHERE if present to get the tree_filter param
if tree_filter:
head, sep, tail = head.partition("WHERE")
rank_table_params["pre_filter"] = "WHERE " + tail.strip() # Note the space after WHERE
rank_table_params["tree_filter"] = "WHERE " + tail.strip() # Note the space after WHERE
else:
rank_table_params["pre_filter"] = ""
rank_table_params["tree_filter"] = ""

# Split on the first FROM to get any joins etc.
head, sep, tail = head.partition("FROM")
Expand Down

0 comments on commit 252a4fe

Please sign in to comment.