Skip to content

Commit

Permalink
add support for partial indexes
Browse files Browse the repository at this point in the history
Co-authored-by: Emanuel Lupi <[email protected]>
  • Loading branch information
timgraham and WaVEV committed Nov 19, 2024
1 parent c49e345 commit e543e0f
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 14 deletions.
1 change: 1 addition & 0 deletions .github/workflows/runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"get_or_create",
"i18n",
"indexes",
"indexes_",
"inline_formsets",
"introspection",
"invalid_models_tests",
Expand Down
2 changes: 2 additions & 0 deletions django_mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from .expressions import register_expressions # noqa: E402
from .fields import register_fields # noqa: E402
from .functions import register_functions # noqa: E402
from .indexes import register_indexes # noqa: E402
from .lookups import register_lookups # noqa: E402
from .query import register_nodes # noqa: E402

register_aggregates()
register_expressions()
register_fields()
register_functions()
register_indexes()
register_lookups()
register_nodes()
11 changes: 10 additions & 1 deletion django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,16 @@ def base_table(self):

@cached_property
def collection_name(self):
return self.base_table.table_alias or self.base_table.table_name
try:
base_table = self.base_table
except StopIteration:
# Use a dummy collection if the query doesn't specify a table
# (such as Constraint.validate() with a condition).
query = self.query_class(self)
query.aggregation_pipeline = [{"$facet": {"__null": []}}]
self.subqueries.insert(0, query)
return "__null"
return base_table.table_alias or base_table.table_name

@cached_property
def collection(self):
Expand Down
7 changes: 5 additions & 2 deletions django_mongodb/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def case(self, compiler, connection):
def col(self, compiler, connection): # noqa: ARG001
# If the column is part of a subquery and belongs to one of the parent
# queries, it will be stored for reference using $let in a $lookup stage.
if (
# If the query is built with `alias_cols=False`, treat the column as
# belonging to the current collection.
if self.alias is not None and (
self.alias not in compiler.query.alias_refcount
or compiler.query.alias_refcount[self.alias] == 0
):
Expand All @@ -64,7 +66,8 @@ def col(self, compiler, connection): # noqa: ARG001
compiler.column_indices[self] = index
return f"$${compiler.PARENT_FIELD_TEMPLATE.format(index)}"
# Add the column's collection's alias for columns in joined collections.
prefix = f"{self.alias}." if self.alias != compiler.collection_name else ""
has_alias = self.alias and self.alias != compiler.collection_name
prefix = f"{self.alias}." if has_alias else ""
return f"${prefix}{self.target.column}"


Expand Down
20 changes: 15 additions & 5 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import operator

from django.db.backends.base.features import BaseDatabaseFeatures
from django.utils.functional import cached_property

Expand All @@ -21,12 +23,14 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_expression_indexes = False
supports_foreign_keys = False
supports_ignore_conflicts = False
# Before MongoDB 6.0, $in cannot be used in partialFilterExpression.
supports_in_index_operator = property(operator.attrgetter("is_mongodb_6_0"))
# Before MongoDB 6.0, $or cannot be used in partialFilterExpression.
supports_or_index_operator = property(operator.attrgetter("is_mongodb_6_0"))
supports_json_field_contains = False
# BSON Date type doesn't support microsecond precision.
supports_microsecond_precision = False
supports_paramstyle_pyformat = False
# Not implemented.
supports_partial_indexes = False
supports_select_difference = False
supports_select_intersection = False
supports_sequence_reset = False
Expand Down Expand Up @@ -91,11 +95,16 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null",
"expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or",
}
_django_test_expected_failures_partial_expression_in = {
"schema.tests.SchemaTests.test_remove_ignored_unique_constraint_not_create_fk_index",
}

@cached_property
def django_test_expected_failures(self):
expected_failures = super().django_test_expected_failures
expected_failures.update(self._django_test_expected_failures)
if not self.supports_in_index_operator:
expected_failures.update(self._django_test_expected_failures_partial_expression_in)
if not self.is_mongodb_6_3:
expected_failures.update(self._django_test_expected_failures_bitwise)
return expected_failures
Expand Down Expand Up @@ -560,9 +569,6 @@ def django_test_expected_failures(self):
# Probably something to do with lack of transaction support.
"migration_test_data_persistence.tests.MigrationDataNormalPersistenceTestCase.test_persistence",
},
"Partial indexes to be supported.": {
"indexes.tests.PartialIndexConditionIgnoredTests.test_condition_ignored",
},
"Database caching not implemented.": {
"cache.tests.CreateCacheTableForDBCacheTests",
"cache.tests.DBCacheTests",
Expand Down Expand Up @@ -597,6 +603,10 @@ def django_test_expected_failures(self):
},
}

@cached_property
def is_mongodb_6_0(self):
return self.connection.get_database_version() >= (6, 0)

@cached_property
def is_mongodb_6_3(self):
return self.connection.get_database_version() >= (6, 3)
74 changes: 74 additions & 0 deletions django_mongodb/indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from django.db import NotSupportedError
from django.db.models import Index
from django.db.models.fields.related_lookups import In
from django.db.models.lookups import BuiltinLookup
from django.db.models.sql.query import Query
from django.db.models.sql.where import AND, XOR, WhereNode

from .query_utils import process_rhs

MONGO_INDEX_OPERATORS = {
"exact": "$eq",
"gt": "$gt",
"gte": "$gte",
"lt": "$lt",
"lte": "$lte",
"in": "$in",
}


def _get_condition_mql(self, model, schema_editor):
"""Analogous to Index._get_condition_sql()."""
query = Query(model=model, alias_cols=False)
where = query.build_where(self.condition)
compiler = query.get_compiler(connection=schema_editor.connection)
return where.as_mql_idx(compiler, schema_editor.connection)


def builtin_lookup_idx(self, compiler, connection):
lhs_mql = self.lhs.target.column
value = process_rhs(self, compiler, connection)
try:
operator = MONGO_INDEX_OPERATORS[self.lookup_name]
except KeyError:
raise NotSupportedError(
f"MongoDB does not support the '{self.lookup_name}' lookup in indexes."
) from None
return {lhs_mql: {operator: value}}


def in_idx(self, compiler, connection):
if not connection.features.supports_in_index_operator:
raise NotSupportedError("MongoDB < 6.0 does not support the 'in' lookup in indexes.")
return builtin_lookup_idx(self, compiler, connection)


def where_node_idx(self, compiler, connection):
if self.connector == AND:
operator = "$and"
elif self.connector == XOR:
raise NotSupportedError("MongoDB does not support the '^' operator lookup in indexes.")
else:
if not connection.features.supports_in_index_operator:
raise NotSupportedError("MongoDB < 6.0 does not support the '|' operator in indexes.")
operator = "$or"
if self.negated:
raise NotSupportedError("MongoDB does not support the '~' operator in indexes.")
children_mql = []
for child in self.children:
mql = child.as_mql_idx(compiler, connection)
children_mql.append(mql)
if len(children_mql) == 1:
mql = children_mql[0]
elif len(children_mql) > 1:
mql = {operator: children_mql}
else:
mql = {}
return mql


def register_indexes():
BuiltinLookup.as_mql_idx = builtin_lookup_idx
In.as_mql_idx = in_idx
Index._get_condition_mql = _get_condition_mql
WhereNode.as_mql_idx = where_node_idx
28 changes: 22 additions & 6 deletions django_mongodb/schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import defaultdict

from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.models import Index, UniqueConstraint
from pymongo import ASCENDING, DESCENDING
Expand Down Expand Up @@ -166,17 +168,23 @@ def add_index(self, model, index, field=None, unique=False):
if index.contains_expressions:
return
kwargs = {}
filter_expression = defaultdict(dict)
if index.condition:
filter_expression.update(index._get_condition_mql(model, self))
if unique:
filter_expression = {}
kwargs["unique"] = True
# Indexing on $type matches the value of most SQL databases by
# allowing multiple null values for the unique constraint.
if field:
filter_expression[field.column] = {"$type": field.db_type(self.connection)}
filter_expression[field.column].update({"$type": field.db_type(self.connection)})
else:
for field_name, _ in index.fields_orders:
field_ = model._meta.get_field(field_name)
filter_expression[field_.column] = {"$type": field_.db_type(self.connection)}
kwargs = {"partialFilterExpression": filter_expression, "unique": True}
filter_expression[field_.column].update(
{"$type": field_.db_type(self.connection)}
)
if filter_expression:
kwargs["partialFilterExpression"] = filter_expression
index_orders = (
[(field.column, ASCENDING)]
if field
Expand Down Expand Up @@ -260,7 +268,11 @@ def add_constraint(self, model, constraint, field=None):
expressions=constraint.expressions,
nulls_distinct=constraint.nulls_distinct,
):
idx = Index(fields=constraint.fields, name=constraint.name)
idx = Index(
fields=constraint.fields,
name=constraint.name,
condition=constraint.condition,
)
self.add_index(model, idx, field=field, unique=True)

def _add_field_unique(self, model, field):
Expand All @@ -276,7 +288,11 @@ def remove_constraint(self, model, constraint):
expressions=constraint.expressions,
nulls_distinct=constraint.nulls_distinct,
):
idx = Index(fields=constraint.fields, name=constraint.name)
idx = Index(
fields=constraint.fields,
name=constraint.name,
condition=constraint.condition,
)
self.remove_index(model, idx)

def _remove_field_unique(self, model, field):
Expand Down
Empty file added tests/indexes_/__init__.py
Empty file.
7 changes: 7 additions & 0 deletions tests/indexes_/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from django.db import models


class Article(models.Model):
headline = models.CharField(max_length=100)
number = models.IntegerField()
body = models.TextField()
126 changes: 126 additions & 0 deletions tests/indexes_/test_condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import operator

from django.db import NotSupportedError, connection
from django.db.models import Index, Q
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature

from .models import Article


class PartialIndexTests(TestCase):
def assertAddRemoveIndex(self, editor, model, index):
editor.add_index(index=index, model=model)
self.assertIn(
index.name,
connection.introspection.get_constraints(
cursor=None,
table_name=model._meta.db_table,
),
)
editor.remove_index(index=index, model=model)

def test_not_supported(self):
msg = "MongoDB does not support the 'isnull' lookup in indexes."
with connection.schema_editor() as editor, self.assertRaisesMessage(NotSupportedError, msg):
Index(
name="test",
fields=["headline"],
condition=Q(pk__isnull=True),
)._get_condition_mql(Article, schema_editor=editor)

def test_negated_not_supported(self):
msg = "MongoDB does not support the '~' operator in indexes."
with self.assertRaisesMessage(NotSupportedError, msg), connection.schema_editor() as editor:
Index(
name="test",
fields=["headline"],
condition=~Q(pk=True),
)._get_condition_mql(Article, schema_editor=editor)

def test_xor_not_supported(self):
msg = "MongoDB does not support the '^' operator lookup in indexes."
with self.assertRaisesMessage(NotSupportedError, msg), connection.schema_editor() as editor:
Index(
name="test",
fields=["headline"],
condition=Q(pk=True) ^ Q(pk=False),
)._get_condition_mql(Article, schema_editor=editor)

@skipIfDBFeature("supports_or_index_operator")
def test_or_not_supported(self):
msg = "MongoDB < 6.0 does not support the '|' operator in indexes."
with self.assertRaisesMessage(NotSupportedError, msg), connection.schema_editor() as editor:
Index(
name="test",
fields=["headline"],
condition=Q(pk=True) | Q(pk=False),
)._get_condition_mql(Article, schema_editor=editor)

@skipIfDBFeature("supports_in_index_operator")
def test_in_not_supported(self):
msg = "MongoDB < 6.0 does not support the 'in' lookup in indexes."
with self.assertRaisesMessage(NotSupportedError, msg), connection.schema_editor() as editor:
Index(
name="test",
fields=["headline"],
condition=Q(pk__in=[True]),
)._get_condition_mql(Article, schema_editor=editor)

def test_operations(self):
operators = (
("gt", "$gt"),
("gte", "$gte"),
("lt", "$lt"),
("lte", "$lte"),
)
for op, mongo_operator in operators:
with self.subTest(operator=op), connection.schema_editor() as editor:
index = Index(
name="test",
fields=["headline"],
condition=Q(**{f"number__{op}": 3}),
)
self.assertEqual(
{"number": {mongo_operator: 3}},
index._get_condition_mql(Article, schema_editor=editor),
)
self.assertAddRemoveIndex(editor, Article, index)

@skipUnlessDBFeature("supports_in_index_operator")
def test_composite_index(self):
with connection.schema_editor() as editor:
index = Index(
name="test",
fields=["headline"],
condition=Q(number__gte=3) & (Q(body__gt="test1") | Q(body__in=["A", "B"])),
)
self.assertEqual(
index._get_condition_mql(Article, schema_editor=editor),
{
"$and": [
{"number": {"$gte": 3}},
{"$or": [{"body": {"$gt": "test1"}}, {"body": {"$in": ["A", "B"]}}]},
]
},
)
self.assertAddRemoveIndex(editor, Article, index)

def test_composite_op_index(self):
operators = (
(operator.or_, "$or"),
(operator.and_, "$and"),
)
if not connection.features.supports_or_index_operator:
operators = operators[1:]
for op, mongo_operator in operators:
with self.subTest(operator=op), connection.schema_editor() as editor:
index = Index(
name="test",
fields=["headline"],
condition=op(Q(number__gte=3), Q(body__gt="test1")),
)
self.assertEqual(
{mongo_operator: [{"number": {"$gte": 3}}, {"body": {"$gt": "test1"}}]},
index._get_condition_mql(Article, schema_editor=editor),
)
self.assertAddRemoveIndex(editor, Article, index)

0 comments on commit e543e0f

Please sign in to comment.