Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow mutation in HttpFilters #87

Merged
merged 8 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright (c) The original author or authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.flipkart.gjex.examples.helloworld.filter;

import com.flipkart.gjex.core.filter.RequestParams;
import com.flipkart.gjex.core.filter.http.HttpFilter;

import javax.inject.Named;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletResponse;
import java.util.Map;

@Named("CustomHeaderHttpFilter")
public class CustomHeaderHttpFilter extends HttpFilter {

@Override
public void doProcessRequest(ServletRequest servletRequest, RequestParams<Map<String, String>> requestParams) {
super.doProcessRequest(servletRequest, requestParams);
}

@Override
public void doProcessResponseHeaders(Map<String, String> responseHeaders) {
super.doProcessResponseHeaders(responseHeaders);
responseHeaders.put("x-custom-header1", "value1");
}

@Override
public void doProcessResponse(ServletResponse response) {
super.doProcessResponse(response);
HttpServletResponse httpServletResponse = (HttpServletResponse) response;
httpServletResponse.addHeader("x-custom-header2", "value2");
}

@Override
public void doHandleException(Exception e) {
super.doHandleException(e);
}

@Override
public HttpFilter getInstance() {
return new CustomHeaderHttpFilter();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
package com.flipkart.gjex.examples.helloworld.guice;

import com.flipkart.gjex.core.filter.grpc.GrpcFilter;
import com.flipkart.gjex.core.filter.http.HttpFilterParams;
import com.flipkart.gjex.core.filter.http.JavaxFilterParams;
import com.flipkart.gjex.core.tracing.TracingSampler;
import com.flipkart.gjex.examples.helloworld.filter.AuthFilter;
import com.flipkart.gjex.examples.helloworld.filter.CustomHeaderHttpFilter;
import com.flipkart.gjex.examples.helloworld.filter.LoggingFilter;
import com.flipkart.gjex.examples.helloworld.service.GreeterService;
import com.flipkart.gjex.examples.helloworld.tracing.AllWhitelistTracingSampler;
Expand Down Expand Up @@ -54,7 +56,8 @@ protected void configure() {
// bind(AccessLogGrpcFilter.class).to(AccessLogTestFilter.class);
bind(TracingSampler.class).to(AllWhitelistTracingSampler.class);
bind(ResourceConfig.class).annotatedWith(Names.named("HelloWorldResourceConfig")).to(HelloWorldResourceConfig.class);
bind(JavaxFilterParams.class).annotatedWith(Names.named("ExampleJavaxFilter"))
.toInstance(JavaxFilterParams.builder().filter(new ExampleJavaxFilter()).pathSpec("/*").build());
bind(JavaxFilterParams.class).annotatedWith(Names.named("ExampleJavaxFilter")).toInstance(JavaxFilterParams.builder().filter(new ExampleJavaxFilter()).pathSpec("/*").build());
bind(HttpFilterParams.class).annotatedWith(Names.named("CustomHeaderHttpFilter")).toInstance(HttpFilterParams.builder().filter(new CustomHeaderHttpFilter()).pathSpec("/*").build());

}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
/*
* Copyright (c) The original author or authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.flipkart.gjex.examples.helloworld.web.javaxfilter;

import com.flipkart.gjex.core.logging.Logging;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright (c) The original author or authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.flipkart.gjex.http.interceptor;

import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintWriter;

/**
* A wrapper for HttpServletResponse that captures the output stream and writer.
*/
public class FilterServletResponseWrapper extends HttpServletResponseWrapper {

private final ServletOutputStreamWrapper stream = new ServletOutputStreamWrapper();
private PrintWriter pw;

/**
* Constructs a response object wrapping the given response.
*
* @param response the HttpServletResponse to be wrapped
* @throws IllegalArgumentException if the response is null
*/
public FilterServletResponseWrapper(HttpServletResponse response) {
super(response);
}

/**
* Returns the ServletOutputStream for this response.
*
* @return the ServletOutputStream
* @throws IOException if an I/O error occurs
*/
public ServletOutputStream getOutputStream() throws IOException {
if (pw != null) {
pw.flush();
}
return stream;
}

/**
* Returns a PrintWriter for this response.
*
* @return the PrintWriter
* @throws IOException if an I/O error occurs
*/
public PrintWriter getWriter() throws IOException {
pw = new PrintWriter(stream);
return pw;
}

/**
* Returns the captured bytes from the output stream.
*
* @return a byte array containing the captured bytes
*/
public byte[] getWrapperBytes() {
return stream.getBytes();
}

/**
* A wrapper for ServletOutputStream that captures written bytes.
*/
static class ServletOutputStreamWrapper extends ServletOutputStream {

private final ByteArrayOutputStream out = new ByteArrayOutputStream();

/**
* Writes the specified byte to this output stream.
*
* @param b the byte to be written
* @throws IOException if an I/O error occurs
*/
public void write(int b) throws IOException {
out.write(b);
}

/**
* Returns the captured bytes from the output stream.
*
* @return a byte array containing the captured bytes
*/
public byte[] getBytes() {
return out.toByteArray();
}

/**
* Indicates whether this output stream is ready to be written to.
*
* @return true if the output stream is ready, false otherwise
*/
@Override
public boolean isReady() {
return true;
}

/**
* Sets the WriteListener for this output stream.
*
* @param writeListener the WriteListener to be set
*/
@Override
public void setWriteListener(WriteListener writeListener) {
try {
writeListener.onWritePossible();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,27 @@
import com.flipkart.gjex.core.filter.http.HttpFilter;
import com.flipkart.gjex.core.filter.http.HttpFilterParams;
import org.eclipse.jetty.http.pathmap.ServletPathSpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.inject.Named;
import javax.inject.Singleton;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Singleton
@Named("HttpFilterInterceptor")
public class HttpFilterInterceptor implements javax.servlet.Filter {

private static final Logger logger = LoggerFactory.getLogger(HttpFilterInterceptor.class);

private static class ServletPathFiltersHolder {
ServletPathSpec spec;
HttpFilter filter;
Expand Down Expand Up @@ -62,32 +65,45 @@ public void init(FilterConfig filterConfig) throws ServletException {}
public final void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {

List<HttpFilter> filters = new ArrayList<>();
RequestParams.RequestParamsBuilder<Map<String,String>> requestParamsBuilder = RequestParams.builder();
try {
if (request instanceof HttpServletRequest){
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
filters = getMatchingFilters(httpServletRequest.getRequestURI());

Map<String, String> headers = Collections.list(httpServletRequest.getHeaderNames())
.stream().collect(Collectors.toMap(String::toLowerCase, httpServletRequest::getHeader));
requestParamsBuilder.metadata(headers);
requestParamsBuilder.clientIp(getClientIp(request));
requestParamsBuilder.method(httpServletRequest.getMethod());
requestParamsBuilder.resourcePath(getFullURL(httpServletRequest));
}

if (request instanceof HttpServletRequest && response instanceof HttpServletResponse) {
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
HttpServletResponse httpServletResponse = (HttpServletResponse) response;

List<HttpFilter> filters = getMatchingFilters(httpServletRequest.getRequestURI());
Map<String, String> headers = Collections.list(httpServletRequest.getHeaderNames())
.stream().collect(Collectors.toMap(String::toLowerCase, httpServletRequest::getHeader));
requestParamsBuilder.metadata(headers);
requestParamsBuilder.clientIp(getClientIp(request));
requestParamsBuilder.method(httpServletRequest.getMethod());
requestParamsBuilder.resourcePath(getFullURL(httpServletRequest));

RequestParams<Map<String, String>> requestParams = requestParamsBuilder.build();
filters.forEach(filter -> filter.doProcessRequest(request, requestParams));
chain.doFilter(request, response);
} finally {
if (response instanceof HttpServletResponse) {
HttpServletResponse httpServletResponse = (HttpServletResponse) response;
Map<String, String> headers = httpServletResponse.getHeaderNames()
FilterServletResponseWrapper responseWrapper = new FilterServletResponseWrapper(httpServletResponse);

try {
filters.forEach(filter -> filter.doProcessRequest(request, requestParams));
chain.doFilter(request, responseWrapper);

// Allow the filters to process the response headers
Map<String, String> responseHeaders = responseWrapper.getHeaderNames()
.stream().collect(Collectors.toMap(String::toLowerCase, httpServletResponse::getHeader));
filters.forEach(filter -> filter.doProcessResponseHeaders(headers));
filters.forEach(filter -> filter.doProcessResponseHeaders(responseHeaders));
responseHeaders.forEach(responseWrapper::setHeader);
} finally {
// Allow the filters to process the response
filters.forEach(filter -> filter.doProcessResponse(responseWrapper));
response.getOutputStream().write(responseWrapper.getWrapperBytes());
}
filters.forEach(filter -> filter.doProcessResponse(response));

} else {
// For Unsupported request types, pass the request and response as is
chain.doFilter(request, response);
logger.warn("Unsupported request type {}, pass the request and response as is.", request.getClass());
}


}

/**
Expand Down