diff --git a/moto/core/responses_custom_registry.py b/moto/core/responses_custom_registry.py index 64c9422dfcd6..13db30a06046 100644 --- a/moto/core/responses_custom_registry.py +++ b/moto/core/responses_custom_registry.py @@ -32,6 +32,23 @@ def add(self, response: responses.BaseResponse) -> responses.BaseResponse: self._registered[response.method].append(response) return response + def replace(self, response: responses.BaseResponse) -> responses.BaseResponse: + registered = self._registered[response.method] + try: + index = registered.index(response) + except ValueError: + raise ValueError(f"Response is not registered for URL {response.url}") + registered[index] = response + return response + + def remove(self, response: responses.BaseResponse) -> List[responses.BaseResponse]: + removed_responses = [] + registered = self._registered[response.method] + while response in registered: + registered.remove(response) + removed_responses.append(response) + return removed_responses + def reset(self) -> None: self._registered.clear() diff --git a/tests/test_core/test_request_mocking.py b/tests/test_core/test_request_mocking.py index 3865ba679810..d5ea67d885e8 100644 --- a/tests/test_core/test_request_mocking.py +++ b/tests/test_core/test_request_mocking.py @@ -1,8 +1,12 @@ +from unittest import SkipTest + import boto3 import pytest import requests +from responses import Response from moto import mock_aws, settings +from moto.core.models import responses_mock @mock_aws @@ -48,3 +52,20 @@ def test_decorator_ordering() -> None: resp = requests.get(presigned_url) assert resp.status_code == 200 + + +@mock_aws() +def test_replace_and_remove_mock() -> None: + if not settings.TEST_DECORATOR_MODE: + raise SkipTest("Only need to test responses mock in decorator mode") + rsp1 = Response(method="GET", url="http://example.com", body="test") + responses_mock.add(rsp1) + + assert requests.get("http://example.com/").text == "test" + + rsp2 = Response(method="GET", url="http://example.com", body="test2") + responses_mock.replace(rsp2) + + assert requests.get("http://example.com/").text == "test2" + + responses_mock.remove("GET", "https://example.com/")