diff --git a/oarepo_requests/invenio_patches.py b/oarepo_requests/invenio_patches.py index 69790e3d..e2f01d7e 100644 --- a/oarepo_requests/invenio_patches.py +++ b/oarepo_requests/invenio_patches.py @@ -1,10 +1,16 @@ from functools import cached_property -from flask_resources import ResponseHandler, JSONSerializer +from flask_resources import JSONSerializer, ResponseHandler from invenio_records_resources.resources.records.headers import etag_headers from invenio_records_resources.services.records.params import FilterParam -from invenio_requests.resources.requests.config import RequestSearchRequestArgsSchema, RequestsResourceConfig -from invenio_requests.services.requests.config import RequestSearchOptions, RequestsServiceConfig +from invenio_requests.resources.requests.config import ( + RequestSearchRequestArgsSchema, + RequestsResourceConfig, +) +from invenio_requests.services.requests.config import ( + RequestSearchOptions, + RequestsServiceConfig, +) from marshmallow import fields from opensearch_dsl.query import Bool, Term @@ -22,26 +28,30 @@ def apply(self, identity, search, params): class RequestReceiverFilterParam(FilterParam): def apply(self, identity, search, params): value = params.pop(self.param_name, None) - my_groups = [ - n.value for n in identity.provides if n.method == 'role' - ] + my_groups = [n.value for n in identity.provides if n.method == "role"] if value is not None: - search = search.filter(Bool(should=[ - # explicitly myself - Term(**{f"{self.field_name}.user": identity.id}), - # my roles - *[ - Term(**{f"{self.field_name}.group": group_id}) for group_id in my_groups - ], - # TODO: add my communities where I have a role to accept requests - ], minimum_should_match=1)) + search = search.filter( + Bool( + should=[ + # explicitly myself + Term(**{f"{self.field_name}.user": identity.id}), + # my roles + *[ + Term(**{f"{self.field_name}.group": group_id}) + for group_id in my_groups + ], + # TODO: add my communities where I have a role to accept requests + ], + minimum_should_match=1, + ) + ) return search class EnhancedRequestSearchOptions(RequestSearchOptions): params_interpreters_cls = RequestSearchOptions.params_interpreters_cls + [ RequestOwnerFilterParam.factory("mine", "created_by.user"), - RequestReceiverFilterParam.factory("assigned", "receiver") + RequestReceiverFilterParam.factory("assigned", "receiver"), ] @@ -68,10 +78,7 @@ def serialize_object_list(self): def serialize_object(self): return self.__instance.serialize_object - RequestsResourceConfig.response_handlers = { - "application/json": ResponseHandler(JSONSerializer(), headers=etag_headers), - "application/vnd.inveniordm.v1+json": ResponseHandler( - LazySerializer() - ) - } + "application/json": ResponseHandler(JSONSerializer(), headers=etag_headers), + "application/vnd.inveniordm.v1+json": ResponseHandler(LazySerializer()), + } diff --git a/oarepo_requests/resources/oarepo/resource.py b/oarepo_requests/resources/oarepo/resource.py index ca9e9551..8f0daf1f 100644 --- a/oarepo_requests/resources/oarepo/resource.py +++ b/oarepo_requests/resources/oarepo/resource.py @@ -27,62 +27,51 @@ def __init__( def create_url_rules(self): """Create the URL rules for the record resource.""" - base_routes = super().create_url_rules() - routes = self.config.routes def p(route): """Prefix a route with the URL prefix.""" return f"{self.config.url_prefix}{route}" - def s(route): - """Suffix a route with the URL prefix.""" - return f"{route}{self.config.url_prefix}" + routes = self.config.routes url_rules = [ route("POST", p(routes["list"]), self.create), - route("POST", p(routes["list-extended"]), self.create_extended), + route( + "POST", + p(routes["list-extended"]), + self.create, + endpoint="extended_create", + ), route("GET", p(routes["item-extended"]), self.read_extended), - route("PUT", p(routes["item-extended"]), self.update_extended), + route("PUT", p(routes["item-extended"]), self.update), ] - return url_rules + base_routes + return url_rules @request_extra_args - @request_view_args @request_headers + @request_view_args @request_data @response_handler() - def create(self): - - items = self.oarepo_requests_service.create( + def update(self): + item = self.oarepo_requests_service.update( + id_=resource_requestctx.view_args["id"], identity=g.identity, data=resource_requestctx.data, - request_type=resource_requestctx.data.pop("request_type", None), - topic=( - stringify_first_val(resource_requestctx.data.pop("topic", None)) - if resource_requestctx.data - else None - ), expand=resource_requestctx.args.get("expand", False), ) - - return items.to_dict(), 201 + return item.to_dict(), 200 @request_extra_args @request_view_args @request_headers @request_data @response_handler() - def create_extended(self): - def stringify_first_val(dct): - if isinstance(dct, dict): - for k, v in dct.items(): - dct[k] = str(v) - return dct + def create(self): items = self.oarepo_requests_service.create( identity=g.identity, data=resource_requestctx.data, - type_id=resource_requestctx.data.pop("request_type", None), + request_type=resource_requestctx.data.pop("request_type", None), topic=( stringify_first_val(resource_requestctx.data.pop("topic", None)) if resource_requestctx.data @@ -104,7 +93,3 @@ def read_extended(self): expand=resource_requestctx.args.get("expand", False), ) return item.to_dict(), 200 - - # from parent - def update_extended(self): - return super().update() diff --git a/oarepo_requests/resources/record/config.py b/oarepo_requests/resources/record/config.py index 7aac6e61..8298f000 100644 --- a/oarepo_requests/resources/record/config.py +++ b/oarepo_requests/resources/record/config.py @@ -1,5 +1,5 @@ import marshmallow as ma -from flask_resources import ResponseHandler, JSONSerializer +from flask_resources import JSONSerializer, ResponseHandler from invenio_records_resources.resources import RecordResourceConfig from invenio_records_resources.resources.records.headers import etag_headers @@ -21,5 +21,5 @@ def response_handlers(self): "application/vnd.inveniordm.v1+json": ResponseHandler( OARepoRequestsUIJSONSerializer() ), - "application/json": ResponseHandler(JSONSerializer(), headers=etag_headers) + "application/json": ResponseHandler(JSONSerializer(), headers=etag_headers), } diff --git a/oarepo_requests/resources/record/resource.py b/oarepo_requests/resources/record/resource.py index 5042d510..732b32af 100644 --- a/oarepo_requests/resources/record/resource.py +++ b/oarepo_requests/resources/record/resource.py @@ -21,9 +21,7 @@ def __init__(self, record_requests_config, config, service): """ actual_config = copy.deepcopy(record_requests_config) actual_config.blueprint_name = f"{config.blueprint_name}_requests" - vars_to_overwrite = [ - x for x in dir(config) if not x.startswith("_") - ] + vars_to_overwrite = [x for x in dir(config) if not x.startswith("_")] actual_keys = dir(actual_config) for var in vars_to_overwrite: if var not in actual_keys: diff --git a/oarepo_requests/services/oarepo/service.py b/oarepo_requests/services/oarepo/service.py index 8fd2f0fe..e6539bb0 100644 --- a/oarepo_requests/services/oarepo/service.py +++ b/oarepo_requests/services/oarepo/service.py @@ -1,4 +1,4 @@ -from invenio_records_resources.services.uow import unit_of_work +from invenio_records_resources.services.uow import IndexRefreshOp, unit_of_work from invenio_requests import current_request_type_registry from invenio_requests.services import RequestsService @@ -36,7 +36,7 @@ def create( else: error = None if not error: - return super().create( + result = super().create( identity=identity, data=data, request_type=type_, @@ -46,7 +46,19 @@ def create( expand=expand, uow=uow, ) + uow.register( + IndexRefreshOp(indexer=self.indexer, index=self.record_cls.index) + ) + return result def read(self, identity, id_, expand=False): api_request = super().read(identity, id_, expand) return api_request + + @unit_of_work() + def update(self, identity, id_, data, revision_id=None, uow=None, expand=False): + result = super().update( + identity, id_, data, revision_id=revision_id, uow=uow, expand=expand + ) + uow.register(IndexRefreshOp(indexer=self.indexer, index=self.record_cls.index)) + return result diff --git a/oarepo_requests/services/schema.py b/oarepo_requests/services/schema.py index 9b527cf0..fedc80b8 100644 --- a/oarepo_requests/services/schema.py +++ b/oarepo_requests/services/schema.py @@ -4,9 +4,10 @@ from invenio_requests.proxies import current_request_type_registry from invenio_requests.services.schemas import GenericRequestSchema from marshmallow import fields +from oarepo_runtime.records import is_published_record from oarepo_requests.utils import get_matching_service_for_record -from oarepo_runtime.records import is_published_record + def get_links_schema(): # TODO possibly specify more diff --git a/oarepo_requests/services/ui_schema.py b/oarepo_requests/services/ui_schema.py index 809beef0..d2f2575c 100644 --- a/oarepo_requests/services/ui_schema.py +++ b/oarepo_requests/services/ui_schema.py @@ -36,12 +36,12 @@ def dereference(self, data, **kwargs): return entity_resolvers[reference_type](self.context["identity"], data) else: # TODO log warning - return fallback_entity_reference_ui_resolver(self.context["identity"], data) + return fallback_entity_reference_ui_resolver( + self.context["identity"], data + ) except PIDDeletedError: - return { - **data, - "status": "removed" - } + return {**data, "status": "removed"} + class UIRequestSchemaMixin: created = LocalizedDateTime(dump_only=True) diff --git a/oarepo_requests/utils.py b/oarepo_requests/utils.py index 1e899463..41af6cff 100644 --- a/oarepo_requests/utils.py +++ b/oarepo_requests/utils.py @@ -140,4 +140,3 @@ def stringify_first_val(dct): for k, v in dct.items(): dct[k] = str(v) return dct - diff --git a/oarepo_requests/views/api.py b/oarepo_requests/views/api.py index c914ec84..32208d8e 100644 --- a/oarepo_requests/views/api.py +++ b/oarepo_requests/views/api.py @@ -4,6 +4,7 @@ def create_oarepo_requests(app): blueprint = ext.requests_resource.as_blueprint() from oarepo_requests.invenio_patches import override_invenio_requests_config + blueprint.record_once(override_invenio_requests_config) return blueprint diff --git a/setup.cfg b/setup.cfg index abfe5414..e4f4924f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = oarepo-requests -version = 1.1.11 +version = 1.1.12 description = authors = Ronald Krist readme = README.md diff --git a/tests/conftest.py b/tests/conftest.py index e03fa9ef..85ddf4f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ import pytest from flask_principal import Identity, Need, UserNeed -from invenio_access.permissions import system_identity from invenio_requests.customizations import CommentEventType, LogEventType from invenio_requests.proxies import current_request_type_registry, current_requests from invenio_requests.records.api import Request, RequestEventFormat @@ -357,9 +356,9 @@ def example_draft_service_bypass(app, db): @pytest.fixture() def record_factory(record_service): - def record(): - draft = record_service.create(system_identity, {}) - record = record_service.publish(system_identity, draft.id) + def record(identity): + draft = record_service.create(identity, {}) + record = record_service.publish(identity, draft.id) return record._obj return record diff --git a/tests/test_requests/test_create_inmodel.py b/tests/test_requests/test_create_inmodel.py index 1fa4ff57..6c07405a 100644 --- a/tests/test_requests/test_create_inmodel.py +++ b/tests/test_requests/test_create_inmodel.py @@ -21,7 +21,7 @@ def test_record( ): creator = users[0] receiver = users[1] - record1 = record_factory() + record1 = record_factory(identity_simple) record1 = logged_client_request( creator, "get", f"{urls['BASE_URL']}{record1['id']}" ) diff --git a/tests/test_requests/test_delete.py b/tests/test_requests/test_delete.py index 27e5eccf..481f68ad 100644 --- a/tests/test_requests/test_delete.py +++ b/tests/test_requests/test_delete.py @@ -14,9 +14,9 @@ def test_delete( ): creator = users[0] receiver = users[1] - record1 = record_factory() - record2 = record_factory() - record3 = record_factory() + record1 = record_factory(identity_simple) + record2 = record_factory(identity_simple) + record3 = record_factory(identity_simple) ThesisRecord.index.refresh() ThesisDraft.index.refresh() lst = logged_client_request(creator, "get", urls["BASE_URL"]) diff --git a/tests/test_requests/test_edit.py b/tests/test_requests/test_edit.py index 0cf64741..27c70d67 100644 --- a/tests/test_requests/test_edit.py +++ b/tests/test_requests/test_edit.py @@ -13,7 +13,7 @@ def test_edit_autoaccept( search_clear, ): creator = users[0] - record1 = record_factory() + record1 = record_factory(identity_simple) resp_request_create = logged_client_request( creator, diff --git a/tests/test_requests/test_index_refresh.py b/tests/test_requests/test_index_refresh.py new file mode 100644 index 00000000..d6a0f547 --- /dev/null +++ b/tests/test_requests/test_index_refresh.py @@ -0,0 +1,43 @@ +from tests.test_requests.utils import link_api2testclient + + +def test_search( + logged_client_request, + identity_simple, + users, + urls, + publish_request_data_function, + search_clear, +): + creator = users[0] + + draft1 = logged_client_request(creator, "post", urls["BASE_URL"], json={}) + + resp_request_create = logged_client_request( + creator, + "post", + urls["BASE_URL_REQUESTS"], + json=publish_request_data_function(draft1.json["id"]), + ) + # should work without refreshing requests index + requests_search = logged_client_request( + creator, "get", urls["BASE_URL_REQUESTS"] + ).json + + assert len(requests_search["hits"]["hits"]) == 1 + + link = link_api2testclient(requests_search["hits"]["hits"][0]["links"]["self"]) + extended_link = link.replace("/requests/", "/requests/extended/") + + update = logged_client_request( + creator, + "put", + extended_link, + json={"title": "tralala"}, + ) + + requests_search = logged_client_request( + creator, "get", urls["BASE_URL_REQUESTS"] + ).json + + assert requests_search["hits"]["hits"][0]["title"] == "tralala" diff --git a/tests/test_requests/test_record_requests.py b/tests/test_requests/test_record_requests.py index 8701f545..640b1c9b 100644 --- a/tests/test_requests/test_record_requests.py +++ b/tests/test_requests/test_record_requests.py @@ -64,9 +64,9 @@ def test_read_requests_on_record( ): creator_client = logged_clients[0] receiver = users[1] - record1 = record_factory() - record2 = record_factory() - record3 = record_factory() + record1 = record_factory(identity_simple) + record2 = record_factory(identity_simple) + record3 = record_factory(identity_simple) ThesisRecord.index.refresh() ThesisDraft.index.refresh() r1 = creator_client.post( diff --git a/tests/test_ui/test_ui_resource.py b/tests/test_ui/test_ui_resource.py index 794f0fd1..6f6ae661 100644 --- a/tests/test_ui/test_ui_resource.py +++ b/tests/test_ui/test_ui_resource.py @@ -42,7 +42,7 @@ def test_record_delete_request_present( data = json.loads(c.text) assert len(data["creatable_request_types"]) == 2 assert data["creatable_request_types"]["thesis_edit_record"] == { - "description": 'Request re-opening of published record', + "description": "Request re-opening of published record", "links": { "actions": { "create": f"https://127.0.0.1:5000/api/thesis/{example_topic['id']}/requests/thesis_edit_record"