From 076898f575e32167822ee1831a8858f04aa44c4d Mon Sep 17 00:00:00 2001 From: Austin Riba Date: Thu, 6 Feb 2025 15:32:50 -0800 Subject: [PATCH] Factor out permissions filter and add logical test --- tom_targets/tests/tests.py | 53 +++++++++++++++++++++++++++++++++++++- tom_targets/views.py | 29 ++++++++++----------- 2 files changed, 66 insertions(+), 16 deletions(-) diff --git a/tom_targets/tests/tests.py b/tom_targets/tests/tests.py index 58fd4761..cfc36f4b 100644 --- a/tom_targets/tests/tests.py +++ b/tom_targets/tests/tests.py @@ -2,7 +2,7 @@ from datetime import datetime import responses -from django.contrib.auth.models import User, Group +from django.contrib.auth.models import AnonymousUser, User, Group from django.contrib.messages import get_messages from django.contrib.messages.constants import SUCCESS, WARNING from django.core.files.uploadedfile import SimpleUploadedFile @@ -18,6 +18,7 @@ from tom_targets.models import Target, TargetExtra, TargetList, TargetName from tom_targets.utils import import_targets from tom_targets.merge import target_merge +from tom_targets.views import target_permission_filter from tom_dataproducts.models import ReducedDatum, DataProduct from tom_observations.models import ObservationRecord from guardian.shortcuts import assign_perm, get_perms @@ -1946,3 +1947,53 @@ def test_seed_targets_unauthenticated(self): response = self.client.post(reverse('targets:seed')) self.assertEqual(response.status_code, 302) self.assertFalse(Target.objects.exists()) + + +class TestTargetPermissionFiltering(TestCase): + def setUp(self): + self.user = User.objects.create(username='testuser') + self.group = Group.objects.create(name='testgroup') + self.open_target = SiderealTargetFactory.create(permissions=Target.Permissions.OPEN) + self.public_target = SiderealTargetFactory.create(permissions=Target.Permissions.PUBLIC) + self.private_group_target = SiderealTargetFactory.create(permissions=Target.Permissions.PRIVATE) + self.private_user_target = SiderealTargetFactory.create(permissions=Target.Permissions.PRIVATE) + + def test_open_targets_visible(self): + result = target_permission_filter(AnonymousUser(), Target.objects.all()) + self.assertIn(self.open_target, result) + self.assertNotIn(self.public_target, result) + self.assertNotIn(self.private_group_target, result) + self.assertNotIn(self.private_user_target, result) + + def test_public_targets_visible(self): + result = target_permission_filter(self.user, Target.objects.all()) + self.assertIn(self.open_target, result) + self.assertIn(self.public_target, result) + self.assertNotIn(self.private_group_target, result) + self.assertNotIn(self.private_user_target, result) + + def test_private_group_permission(self): + self.group.user_set.add(self.user) + assign_perm('tom_targets.view_target', self.group, self.private_group_target) + result = target_permission_filter(self.user, Target.objects.all()) + self.assertIn(self.open_target, result) + self.assertIn(self.public_target, result) + self.assertIn(self.private_group_target, result) + self.assertNotIn(self.private_user_target, result) + + def test_private_user_permission(self): + assign_perm('tom_targets.view_target', self.user, self.private_user_target) + result = target_permission_filter(self.user, Target.objects.all()) + self.assertIn(self.open_target, result) + self.assertIn(self.public_target, result) + self.assertNotIn(self.private_group_target, result) + self.assertIn(self.private_user_target, result) + + def test_superuser_permission(self): + self.user.is_superuser = True + self.user.save() + result = target_permission_filter(self.user, Target.objects.all()) + self.assertIn(self.open_target, result) + self.assertIn(self.public_target, result) + self.assertIn(self.private_group_target, result) + self.assertIn(self.private_user_target, result) diff --git a/tom_targets/views.py b/tom_targets/views.py index c25c900a..59cd9464 100644 --- a/tom_targets/views.py +++ b/tom_targets/views.py @@ -61,6 +61,19 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) +def target_permission_filter(user, qs): + if user.is_authenticated: + if user.is_superuser: + # Do not filter the queryset by permissions at all + return qs + else: + # Exclude targets that are private except for those that the user has explicit permissions to view + private_targets = qs.filter(permissions=Target.Permissions.PRIVATE) + public_targets = qs.exclude(permissions=Target.Permissions.PRIVATE) + return public_targets | get_objects_for_user(user, f'{Target._meta.app_label}.view_target', private_targets) + else: + # Only allow open targets + return qs.exclude(permissions__in=[Target.Permissions.PUBLIC, Target.Permissions.PRIVATE]) class TargetListView(FilterView): """ @@ -94,21 +107,7 @@ def get_context_data(self, *args, **kwargs): def get_queryset(self, *args, **kwargs): qs = super().get_queryset(*args, **kwargs) - - if self.request.user.is_authenticated: - if self.request.user.is_superuser: - # Do not filter the queryset by permissions at all - return qs - else: - # Exclude targets that are private except for those that the user has explicit permissions to view - private_targets = qs.filter(permissions=Target.Permissions.PRIVATE) - public_targets = qs.exclude(permissions=Target.Permissions.PRIVATE) - return public_targets | get_objects_for_user(self.request.user, f'{Target._meta.app_label}.view_target', private_targets) - else: - # Only allow open targets - return qs.exclude(permissions__in=[Target.Permissions.PUBLIC, Target.Permissions.PRIVATE]) - - + return target_permission_filter(self.request.user, qs) class TargetNameSearchView(RedirectView):