diff --git a/spring-web/src/main/java/org/springframework/web/filter/ServerHttpObservationFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ServerHttpObservationFilter.java index 1c56741a4ba2..cee9c9e7ce3b 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ServerHttpObservationFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ServerHttpObservationFilter.java @@ -23,6 +23,7 @@ import io.micrometer.observation.ObservationRegistry; import jakarta.servlet.AsyncEvent; import jakarta.servlet.AsyncListener; +import jakarta.servlet.DispatcherType; import jakarta.servlet.FilterChain; import jakarta.servlet.RequestDispatcher; import jakarta.servlet.ServletException; @@ -97,6 +98,11 @@ public static Optional findObservationContext(H return Optional.ofNullable((ServerRequestObservationContext) request.getAttribute(CURRENT_OBSERVATION_CONTEXT_ATTRIBUTE)); } + @Override + protected boolean shouldNotFilterAsyncDispatch() { + return false; + } + @Override @SuppressWarnings("try") protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) @@ -117,8 +123,9 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse if (request.isAsyncStarted()) { request.getAsyncContext().addListener(new ObservationAsyncListener(observation)); } - // Stop Observation right now if async processing has not been started. - else { + // scope is opened for ASYNC dispatches, but the observation will be closed + // by the async listener. + else if (request.getDispatcherType() != DispatcherType.ASYNC){ Throwable error = fetchException(request); if (error != null) { observation.error(error); @@ -188,7 +195,6 @@ public void onComplete(AsyncEvent event) { @Override public void onError(AsyncEvent event) { this.currentObservation.error(unwrapServletException(event.getThrowable())); - this.currentObservation.stop(); } } diff --git a/spring-web/src/test/java/org/springframework/web/filter/ServerHttpObservationFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ServerHttpObservationFilterTests.java index 7bca8a323fab..78a895fe7704 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ServerHttpObservationFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ServerHttpObservationFilterTests.java @@ -16,12 +16,18 @@ package org.springframework.web.filter; +import java.io.IOException; + import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; +import jakarta.servlet.AsyncEvent; +import jakarta.servlet.AsyncListener; +import jakarta.servlet.DispatcherType; import jakarta.servlet.RequestDispatcher; import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.Test; @@ -29,6 +35,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.server.observation.ServerRequestObservationContext; import org.springframework.util.Assert; +import org.springframework.web.testfixture.servlet.MockAsyncContext; import org.springframework.web.testfixture.servlet.MockFilterChain; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.testfixture.servlet.MockHttpServletResponse; @@ -45,18 +52,18 @@ class ServerHttpObservationFilterTests { private final TestObservationRegistry observationRegistry = TestObservationRegistry.create(); - private final MockFilterChain mockFilterChain = new MockFilterChain(); - private final MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.GET.name(), "/resource/test"); private final MockHttpServletResponse response = new MockHttpServletResponse(); + private MockFilterChain mockFilterChain = new MockFilterChain(); + private ServerHttpObservationFilter filter = new ServerHttpObservationFilter(this.observationRegistry); @Test - void filterShouldNotProcessAsyncDispatch() { - assertThat(this.filter.shouldNotFilterAsyncDispatch()).isTrue(); + void filterShouldProcessAsyncDispatch() { + assertThat(this.filter.shouldNotFilterAsyncDispatch()).isFalse(); } @Test @@ -72,6 +79,12 @@ void filterShouldFillObservationContext() throws Exception { assertThatHttpObservation().hasLowCardinalityKeyValue("outcome", "SUCCESS").hasBeenStopped(); } + @Test + void filterShouldOpenScope() throws Exception { + this.mockFilterChain = new MockFilterChain(new ScopeCheckingServlet(this.observationRegistry)); + filter.doFilter(this.request, this.response, this.mockFilterChain); + } + @Test void filterShouldAcceptNoOpObservationContext() throws Exception { this.filter = new ServerHttpObservationFilter(ObservationRegistry.NOOP); @@ -136,9 +149,52 @@ void shouldCloseObservationAfterAsyncCompletion() throws Exception { assertThatHttpObservation().hasLowCardinalityKeyValue("outcome", "SUCCESS").hasBeenStopped(); } + @Test + void shouldCloseObservationAfterAsyncError() throws Exception { + this.request.setAsyncSupported(true); + this.request.startAsync(); + this.filter.doFilter(this.request, this.response, this.mockFilterChain); + MockAsyncContext asyncContext = (MockAsyncContext) this.request.getAsyncContext(); + for (AsyncListener listener : asyncContext.getListeners()) { + listener.onError(new AsyncEvent(this.request.getAsyncContext(), new IllegalStateException("test error"))); + } + asyncContext.complete(); + assertThatHttpObservation().hasLowCardinalityKeyValue("exception", "IllegalStateException").hasBeenStopped(); + } + + @Test + void shouldNotCloseObservationDuringAsyncDispatch() throws Exception { + this.mockFilterChain = new MockFilterChain(new ScopeCheckingServlet(this.observationRegistry)); + this.request.setDispatcherType(DispatcherType.ASYNC); + this.filter.doFilter(this.request, this.response, this.mockFilterChain); + TestObservationRegistryAssert.assertThat(this.observationRegistry) + .hasObservationWithNameEqualTo("http.server.requests") + .that().isNotStopped(); + } + private TestObservationRegistryAssert.TestObservationRegistryAssertReturningObservationContextAssert assertThatHttpObservation() { + TestObservationRegistryAssert.assertThat(this.observationRegistry) + .hasNumberOfObservationsWithNameEqualTo("http.server.requests", 1); + return TestObservationRegistryAssert.assertThat(this.observationRegistry) - .hasObservationWithNameEqualTo("http.server.requests").that(); + .hasObservationWithNameEqualTo("http.server.requests") + .that() + .hasBeenStopped(); + } + + @SuppressWarnings("serial") + static class ScopeCheckingServlet extends HttpServlet { + + private final ObservationRegistry observationRegistry; + + public ScopeCheckingServlet(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + assertThat(this.observationRegistry.getCurrentObservation()).isNotNull(); + } } static class CustomObservationFilter extends ServerHttpObservationFilter {