diff --git a/examples/src/main/java/com/flipkart/gjex/examples/helloworld/filter/CustomHeaderHttpFilter.java b/examples/src/main/java/com/flipkart/gjex/examples/helloworld/filter/CustomHeaderHttpFilter.java new file mode 100644 index 00000000..007baef7 --- /dev/null +++ b/examples/src/main/java/com/flipkart/gjex/examples/helloworld/filter/CustomHeaderHttpFilter.java @@ -0,0 +1,57 @@ +/* + * Copyright (c) The original author or authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.flipkart.gjex.examples.helloworld.filter; + +import com.flipkart.gjex.core.filter.RequestParams; +import com.flipkart.gjex.core.filter.http.HttpFilter; + +import javax.inject.Named; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletResponse; +import java.util.Map; + +@Named("CustomHeaderHttpFilter") +public class CustomHeaderHttpFilter extends HttpFilter { + + @Override + public void doProcessRequest(ServletRequest servletRequest, RequestParams> requestParams) { + super.doProcessRequest(servletRequest, requestParams); + } + + @Override + public void doProcessResponseHeaders(Map responseHeaders) { + super.doProcessResponseHeaders(responseHeaders); + responseHeaders.put("x-custom-header1", "value1"); + } + + @Override + public void doProcessResponse(ServletResponse response) { + super.doProcessResponse(response); + HttpServletResponse httpServletResponse = (HttpServletResponse) response; + httpServletResponse.addHeader("x-custom-header2", "value2"); + } + + @Override + public void doHandleException(Exception e) { + super.doHandleException(e); + } + + @Override + public HttpFilter getInstance() { + return new CustomHeaderHttpFilter(); + } +} diff --git a/examples/src/main/java/com/flipkart/gjex/examples/helloworld/guice/HelloWorldModule.java b/examples/src/main/java/com/flipkart/gjex/examples/helloworld/guice/HelloWorldModule.java index 060c188d..c8ca9f2f 100644 --- a/examples/src/main/java/com/flipkart/gjex/examples/helloworld/guice/HelloWorldModule.java +++ b/examples/src/main/java/com/flipkart/gjex/examples/helloworld/guice/HelloWorldModule.java @@ -16,9 +16,11 @@ package com.flipkart.gjex.examples.helloworld.guice; import com.flipkart.gjex.core.filter.grpc.GrpcFilter; +import com.flipkart.gjex.core.filter.http.HttpFilterParams; import com.flipkart.gjex.core.filter.http.JavaxFilterParams; import com.flipkart.gjex.core.tracing.TracingSampler; import com.flipkart.gjex.examples.helloworld.filter.AuthFilter; +import com.flipkart.gjex.examples.helloworld.filter.CustomHeaderHttpFilter; import com.flipkart.gjex.examples.helloworld.filter.LoggingFilter; import com.flipkart.gjex.examples.helloworld.service.GreeterService; import com.flipkart.gjex.examples.helloworld.tracing.AllWhitelistTracingSampler; @@ -54,7 +56,8 @@ protected void configure() { // bind(AccessLogGrpcFilter.class).to(AccessLogTestFilter.class); bind(TracingSampler.class).to(AllWhitelistTracingSampler.class); bind(ResourceConfig.class).annotatedWith(Names.named("HelloWorldResourceConfig")).to(HelloWorldResourceConfig.class); - bind(JavaxFilterParams.class).annotatedWith(Names.named("ExampleJavaxFilter")) - .toInstance(JavaxFilterParams.builder().filter(new ExampleJavaxFilter()).pathSpec("/*").build()); + bind(JavaxFilterParams.class).annotatedWith(Names.named("ExampleJavaxFilter")).toInstance(JavaxFilterParams.builder().filter(new ExampleJavaxFilter()).pathSpec("/*").build()); + bind(HttpFilterParams.class).annotatedWith(Names.named("CustomHeaderHttpFilter")).toInstance(HttpFilterParams.builder().filter(new CustomHeaderHttpFilter()).pathSpec("/*").build()); + } } diff --git a/examples/src/main/java/com/flipkart/gjex/examples/helloworld/web/javaxfilter/ExampleJavaxFilter.java b/examples/src/main/java/com/flipkart/gjex/examples/helloworld/web/javaxfilter/ExampleJavaxFilter.java index 37ba4044..9e0c2171 100644 --- a/examples/src/main/java/com/flipkart/gjex/examples/helloworld/web/javaxfilter/ExampleJavaxFilter.java +++ b/examples/src/main/java/com/flipkart/gjex/examples/helloworld/web/javaxfilter/ExampleJavaxFilter.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) The original author or authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.flipkart.gjex.examples.helloworld.web.javaxfilter; import com.flipkart.gjex.core.logging.Logging; diff --git a/guice/src/main/java/com/flipkart/gjex/http/interceptor/FilterServletResponseWrapper.java b/guice/src/main/java/com/flipkart/gjex/http/interceptor/FilterServletResponseWrapper.java new file mode 100644 index 00000000..f545d07d --- /dev/null +++ b/guice/src/main/java/com/flipkart/gjex/http/interceptor/FilterServletResponseWrapper.java @@ -0,0 +1,129 @@ +/* + * Copyright (c) The original author or authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.flipkart.gjex.http.interceptor; + +import javax.servlet.ServletOutputStream; +import javax.servlet.WriteListener; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintWriter; + +/** + * A wrapper for HttpServletResponse that captures the output stream and writer. + */ +public class FilterServletResponseWrapper extends HttpServletResponseWrapper { + + private final ServletOutputStreamWrapper stream = new ServletOutputStreamWrapper(); + private PrintWriter pw; + + /** + * Constructs a response object wrapping the given response. + * + * @param response the HttpServletResponse to be wrapped + * @throws IllegalArgumentException if the response is null + */ + public FilterServletResponseWrapper(HttpServletResponse response) { + super(response); + } + + /** + * Returns the ServletOutputStream for this response. + * + * @return the ServletOutputStream + * @throws IOException if an I/O error occurs + */ + public ServletOutputStream getOutputStream() throws IOException { + if (pw != null) { + pw.flush(); + } + return stream; + } + + /** + * Returns a PrintWriter for this response. + * + * @return the PrintWriter + * @throws IOException if an I/O error occurs + */ + public PrintWriter getWriter() throws IOException { + pw = new PrintWriter(stream); + return pw; + } + + /** + * Returns the captured bytes from the output stream. + * + * @return a byte array containing the captured bytes + */ + public byte[] getWrapperBytes() { + return stream.getBytes(); + } + + /** + * A wrapper for ServletOutputStream that captures written bytes. + */ + static class ServletOutputStreamWrapper extends ServletOutputStream { + + private final ByteArrayOutputStream out = new ByteArrayOutputStream(); + + /** + * Writes the specified byte to this output stream. + * + * @param b the byte to be written + * @throws IOException if an I/O error occurs + */ + public void write(int b) throws IOException { + out.write(b); + } + + /** + * Returns the captured bytes from the output stream. + * + * @return a byte array containing the captured bytes + */ + public byte[] getBytes() { + return out.toByteArray(); + } + + /** + * Indicates whether this output stream is ready to be written to. + * + * @return true if the output stream is ready, false otherwise + */ + @Override + public boolean isReady() { + return true; + } + + /** + * Sets the WriteListener for this output stream. + * + * @param writeListener the WriteListener to be set + */ + @Override + public void setWriteListener(WriteListener writeListener) { + try { + writeListener.onWritePossible(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + } + +} diff --git a/guice/src/main/java/com/flipkart/gjex/http/interceptor/HttpFilterInterceptor.java b/guice/src/main/java/com/flipkart/gjex/http/interceptor/HttpFilterInterceptor.java index eed8582b..07635174 100644 --- a/guice/src/main/java/com/flipkart/gjex/http/interceptor/HttpFilterInterceptor.java +++ b/guice/src/main/java/com/flipkart/gjex/http/interceptor/HttpFilterInterceptor.java @@ -4,24 +4,27 @@ import com.flipkart.gjex.core.filter.http.HttpFilter; import com.flipkart.gjex.core.filter.http.HttpFilterParams; import org.eclipse.jetty.http.pathmap.ServletPathSpec; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import javax.inject.Named; import javax.inject.Singleton; -import javax.servlet.FilterChain; -import javax.servlet.FilterConfig; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; +import javax.servlet.*; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; import java.util.stream.Collectors; @Singleton @Named("HttpFilterInterceptor") public class HttpFilterInterceptor implements javax.servlet.Filter { + private static final Logger logger = LoggerFactory.getLogger(HttpFilterInterceptor.class); + private static class ServletPathFiltersHolder { ServletPathSpec spec; HttpFilter filter; @@ -62,32 +65,45 @@ public void init(FilterConfig filterConfig) throws ServletException {} public final void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - List filters = new ArrayList<>(); RequestParams.RequestParamsBuilder> requestParamsBuilder = RequestParams.builder(); - try { - if (request instanceof HttpServletRequest){ - HttpServletRequest httpServletRequest = (HttpServletRequest) request; - filters = getMatchingFilters(httpServletRequest.getRequestURI()); - - Map headers = Collections.list(httpServletRequest.getHeaderNames()) - .stream().collect(Collectors.toMap(String::toLowerCase, httpServletRequest::getHeader)); - requestParamsBuilder.metadata(headers); - requestParamsBuilder.clientIp(getClientIp(request)); - requestParamsBuilder.method(httpServletRequest.getMethod()); - requestParamsBuilder.resourcePath(getFullURL(httpServletRequest)); - } + + if (request instanceof HttpServletRequest && response instanceof HttpServletResponse) { + HttpServletRequest httpServletRequest = (HttpServletRequest) request; + HttpServletResponse httpServletResponse = (HttpServletResponse) response; + + List filters = getMatchingFilters(httpServletRequest.getRequestURI()); + Map headers = Collections.list(httpServletRequest.getHeaderNames()) + .stream().collect(Collectors.toMap(String::toLowerCase, httpServletRequest::getHeader)); + requestParamsBuilder.metadata(headers); + requestParamsBuilder.clientIp(getClientIp(request)); + requestParamsBuilder.method(httpServletRequest.getMethod()); + requestParamsBuilder.resourcePath(getFullURL(httpServletRequest)); + RequestParams> requestParams = requestParamsBuilder.build(); - filters.forEach(filter -> filter.doProcessRequest(request, requestParams)); - chain.doFilter(request, response); - } finally { - if (response instanceof HttpServletResponse) { - HttpServletResponse httpServletResponse = (HttpServletResponse) response; - Map headers = httpServletResponse.getHeaderNames() + FilterServletResponseWrapper responseWrapper = new FilterServletResponseWrapper(httpServletResponse); + + try { + filters.forEach(filter -> filter.doProcessRequest(request, requestParams)); + chain.doFilter(request, responseWrapper); + + // Allow the filters to process the response headers + Map responseHeaders = responseWrapper.getHeaderNames() .stream().collect(Collectors.toMap(String::toLowerCase, httpServletResponse::getHeader)); - filters.forEach(filter -> filter.doProcessResponseHeaders(headers)); + filters.forEach(filter -> filter.doProcessResponseHeaders(responseHeaders)); + responseHeaders.forEach(responseWrapper::setHeader); + } finally { + // Allow the filters to process the response + filters.forEach(filter -> filter.doProcessResponse(responseWrapper)); + response.getOutputStream().write(responseWrapper.getWrapperBytes()); } - filters.forEach(filter -> filter.doProcessResponse(response)); + + } else { + // For Unsupported request types, pass the request and response as is + chain.doFilter(request, response); + logger.warn("Unsupported request type {}, pass the request and response as is.", request.getClass()); } + + } /**