Skip to content

Commit

Permalink
Fixed #2 -- Added an abstract base that don't rely on ContentTypes.
Browse files Browse the repository at this point in the history
  • Loading branch information
charettes committed Jan 22, 2016
1 parent e7f624d commit cf46f8f
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 34 deletions.
8 changes: 4 additions & 4 deletions polymodels/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ class PolymorphicManyToOneRel(ManyToOneRel):

@property
def limit_choices_to(self):
subclasses_lookup = self.polymorphic_type.subclasses_lookup('pk')
subclasses_filter = self.polymorphic_type.get_subclasses_filter(query_name='pk')
limit_choices_to = self._limit_choices_to
if limit_choices_to is None:
limit_choices_to = subclasses_lookup
limit_choices_to = subclasses_filter
elif isinstance(limit_choices_to, dict):
limit_choices_to = dict(limit_choices_to, **subclasses_lookup)
limit_choices_to = Q(**limit_choices_to) & subclasses_filter
elif isinstance(limit_choices_to, Q):
limit_choices_to = limit_choices_to & Q(**subclasses_lookup)
limit_choices_to = limit_choices_to & subclasses_filter
self.__dict__['limit_choices_to'] = limit_choices_to
return limit_choices_to

Expand Down
8 changes: 3 additions & 5 deletions polymodels/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def select_subclasses(self, *models):
related = accessors[subclass][2]
if related:
relateds.add(related)
queryset = self.filter(
**self.model.content_type_lookup(*tuple(subclasses))
)
queryset = self.filter(self.model.get_subclasses_filter(*tuple(subclasses)))
else:
# Collect all `select_related` required relateds
for accessor in accessors.values():
Expand All @@ -39,7 +37,7 @@ def select_subclasses(self, *models):
return queryset

def exclude_subclasses(self):
return self.filter(**self.model.content_type_lookup())
return self.filter(self.model.get_subclasses_filter(self.model))

def _clone(self, *args, **kwargs):
kwargs.update(type_cast=getattr(self, 'type_cast', False))
Expand Down Expand Up @@ -75,5 +73,5 @@ def get_queryset(self):
opts = model._meta
if opts.proxy:
# Select only associated model and its subclasses.
queryset = queryset.filter(**self.model.subclasses_lookup())
queryset = queryset.filter(self.model.get_subclasses_filter())
return queryset
55 changes: 35 additions & 20 deletions polymodels/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,30 @@ def __missing__(self, model_key):
return accessors


class BasePolymorphicModel(models.Model):
class AbstractPolymorphicModel(models.Model):
class Meta:
abstract = True

subclass_accessors = SubclassAccessors()

def get_type(self):
raise NotImplementedError

def set_type(self):
raise NotImplementedError

@classmethod
def get_subclasses_filter(cls, *subclasses, **kwargs):
raise NotImplementedError

def save(self, *args, **kwargs):
if self.pk is None:
self.set_type()
return super(AbstractPolymorphicModel, self).save(*args, **kwargs)

def type_cast(self, to=None):
if to is None:
content_type_id = getattr(self, "%s_id" % self.CONTENT_TYPE_FIELD)
to = ContentType.objects.get_for_id(content_type_id).model_class()
to = self.get_type()
attrs, proxy, _lookup = self.subclass_accessors[to]
# Cast to the right concrete model by going up in the
# SingleRelatedObjectDescriptor chain
Expand All @@ -84,27 +98,28 @@ def type_cast(self, to=None):
type_casted = copy_fields(type_casted, proxy)
return type_casted

def save(self, *args, **kwargs):
if self.pk is None:
content_type = get_content_type(self.__class__)
setattr(self, self.CONTENT_TYPE_FIELD, content_type)
return super(BasePolymorphicModel, self).save(*args, **kwargs)

@classmethod
def content_type_lookup(cls, *models, **kwargs):
query_name = kwargs.pop('query_name', None) or cls.CONTENT_TYPE_FIELD
if models:
query_name = "%s__in" % query_name
value = set(ct.pk for ct in get_content_types(*models).values())
else:
value = get_content_type(cls).pk
return {query_name: value}
class BasePolymorphicModel(AbstractPolymorphicModel):
class Meta:
abstract = True

def get_type(self):
content_type_id = getattr(self, "%s_id" % self.CONTENT_TYPE_FIELD)
return ContentType.objects.get_for_id(content_type_id).model_class()

def set_type(self):
content_type = get_content_type(self.__class__)
setattr(self, self.CONTENT_TYPE_FIELD, content_type)

@classmethod
def subclasses_lookup(cls, query_name=None):
return cls.content_type_lookup(
cls, *tuple(cls.subclass_accessors), query_name=query_name
def get_subclasses_filter(cls, *subclasses, **kwargs):
if not subclasses:
subclasses = set(cls.subclass_accessors)
query_name = "%s__in" % (
kwargs.pop('query_name', None) or cls.CONTENT_TYPE_FIELD
)
value = set(ct.pk for ct in get_content_types(*subclasses).values())
return models.Q(**{query_name: value})

@classmethod
def check(cls, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from django.contrib.contenttypes.models import ContentType
from django.test.testcases import TestCase

from polymodels.models import BasePolymorphicModel
from polymodels.models import AbstractPolymorphicModel


class TestCase(TestCase):
def tearDown(self):
ContentType.objects.clear_cache()
BasePolymorphicModel.subclass_accessors.clear()
AbstractPolymorphicModel.subclass_accessors.clear()
9 changes: 6 additions & 3 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,12 @@ def test_limit_choices_to(self):
remote_field.limit_choices_to = extra_limit_choices_to
# Cache should be cleared
self.assertNotIn('limit_choices_to', remote_field.__dict__)
remote_field_limit_choices_to = remote_field.limit_choices_to
self.assertEqual(remote_field_limit_choices_to.connector, Q.AND)
self.assertFalse(remote_field_limit_choices_to.negated)
self.assertEqual(
remote_field.limit_choices_to,
dict(extra_limit_choices_to, **limit_choices_to)
remote_field_limit_choices_to.children,
list(extra_limit_choices_to.items()) + limit_choices_to.children
)

# Make sure it works with existing Q `limit_choices_to`
Expand All @@ -70,7 +73,7 @@ def test_limit_choices_to(self):
self.assertFalse(remote_field_limit_choices_to.negated)
self.assertEqual(
remote_field_limit_choices_to.children,
list(extra_limit_choices_to.items()) + list(limit_choices_to.items())
list(extra_limit_choices_to.items()) + limit_choices_to.children
)

# Re-assign the original value
Expand Down

0 comments on commit cf46f8f

Please sign in to comment.