Skip to content

Commit

Permalink
type check fixes andtype checks for generators
Browse files Browse the repository at this point in the history
  • Loading branch information
Ronald Krist committed Jan 6, 2025
1 parent 8253117 commit 6f13e9a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 41 deletions.
4 changes: 2 additions & 2 deletions oarepo_communities/ext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Dict

from flask_principal import identity_loaded

Expand Down Expand Up @@ -76,7 +76,7 @@ def init_config(self, app: Flask) -> None:
}

@cached_property
def urlprefix_serviceid_mapping(self) -> str:
def urlprefix_serviceid_mapping(self) -> dict[str, str]:
return get_urlprefix_service_id_mapping()

def get_community_default_workflow(self, **kwargs)->str | None:
Expand Down
78 changes: 41 additions & 37 deletions oarepo_communities/services/permissions/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
MissingDefaultCommunityError,
)
from oarepo_communities.proxies import current_oarepo_communities

#---
from typing import Any
from flask_principal import Need
from invenio_drafts_resources.records import Record as RecordWithDraft #should probably replace with drafts record where parent is expected
from flask_principal import Identity

def _user_in_community_need(user, community):
_Need = namedtuple("Need", ["method", "value", "user", "community"])
Expand All @@ -29,11 +33,11 @@ def _user_in_community_need(user, community):


class InAnyCommunity(Generator):
def __init__(self, permission_generator, **kwargs):
def __init__(self, permission_generator: Generator, **kwargs: Any)->None:
self.permission_generator = permission_generator
super().__init__(**kwargs)

def needs(self, **kwargs):
def needs(self, **kwargs: Any) -> list[Need]:
communities = CommunityMetadata.query.all()
needs = set() # to avoid duplicates
# TODO: this is linear with number of communities, optimize
Expand All @@ -55,7 +59,7 @@ def needs(self, **kwargs):

class CommunityWorkflowPermission(WorkflowPermission):

def _get_workflow_id(self, record=None, **kwargs):
def _get_workflow_id(self, record: RecordWithDraft = None, **kwargs: Any) -> str:
# todo - check the record branch too? idk makes more sense to not use the default community's workflow, there is a deeper problem if there's no workflow on the record
try:
return super()._get_workflow_id(record=None, **kwargs)
Expand All @@ -71,32 +75,32 @@ def _get_workflow_id(self, record=None, **kwargs):
raise MissingWorkflowError("Workflow not defined on record.")


def convert_community_ids_to_uuid(community_id):
def convert_community_ids_to_uuid(community_id: str) -> str:
# if it already is a string representation of uuid, keep it as it is
try:
uuid.UUID(community_id, version=4)
uuid.UUID(community_id, version=4) #?
return community_id
except ValueError:
community = Community.pid.resolve(community_id)
return str(community.id)


class CommunityRoleMixin:
def _get_record_communities(self, record=None, **kwargs):
def _get_record_communities(self, record: RecordWithDraft = None, **kwargs: Any)->list[str]:
try:
return record.parent.communities.ids
except AttributeError:
raise MissingCommunitiesError(f"Communities missing on record {record}.")

def _get_data_communities(self, data=None, **kwargs):
def _get_data_communities(self, data: dict=None, **kwargs: Any)->list[str]:
community_ids = (data or {}).get("parent", {}).get("communities", {}).get("ids")
if not community_ids:
raise MissingCommunitiesError("Communities not defined in input data.")
return [convert_community_ids_to_uuid(x) for x in community_ids]


class DefaultCommunityRoleMixin:
def _get_record_communities(self, record=None, **kwargs):
def _get_record_communities(self, record: RecordWithDraft = None, **kwargs: Any)->list[str]:
try:
return [str(record.parent.communities.default.id)]
except (AttributeError, TypeError) as e:
Expand All @@ -107,7 +111,7 @@ def _get_record_communities(self, record=None, **kwargs):
f"Default community missing on record {record}."
)

def _get_data_communities(self, data=None, **kwargs):
def _get_data_communities(self, data: dict = None, **kwargs: Any)->list[str]:
community_id = (
(data or {}).get("parent", {}).get("communities", {}).get("default")
)
Expand All @@ -120,7 +124,7 @@ def _get_data_communities(self, data=None, **kwargs):

class OARepoCommunityRoles(CommunityRoles):
# Invenio generators do not capture all situations where we need community id from record
def communities(self, identity):
def communities(self, identity: Identity)->list[str]:
"""Communities that an identity can manage."""
roles = self.roles(identity=identity)
community_ids = set()
Expand All @@ -130,18 +134,18 @@ def communities(self, identity):
return list(community_ids)

@abc.abstractmethod
def _get_record_communities(self, record=None, **kwargs):
def _get_record_communities(self, record: RecordWithDraft=None, **kwargs: Any) -> list[str]:
raise NotImplementedError()

@abc.abstractmethod
def _get_data_communities(self, data=None, **kwargs):
def _get_data_communities(self, data: dict = None, **kwargs: Any) -> list[str]:
raise NotImplementedError()

@abc.abstractmethod
def roles(self, **kwargs):
def roles(self, **kwargs: Any)->list[str]:
raise NotImplementedError()

def needs(self, record=None, data=None, **kwargs):
def needs(self, record: RecordWithDraft=None, data:dict=None, **kwargs:Any) -> list[Need]:
"""Set of Needs granting permission."""
if record:
community_ids = self._get_record_communities(record)
Expand All @@ -152,17 +156,17 @@ def needs(self, record=None, data=None, **kwargs):
for c in community_ids:
for role in self.roles(**kwargs):
_needs.add(CommunityRoleNeed(c, role))
return _needs
return list(_needs)

@abc.abstractmethod
def query_filter_field(self):
def query_filter_field(self)->str:
"""Field for query filter.
returns parent.communities.ids or parent.communities.default
"""
raise NotImplementedError()

def query_filter(self, identity=None, **kwargs):
def query_filter(self, identity: Identity=None, **kwargs: Any)->dsl.Q:
"""Filter for current identity."""
community_ids = self.communities(identity)
if not community_ids:
Expand All @@ -172,33 +176,33 @@ def query_filter(self, identity=None, **kwargs):

class CommunityRole(CommunityRoleMixin, OARepoCommunityRoles):

def __init__(self, role):
def __init__(self, role:str)->None:
self._role = role
super().__init__()

def roles(self, **kwargs):
def roles(self, **kwargs: Any)->list[str]:
return [self._role]

def query_filter_field(self):
def query_filter_field(self)->str:
return "parent.communities.ids"


class DefaultCommunityRole(
DefaultCommunityRoleMixin, RecipientGeneratorMixin, OARepoCommunityRoles
):

def __init__(self, role):
def __init__(self, role: str)->None:
self._role = role
super().__init__()

def roles(self, **kwargs):
def roles(self, **kwargs: Any)->list[str]:
return [self._role]

def reference_receivers(self, **kwargs):
def reference_receivers(self, **kwargs: Any)->list[dict[str, str]]:
community_id = self._get_record_communities(**kwargs)[0]
return [{"community_role": f"{community_id}:{self._role}"}]

def query_filter_field(self):
def query_filter_field(self)->str:
return "parent.communities.default"


Expand All @@ -207,7 +211,7 @@ def query_filter_field(self):

class TargetCommunityRole(DefaultCommunityRole):

def _get_data_communities(self, data=None, **kwargs):
def _get_data_communities(self, data: dict=None, **kwargs: Any)->list[str]:
try:
community_id = data["payload"]["community"]
except KeyError:
Expand All @@ -216,28 +220,28 @@ def _get_data_communities(self, data=None, **kwargs):
)
return [community_id]

def reference_receivers(self, **kwargs):
def reference_receivers(self, **kwargs: Any)->list[dict[str, str]]:
community_id = self._get_data_communities(**kwargs)[0]
return [{"community_role": f"{community_id}:{self._role}"}]


class CommunityMembers(CommunityRoleMixin, OARepoCommunityRoles):

def roles(self, **kwargs):
def roles(self, **kwargs: Any)->list[str]:
"""Roles."""
return [r.name for r in current_roles]

def query_filter_field(self):
def query_filter_field(self)->str:
return "parent.communities.ids"


class DefaultCommunityMembers(DefaultCommunityRoleMixin, OARepoCommunityRoles):

def roles(self, **kwargs):
def roles(self, **kwargs: Any)->list[str]:
"""Roles."""
return [r.name for r in current_roles]

def query_filter_field(self):
def query_filter_field(self)->str:
return "parent.communities.default"


Expand All @@ -247,14 +251,14 @@ def query_filter_field(self):
class RecordOwnerInDefaultRecordCommunity(DefaultCommunityRoleMixin, Generator):
default_or_ids = "default"

def _record_communities(self, record, **kwargs):
def _record_communities(self, record: RecordWithDraft=None, **kwargs: Any) -> set[str]:
return set(self._get_record_communities(record, **kwargs))

def needs(self, record=None, **kwargs):
def needs(self, record: RecordWithDraft=None, data:dict=None, **kwargs:Any) -> list[Need]:
record_communities = set(self._get_record_communities(record, **kwargs))
return self._needs(record_communities, record=record)

def _needs(self, record_communities, record=None):
def _needs(self, record_communities: set[str], record: RecordWithDraft = None) -> list[Need]:
owners = getattr(record.parent, "owners", None)
ret = []
for owner in owners:
Expand All @@ -264,7 +268,7 @@ def _needs(self, record_communities, record=None):
]
return ret

def query_filter(self, identity=None, **kwargs):
def query_filter(self, identity: Identity=None, **kwargs: Any)->dsl.Q:
"""Filters for current identity as owner."""
user_in_communities = {
(n.user, n.community)
Expand Down Expand Up @@ -294,9 +298,9 @@ class RecordOwnerInRecordCommunity(
default_or_ids = "ids"

# trick to use CommunityRoleMixin instead of DefaultCommunityRoleMixin
def _record_communities(self, record, **kwargs):
def _record_communities(self, record: RecordWithDraft=None, **kwargs: Any) -> set[str]:
return set(self._get_record_communities(record, **kwargs))

def needs(self, record=None, **kwargs):
def needs(self, record: RecordWithDraft=None, data:dict=None, **kwargs:Any) -> list[Need]:
record_communities = set(self._get_record_communities(record, **kwargs))
return self._needs(record_communities, record=record)
4 changes: 2 additions & 2 deletions oarepo_communities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from invenio_records_resources.records import Record


def get_community_needs_for_identity(identity: Identity)->list[tuple[str, str]]:
def get_community_needs_for_identity(identity: Identity)->list[tuple[str, str]] | None:
# see invenio_communities.utils.load_community_needs
if identity.id is None:
# no user is logged in
Expand Down Expand Up @@ -95,7 +95,7 @@ def get_urlprefix_service_id_mapping()->dict[str, str]:
return ret


def community_id_from_record(record: Record)->str:
def community_id_from_record(record: Record)->str | None:

if isinstance(record, Community):
community_id = record.id
Expand Down

0 comments on commit 6f13e9a

Please sign in to comment.