Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support VPC Lattice as an event source #845

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
30b161f
Added VPCLattice event
mbfreder Mar 7, 2024
38946e4
Added Servlet request and Security Context classes
mbfreder Mar 14, 2024
b9d9ad4
Added Springboot tests
mbfreder Apr 9, 2024
229ddbc
fix nullPointerException on request.getIsBase64Encoded()
mbfreder Apr 11, 2024
0679c1d
Added Jersey tests
mbfreder Apr 16, 2024
2de6093
Added Spring tests
mbfreder Apr 17, 2024
5117dfb
Added Sample app
mbfreder Apr 17, 2024
a71535b
Upgrade to Spring 6.1.6 to fix CVE-2024-22262
mbfreder Apr 17, 2024
7a645ec
remove duplicate lombok dep
mbfreder Apr 17, 2024
679016f
Fix java21 error
mbfreder Apr 17, 2024
2f3fad3
Remove lombok dependency
mbfreder Apr 29, 2024
66a0fc8
return HTTP/1.1 as default protocol
mbfreder Apr 29, 2024
6c57dde
Update SpringBoot version on sample
mbfreder Apr 29, 2024
0108f30
enabled auth to test sample app with awscurl
mbfreder Apr 29, 2024
07a5aba
Merge branch 'main' into vpclattice-integ
mbfreder Apr 29, 2024
4ce1ae6
reduce duplicated code for getServletConnection()
deki May 10, 2024
b587da7
reduce duplicated code for authenticate
deki May 22, 2024
6ca3d07
reduce duplicated code for login/ logout
deki May 22, 2024
6f2b629
reduce duplicated code for upgrade
deki May 22, 2024
5bc26a3
pull up SecurityContext and reduce duplicated code
deki May 22, 2024
389f834
Merge branch 'refs/heads/main' into vpclattice-integ
deki May 22, 2024
eb1b7a4
reduce duplicated code for getRequestDispatcher
deki May 22, 2024
47e8fc9
Merge remote-tracking branch 'origin/main' into vpclattice-integ
deki Aug 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions aws-serverless-java-container-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
<version>1.2.3</version>
</dependency>

<dependency>
<groupId>org.jetbrains</groupId>
<artifactId>annotations</artifactId>
<version>24.0.1</version>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>jakarta.servlet</groupId>
<artifactId>jakarta.servlet-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.amazonaws.serverless.proxy;

import com.amazonaws.serverless.proxy.internal.jaxrs.AwsVpcLatticeV2SecurityContext;
import com.amazonaws.serverless.proxy.model.VPCLatticeV2RequestEvent;
import com.amazonaws.services.lambda.runtime.Context;
import jakarta.ws.rs.core.SecurityContext;

public class AwsVPCLatticeV2SecurityContextWriter implements SecurityContextWriter<VPCLatticeV2RequestEvent>{
@Override
public SecurityContext writeSecurityContext(VPCLatticeV2RequestEvent event, Context lambdaContext) {
return new AwsVpcLatticeV2SecurityContext(lambdaContext, event);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ public abstract class RequestReader<RequestType, ContainerRequestType> {
*/
public static final String API_GATEWAY_CONTEXT_PROPERTY = "com.amazonaws.apigateway.request.context";

/**
* The key for the <strong>VPC Lattice V2 context</strong> property in the PropertiesDelegate object
*/
public static final String VPC_LATTICE_V2_CONTEXT_PROPERTY = "com.amazonaws.vpclattice.request.context";

/**
* The key for the <strong>API Gateway stage variables</strong> property in the PropertiesDelegate object
*/
Expand All @@ -55,6 +60,11 @@ public abstract class RequestReader<RequestType, ContainerRequestType> {
*/
public static final String API_GATEWAY_EVENT_PROPERTY = "com.amazonaws.apigateway.request";

/**
* The key to store the entire VPC Lattice V2 event
*/
public static final String VPC_LATTICE_V2_EVENT_PROPERTY = "com.amazonaws.vpclattice.request";

/**
* The key for the <strong>AWS Lambda context</strong> property in the PropertiesDelegate object
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.amazonaws.serverless.proxy.internal.jaxrs;

import com.amazonaws.serverless.proxy.model.VPCLatticeV2RequestEvent;
import com.amazonaws.services.lambda.runtime.Context;
import jakarta.ws.rs.core.SecurityContext;

import java.security.Principal;
import java.util.Objects;

/**
* default implementation of the <code>SecurityContext</code> object. This class supports 1 VPC Lattice authentication type:
* AWS_IAM.
*/
public class AwsVpcLatticeV2SecurityContext implements SecurityContext {

static final String AUTH_SCHEME_AWS_IAM = "AWS_IAM";


private final VPCLatticeV2RequestEvent event;

public AwsVpcLatticeV2SecurityContext(Context lambdaContext, VPCLatticeV2RequestEvent event) {
this.event = event;
}

//-------------------------------------------------------------
// Implementation - SecurityContext
//-------------------------------------------------------------
@Override
public Principal getUserPrincipal() {
if (Objects.equals(getAuthenticationScheme(), AUTH_SCHEME_AWS_IAM)) {
return () -> getEvent().getRequestContext().getIdentity().getPrincipal();
}
return null;
}

private VPCLatticeV2RequestEvent getEvent() {
return event;
}


@Override
public boolean isUserInRole(String role) {
return role.equals(event.getRequestContext().getIdentity().getPrincipal());
}

@Override
public boolean isSecure() {
return getAuthenticationScheme() != null;
}

@Override
public String getAuthenticationScheme() {
if (Objects.equals(getEvent().getRequestContext().getIdentity().getType(), AUTH_SCHEME_AWS_IAM)) {
return AUTH_SCHEME_AWS_IAM;
} else {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ public class AwsHttpApiV2ProxyHttpServletRequest extends AwsHttpServletRequest {
private MultiValuedTreeMap<String, String> queryString;
private Headers headers;
private ContainerConfig config;
private SecurityContext securityContext;
private AwsAsyncContext asyncContext;

/**
Expand All @@ -57,10 +56,9 @@ public class AwsHttpApiV2ProxyHttpServletRequest extends AwsHttpServletRequest {
* @param lambdaContext The Lambda function context. This object is used for utility methods such as log
*/
public AwsHttpApiV2ProxyHttpServletRequest(HttpApiV2ProxyRequest req, Context lambdaContext, SecurityContext sc, ContainerConfig cfg) {
super(lambdaContext);
super(lambdaContext, sc);
request = req;
config = cfg;
securityContext = sc;
queryString = parseRawQueryString(request.getRawQueryString());
headers = headersMapToMultiValue(request.getHeaders());
}
Expand All @@ -69,12 +67,6 @@ public HttpApiV2ProxyRequest getRequest() {
return request;
}

@Override
public String getAuthType() {
// TODO
return null;
}

@Override
public Cookie[] getCookies() {
Cookie[] rhc;
Expand Down Expand Up @@ -108,56 +100,27 @@ public Cookie[] getCookies() {

@Override
public long getDateHeader(String s) {
if (headers == null) {
return -1L;
}
String dateString = headers.getFirst(s);
if (dateString == null) {
return -1L;
}
try {
return Instant.from(ZonedDateTime.parse(dateString, dateFormatter)).toEpochMilli();
} catch (DateTimeParseException e) {
log.warn("Invalid date header in request: " + SecurityUtils.crlf(dateString));
return -1L;
}
return getDateHeader(s, headers);
}

@Override
public String getHeader(String s) {
if (headers == null) {
return null;
}
return headers.getFirst(s);
return getHeader(s, headers);
}

@Override
public Enumeration<String> getHeaders(String s) {
if (headers == null || !headers.containsKey(s)) {
return Collections.emptyEnumeration();
}
return Collections.enumeration(headers.get(s));
return getHeaders(s, headers);
}

@Override
public Enumeration<String> getHeaderNames() {
if (headers == null) {
return Collections.emptyEnumeration();
}
return Collections.enumeration(headers.keySet());
return getHeaderNames(headers);
}

@Override
public int getIntHeader(String s) {
if (headers == null) {
return -1;
}
String headerValue = headers.getFirst(s);
if (headerValue == null || "".equals(headerValue)) {
return -1;
}

return Integer.parseInt(headerValue);
return getIntHeader(s, headers);
}

@Override
Expand Down Expand Up @@ -187,28 +150,6 @@ public String getQueryString() {
return request.getRawQueryString();
}

@Override
public String getRemoteUser() {
if (securityContext == null || securityContext.getUserPrincipal() == null) {
return null;
}
return securityContext.getUserPrincipal().getName();
}

@Override
public boolean isUserInRole(String s) {
// TODO: Not supported
return false;
}

@Override
public Principal getUserPrincipal() {
if (securityContext == null) {
return null;
}
return securityContext.getUserPrincipal();
}

@Override
public String getRequestURI() {
return cleanUri(getContextPath()) + cleanUri(request.getRawPath());
Expand All @@ -219,27 +160,6 @@ public StringBuffer getRequestURL() {
return generateRequestURL(request.getRawPath());
}


@Override
public boolean authenticate(HttpServletResponse httpServletResponse) throws IOException, ServletException {
throw new UnsupportedOperationException();
}

@Override
public void login(String s, String s1) throws ServletException {
throw new UnsupportedOperationException();
}

@Override
public void logout() throws ServletException {
throw new UnsupportedOperationException();
}

@Override
public <T extends HttpUpgradeHandler> T upgrade(Class<T> aClass) throws IOException, ServletException {
throw new UnsupportedOperationException();
}

@Override
public String getCharacterEncoding() {
if (headers == null) {
Expand All @@ -250,30 +170,17 @@ public String getCharacterEncoding() {

@Override
public void setCharacterEncoding(String s) throws UnsupportedEncodingException {
if (headers == null || !headers.containsKey(HttpHeaders.CONTENT_TYPE)) {
log.debug("Called set character encoding to " + SecurityUtils.crlf(s) + " on a request without a content type. Character encoding will not be set");
return;
}
String currentContentType = headers.getFirst(HttpHeaders.CONTENT_TYPE);
headers.putSingle(HttpHeaders.CONTENT_TYPE, appendCharacterEncoding(currentContentType, s));
setCharacterEncoding(s, headers);
}

@Override
public int getContentLength() {
String headerValue = headers.getFirst(HttpHeaders.CONTENT_LENGTH);
if (headerValue == null) {
return -1;
}
return Integer.parseInt(headerValue);
return getContentLength(headers);
}

@Override
public long getContentLengthLong() {
String headerValue = headers.getFirst(HttpHeaders.CONTENT_LENGTH);
if (headerValue == null) {
return -1;
}
return Long.parseLong(headerValue);
return getContentLengthLong(headers);
}

@Override
Expand All @@ -286,17 +193,7 @@ public String getContentType() {

@Override
public String getParameter(String s) {
String queryStringParameter = getFirstQueryParamValue(queryString, s, config.isQueryStringCaseSensitive());
if (queryStringParameter != null) {
return queryStringParameter;
}

String[] bodyParams = getFormBodyParameterCaseInsensitive(s);
if (bodyParams.length == 0) {
return null;
} else {
return bodyParams[0];
}
return getParameter(queryString, s, config.isQueryStringCaseSensitive());
}

@Override
Expand All @@ -315,7 +212,7 @@ public String[] getParameterValues(String s) {

values.addAll(Arrays.asList(getFormBodyParameterCaseInsensitive(s)));

if (values.size() == 0) {
if (values.isEmpty()) {
return null;
} else {
return values.toArray(new String[0]);
Expand Down Expand Up @@ -409,16 +306,6 @@ public Enumeration<Locale> getLocales() {
return Collections.enumeration(locales);
}

@Override
public boolean isSecure() {
return securityContext.isSecure();
}

@Override
public RequestDispatcher getRequestDispatcher(String s) {
return getServletContext().getRequestDispatcher(s);
}

@Override
public int getRemotePort() {
return 0;
Expand Down Expand Up @@ -456,6 +343,8 @@ public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse se
return asyncContext;
}



@Override
public AsyncContext getAsyncContext() {
if (asyncContext == null) {
Expand All @@ -475,11 +364,6 @@ public String getProtocolRequestId() {
return "";
}

@Override
public ServletConnection getServletConnection() {
return null;
}

private MultiValuedTreeMap<String, String> parseRawQueryString(String qs) {
if (qs == null || "".equals(qs.trim())) {
return new MultiValuedTreeMap<>();
Expand All @@ -505,7 +389,7 @@ private MultiValuedTreeMap<String, String> parseRawQueryString(String qs) {
return qsMap;
}

private Headers headersMapToMultiValue(Map<String, String> headers) {
protected static Headers headersMapToMultiValue(Map<String, String> headers) {
if (headers == null || headers.size() == 0) {
return new Headers();
}
Expand Down
Loading
Loading