Skip to content

Commit

Permalink
Allow mutation in HttpFilters (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
kingster authored Aug 10, 2024
1 parent 081eeec commit 0369cd4
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright (c) The original author or authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.flipkart.gjex.examples.helloworld.filter;

import com.flipkart.gjex.core.filter.RequestParams;
import com.flipkart.gjex.core.filter.http.HttpFilter;

import javax.inject.Named;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletResponse;
import java.util.Map;

@Named("CustomHeaderHttpFilter")
public class CustomHeaderHttpFilter extends HttpFilter {

@Override
public void doProcessRequest(ServletRequest servletRequest, RequestParams<Map<String, String>> requestParams) {
super.doProcessRequest(servletRequest, requestParams);
}

@Override
public void doProcessResponseHeaders(Map<String, String> responseHeaders) {
super.doProcessResponseHeaders(responseHeaders);
responseHeaders.put("x-custom-header1", "value1");
}

@Override
public void doProcessResponse(ServletResponse response) {
super.doProcessResponse(response);
HttpServletResponse httpServletResponse = (HttpServletResponse) response;
httpServletResponse.addHeader("x-custom-header2", "value2");
}

@Override
public void doHandleException(Exception e) {
super.doHandleException(e);
}

@Override
public HttpFilter getInstance() {
return new CustomHeaderHttpFilter();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
package com.flipkart.gjex.examples.helloworld.guice;

import com.flipkart.gjex.core.filter.grpc.GrpcFilter;
import com.flipkart.gjex.core.filter.http.HttpFilterParams;
import com.flipkart.gjex.core.filter.http.JavaxFilterParams;
import com.flipkart.gjex.core.tracing.TracingSampler;
import com.flipkart.gjex.examples.helloworld.filter.AuthFilter;
import com.flipkart.gjex.examples.helloworld.filter.CustomHeaderHttpFilter;
import com.flipkart.gjex.examples.helloworld.filter.LoggingFilter;
import com.flipkart.gjex.examples.helloworld.service.GreeterService;
import com.flipkart.gjex.examples.helloworld.tracing.AllWhitelistTracingSampler;
Expand Down Expand Up @@ -54,7 +56,8 @@ protected void configure() {
// bind(AccessLogGrpcFilter.class).to(AccessLogTestFilter.class);
bind(TracingSampler.class).to(AllWhitelistTracingSampler.class);
bind(ResourceConfig.class).annotatedWith(Names.named("HelloWorldResourceConfig")).to(HelloWorldResourceConfig.class);
bind(JavaxFilterParams.class).annotatedWith(Names.named("ExampleJavaxFilter"))
.toInstance(JavaxFilterParams.builder().filter(new ExampleJavaxFilter()).pathSpec("/*").build());
bind(JavaxFilterParams.class).annotatedWith(Names.named("ExampleJavaxFilter")).toInstance(JavaxFilterParams.builder().filter(new ExampleJavaxFilter()).pathSpec("/*").build());
bind(HttpFilterParams.class).annotatedWith(Names.named("CustomHeaderHttpFilter")).toInstance(HttpFilterParams.builder().filter(new CustomHeaderHttpFilter()).pathSpec("/*").build());

}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
/*
* Copyright (c) The original author or authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.flipkart.gjex.examples.helloworld.web.javaxfilter;

import com.flipkart.gjex.core.logging.Logging;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright (c) The original author or authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.flipkart.gjex.http.interceptor;

import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintWriter;

/**
* A wrapper for HttpServletResponse that captures the output stream and writer.
*/
public class FilterServletResponseWrapper extends HttpServletResponseWrapper {

private final ServletOutputStreamWrapper stream = new ServletOutputStreamWrapper();
private PrintWriter pw;

/**
* Constructs a response object wrapping the given response.
*
* @param response the HttpServletResponse to be wrapped
* @throws IllegalArgumentException if the response is null
*/
public FilterServletResponseWrapper(HttpServletResponse response) {
super(response);
}

/**
* Returns the ServletOutputStream for this response.
*
* @return the ServletOutputStream
* @throws IOException if an I/O error occurs
*/
public ServletOutputStream getOutputStream() throws IOException {
if (pw != null) {
pw.flush();
}
return stream;
}

/**
* Returns a PrintWriter for this response.
*
* @return the PrintWriter
* @throws IOException if an I/O error occurs
*/
public PrintWriter getWriter() throws IOException {
pw = new PrintWriter(stream);
return pw;
}

/**
* Returns the captured bytes from the output stream.
*
* @return a byte array containing the captured bytes
*/
public byte[] getWrapperBytes() {
return stream.getBytes();
}

/**
* A wrapper for ServletOutputStream that captures written bytes.
*/
static class ServletOutputStreamWrapper extends ServletOutputStream {

private final ByteArrayOutputStream out = new ByteArrayOutputStream();

/**
* Writes the specified byte to this output stream.
*
* @param b the byte to be written
* @throws IOException if an I/O error occurs
*/
public void write(int b) throws IOException {
out.write(b);
}

/**
* Returns the captured bytes from the output stream.
*
* @return a byte array containing the captured bytes
*/
public byte[] getBytes() {
return out.toByteArray();
}

/**
* Indicates whether this output stream is ready to be written to.
*
* @return true if the output stream is ready, false otherwise
*/
@Override
public boolean isReady() {
return true;
}

/**
* Sets the WriteListener for this output stream.
*
* @param writeListener the WriteListener to be set
*/
@Override
public void setWriteListener(WriteListener writeListener) {
try {
writeListener.onWritePossible();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,27 @@
import com.flipkart.gjex.core.filter.http.HttpFilter;
import com.flipkart.gjex.core.filter.http.HttpFilterParams;
import org.eclipse.jetty.http.pathmap.ServletPathSpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.inject.Named;
import javax.inject.Singleton;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Singleton
@Named("HttpFilterInterceptor")
public class HttpFilterInterceptor implements javax.servlet.Filter {

private static final Logger logger = LoggerFactory.getLogger(HttpFilterInterceptor.class);

private static class ServletPathFiltersHolder {
ServletPathSpec spec;
HttpFilter filter;
Expand Down Expand Up @@ -62,32 +65,45 @@ public void init(FilterConfig filterConfig) throws ServletException {}
public final void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {

List<HttpFilter> filters = new ArrayList<>();
RequestParams.RequestParamsBuilder<Map<String,String>> requestParamsBuilder = RequestParams.builder();
try {
if (request instanceof HttpServletRequest){
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
filters = getMatchingFilters(httpServletRequest.getRequestURI());

Map<String, String> headers = Collections.list(httpServletRequest.getHeaderNames())
.stream().collect(Collectors.toMap(String::toLowerCase, httpServletRequest::getHeader));
requestParamsBuilder.metadata(headers);
requestParamsBuilder.clientIp(getClientIp(request));
requestParamsBuilder.method(httpServletRequest.getMethod());
requestParamsBuilder.resourcePath(getFullURL(httpServletRequest));
}

if (request instanceof HttpServletRequest && response instanceof HttpServletResponse) {
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
HttpServletResponse httpServletResponse = (HttpServletResponse) response;

List<HttpFilter> filters = getMatchingFilters(httpServletRequest.getRequestURI());
Map<String, String> headers = Collections.list(httpServletRequest.getHeaderNames())
.stream().collect(Collectors.toMap(String::toLowerCase, httpServletRequest::getHeader));
requestParamsBuilder.metadata(headers);
requestParamsBuilder.clientIp(getClientIp(request));
requestParamsBuilder.method(httpServletRequest.getMethod());
requestParamsBuilder.resourcePath(getFullURL(httpServletRequest));

RequestParams<Map<String, String>> requestParams = requestParamsBuilder.build();
filters.forEach(filter -> filter.doProcessRequest(request, requestParams));
chain.doFilter(request, response);
} finally {
if (response instanceof HttpServletResponse) {
HttpServletResponse httpServletResponse = (HttpServletResponse) response;
Map<String, String> headers = httpServletResponse.getHeaderNames()
FilterServletResponseWrapper responseWrapper = new FilterServletResponseWrapper(httpServletResponse);

try {
filters.forEach(filter -> filter.doProcessRequest(request, requestParams));
chain.doFilter(request, responseWrapper);

// Allow the filters to process the response headers
Map<String, String> responseHeaders = responseWrapper.getHeaderNames()
.stream().collect(Collectors.toMap(String::toLowerCase, httpServletResponse::getHeader));
filters.forEach(filter -> filter.doProcessResponseHeaders(headers));
filters.forEach(filter -> filter.doProcessResponseHeaders(responseHeaders));
responseHeaders.forEach(responseWrapper::setHeader);
} finally {
// Allow the filters to process the response
filters.forEach(filter -> filter.doProcessResponse(responseWrapper));
response.getOutputStream().write(responseWrapper.getWrapperBytes());
}
filters.forEach(filter -> filter.doProcessResponse(response));

} else {
// For Unsupported request types, pass the request and response as is
chain.doFilter(request, response);
logger.warn("Unsupported request type {}, pass the request and response as is.", request.getClass());
}


}

/**
Expand Down

0 comments on commit 0369cd4

Please sign in to comment.