From cf46f8f904ae0d808e8ae756ccc69a550e891977 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Fri, 22 Jan 2016 14:21:07 -0500 Subject: [PATCH] Fixed #2 -- Added an abstract base that don't rely on ContentTypes. --- polymodels/fields.py | 8 +++--- polymodels/managers.py | 8 +++--- polymodels/models.py | 55 +++++++++++++++++++++++++++--------------- tests/base.py | 4 +-- tests/test_fields.py | 9 ++++--- 5 files changed, 50 insertions(+), 34 deletions(-) diff --git a/polymodels/fields.py b/polymodels/fields.py index 8bb8e38..ecbb8c1 100644 --- a/polymodels/fields.py +++ b/polymodels/fields.py @@ -25,14 +25,14 @@ class PolymorphicManyToOneRel(ManyToOneRel): @property def limit_choices_to(self): - subclasses_lookup = self.polymorphic_type.subclasses_lookup('pk') + subclasses_filter = self.polymorphic_type.get_subclasses_filter(query_name='pk') limit_choices_to = self._limit_choices_to if limit_choices_to is None: - limit_choices_to = subclasses_lookup + limit_choices_to = subclasses_filter elif isinstance(limit_choices_to, dict): - limit_choices_to = dict(limit_choices_to, **subclasses_lookup) + limit_choices_to = Q(**limit_choices_to) & subclasses_filter elif isinstance(limit_choices_to, Q): - limit_choices_to = limit_choices_to & Q(**subclasses_lookup) + limit_choices_to = limit_choices_to & subclasses_filter self.__dict__['limit_choices_to'] = limit_choices_to return limit_choices_to diff --git a/polymodels/managers.py b/polymodels/managers.py index a08e4ba..73df190 100644 --- a/polymodels/managers.py +++ b/polymodels/managers.py @@ -23,9 +23,7 @@ def select_subclasses(self, *models): related = accessors[subclass][2] if related: relateds.add(related) - queryset = self.filter( - **self.model.content_type_lookup(*tuple(subclasses)) - ) + queryset = self.filter(self.model.get_subclasses_filter(*tuple(subclasses))) else: # Collect all `select_related` required relateds for accessor in accessors.values(): @@ -39,7 +37,7 @@ def select_subclasses(self, *models): return queryset def exclude_subclasses(self): - return self.filter(**self.model.content_type_lookup()) + return self.filter(self.model.get_subclasses_filter(self.model)) def _clone(self, *args, **kwargs): kwargs.update(type_cast=getattr(self, 'type_cast', False)) @@ -75,5 +73,5 @@ def get_queryset(self): opts = model._meta if opts.proxy: # Select only associated model and its subclasses. - queryset = queryset.filter(**self.model.subclasses_lookup()) + queryset = queryset.filter(self.model.get_subclasses_filter()) return queryset diff --git a/polymodels/models.py b/polymodels/models.py index b42bf08..b5975df 100644 --- a/polymodels/models.py +++ b/polymodels/models.py @@ -63,16 +63,30 @@ def __missing__(self, model_key): return accessors -class BasePolymorphicModel(models.Model): +class AbstractPolymorphicModel(models.Model): class Meta: abstract = True subclass_accessors = SubclassAccessors() + def get_type(self): + raise NotImplementedError + + def set_type(self): + raise NotImplementedError + + @classmethod + def get_subclasses_filter(cls, *subclasses, **kwargs): + raise NotImplementedError + + def save(self, *args, **kwargs): + if self.pk is None: + self.set_type() + return super(AbstractPolymorphicModel, self).save(*args, **kwargs) + def type_cast(self, to=None): if to is None: - content_type_id = getattr(self, "%s_id" % self.CONTENT_TYPE_FIELD) - to = ContentType.objects.get_for_id(content_type_id).model_class() + to = self.get_type() attrs, proxy, _lookup = self.subclass_accessors[to] # Cast to the right concrete model by going up in the # SingleRelatedObjectDescriptor chain @@ -84,27 +98,28 @@ def type_cast(self, to=None): type_casted = copy_fields(type_casted, proxy) return type_casted - def save(self, *args, **kwargs): - if self.pk is None: - content_type = get_content_type(self.__class__) - setattr(self, self.CONTENT_TYPE_FIELD, content_type) - return super(BasePolymorphicModel, self).save(*args, **kwargs) - @classmethod - def content_type_lookup(cls, *models, **kwargs): - query_name = kwargs.pop('query_name', None) or cls.CONTENT_TYPE_FIELD - if models: - query_name = "%s__in" % query_name - value = set(ct.pk for ct in get_content_types(*models).values()) - else: - value = get_content_type(cls).pk - return {query_name: value} +class BasePolymorphicModel(AbstractPolymorphicModel): + class Meta: + abstract = True + + def get_type(self): + content_type_id = getattr(self, "%s_id" % self.CONTENT_TYPE_FIELD) + return ContentType.objects.get_for_id(content_type_id).model_class() + + def set_type(self): + content_type = get_content_type(self.__class__) + setattr(self, self.CONTENT_TYPE_FIELD, content_type) @classmethod - def subclasses_lookup(cls, query_name=None): - return cls.content_type_lookup( - cls, *tuple(cls.subclass_accessors), query_name=query_name + def get_subclasses_filter(cls, *subclasses, **kwargs): + if not subclasses: + subclasses = set(cls.subclass_accessors) + query_name = "%s__in" % ( + kwargs.pop('query_name', None) or cls.CONTENT_TYPE_FIELD ) + value = set(ct.pk for ct in get_content_types(*subclasses).values()) + return models.Q(**{query_name: value}) @classmethod def check(cls, **kwargs): diff --git a/tests/base.py b/tests/base.py index 93a8ab8..f43d145 100644 --- a/tests/base.py +++ b/tests/base.py @@ -3,10 +3,10 @@ from django.contrib.contenttypes.models import ContentType from django.test.testcases import TestCase -from polymodels.models import BasePolymorphicModel +from polymodels.models import AbstractPolymorphicModel class TestCase(TestCase): def tearDown(self): ContentType.objects.clear_cache() - BasePolymorphicModel.subclass_accessors.clear() + AbstractPolymorphicModel.subclass_accessors.clear() diff --git a/tests/test_fields.py b/tests/test_fields.py index b447761..e1d4ce5 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -56,9 +56,12 @@ def test_limit_choices_to(self): remote_field.limit_choices_to = extra_limit_choices_to # Cache should be cleared self.assertNotIn('limit_choices_to', remote_field.__dict__) + remote_field_limit_choices_to = remote_field.limit_choices_to + self.assertEqual(remote_field_limit_choices_to.connector, Q.AND) + self.assertFalse(remote_field_limit_choices_to.negated) self.assertEqual( - remote_field.limit_choices_to, - dict(extra_limit_choices_to, **limit_choices_to) + remote_field_limit_choices_to.children, + list(extra_limit_choices_to.items()) + limit_choices_to.children ) # Make sure it works with existing Q `limit_choices_to` @@ -70,7 +73,7 @@ def test_limit_choices_to(self): self.assertFalse(remote_field_limit_choices_to.negated) self.assertEqual( remote_field_limit_choices_to.children, - list(extra_limit_choices_to.items()) + list(limit_choices_to.items()) + list(extra_limit_choices_to.items()) + limit_choices_to.children ) # Re-assign the original value