Skip to content

Commit

Permalink
Override RegionSet in EnpointResolverInterceptor after fetching the S…
Browse files Browse the repository at this point in the history
…igning Properties from Endpoint rules (#5825)

* Override RegionSet in EnpointResolverInterceptor after fetching the Signing Properties from Endpoint rules

* Endpoint Resolver Spec test added

* Add test case for service with MultiAuth only and not using Sigv4

* checkstyle fixed

* Handled comments

* Updated after Integ test failure and some comments
  • Loading branch information
joviegas authored Feb 3, 2025
1 parent 1efd369 commit 3001dad
Show file tree
Hide file tree
Showing 21 changed files with 1,039 additions and 262 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ public static Metadata constructMetadata(ServiceModel serviceModel,
.withJsonVersion(serviceMetadata.getJsonVersion())
.withEndpointPrefix(serviceMetadata.getEndpointPrefix())
.withSigningName(serviceMetadata.getSigningName())
.withAuthType(AuthType.fromValue(serviceMetadata.getSignatureVersion()))
.withAuthType(serviceMetadata.getSignatureVersion() != null ?
AuthType.fromValue(serviceMetadata.getSignatureVersion()) : null)
.withUid(serviceMetadata.getUid())
.withServiceId(serviceMetadata.getServiceId())
.withSupportsH2(supportsH2(serviceMetadata))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import software.amazon.awssdk.metrics.MetricCollector;
import software.amazon.awssdk.metrics.SdkMetric;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Validate;

Expand Down Expand Up @@ -196,7 +197,6 @@ private MethodSpec generateAuthSchemeParams() {
builder.addStatement("(($T)builder).endpointProvider(($T)endpointProvider)", paramsBuilderClass, endpointProviderClass);
builder.endControlFlow();
builder.endControlFlow();
// TODO: Implement addRegionSet() for legacy services that resolve authentication from endpoints in one of next PRs.
builder.addStatement("return builder.build()");
return builder.build();
}
Expand Down Expand Up @@ -452,19 +452,13 @@ private TypeName toTypeName(Object valueType) {
private void generateSigv4aRegionSet(MethodSpec.Builder builder) {
if (authSchemeSpecUtils.usesSigV4a()) {
builder.addStatement(
"$T regionSet = executionAttributes.getOptionalAttribute($T.AWS_SIGV4A_SIGNING_REGION_SET)\n" +
" .filter(regions -> !regions.isEmpty())\n" +
" .map(regions -> $T.create(String.join(\", \", regions)))\n" +
" .orElseGet(() -> {\n" +
" $T fallbackRegion = executionAttributes.getAttribute($T.AWS_REGION);\n" +
" return fallbackRegion != null ? $T.create(fallbackRegion.toString()) : null;\n" +
" });",
RegionSet.class, AwsExecutionAttribute.class,
RegionSet.class, Region.class, AwsExecutionAttribute.class,
"executionAttributes.getOptionalAttribute($T.AWS_SIGV4A_SIGNING_REGION_SET)\n" +
" .filter(regionSet -> !$T.isNullOrEmpty(regionSet))\n" +
" .ifPresent(nonEmptyRegionSet -> builder.regionSet($T.create(nonEmptyRegionSet)))",
AwsExecutionAttribute.class,
CollectionUtils.class,
RegionSet.class
);

builder.addStatement("builder.regionSet(regionSet)");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ public class EndpointResolverInterceptorSpec implements ClassSpec {
private final JmesPathAcceptorGenerator jmesPathGenerator;
private final boolean dependsOnHttpAuthAws;
private final boolean useSraAuth;
private final boolean multiAuthSigv4a;


public EndpointResolverInterceptorSpec(IntermediateModel model) {
Expand All @@ -116,6 +117,7 @@ public EndpointResolverInterceptorSpec(IntermediateModel model) {
supportedAuthSchemes.contains(AwsV4aAuthScheme.class);

this.useSraAuth = new AuthSchemeSpecUtils(model).useSraAuth();
this.multiAuthSigv4a = new AuthSchemeSpecUtils(model).usesSigV4a();
}

@Override
Expand Down Expand Up @@ -192,7 +194,9 @@ private MethodSpec modifyRequestMethod(String endpointAuthSchemeStrategyFieldNam
endpointRulesSpecUtils.providerInterfaceName(), providerVar, SdkInternalExecutionAttribute.class);
b.beginControlFlow("try");
b.addStatement("long resolveEndpointStart = $T.nanoTime()", System.class);
b.addStatement("$T endpoint = $N.resolveEndpoint(ruleParams(result, executionAttributes)).join()",
b.addStatement("$T endpointParams = ruleParams(result, executionAttributes)",
endpointRulesSpecUtils.parametersClassName());
b.addStatement("$T endpoint = $N.resolveEndpoint(endpointParams).join()",
Endpoint.class, providerVar);
b.addStatement("$1T resolveEndpointDuration = $1T.ofNanos($2T.nanoTime() - resolveEndpointStart)", Duration.class,
System.class);
Expand All @@ -219,7 +223,20 @@ private MethodSpec modifyRequestMethod(String endpointAuthSchemeStrategyFieldNam
SelectedAuthScheme.class, SdkInternalExecutionAttribute.class);
b.beginControlFlow("if (endpointAuthSchemes != null && selectedAuthScheme != null)");
b.addStatement("selectedAuthScheme = authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme)");

if (multiAuthSigv4a) {
b.addComment("Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications");
b.beginControlFlow("if(selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID) "
+ "&& selectedAuthScheme.authSchemeOption().signerProperty($T.REGION_SET) == null)",
AwsV4aAuthScheme.class, AwsV4aHttpSigner.class);
b.addStatement("$T optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder()",
AuthSchemeOption.Builder.class);
b.addStatement("$T regionSet = $T.create(endpointParams.region().id())",
RegionSet.class, RegionSet.class);
b.addStatement("optionBuilder.putSignerProperty($T.REGION_SET, regionSet)", AwsV4aHttpSigner.class);
b.addStatement("selectedAuthScheme = new $T(selectedAuthScheme.identity(), selectedAuthScheme.signer(), "
+ "optionBuilder.build())", SelectedAuthScheme.class);
b.endControlFlow();
}
b.addStatement("executionAttributes.putAttribute($T.SELECTED_AUTH_SCHEME, selectedAuthScheme)",
SdkInternalExecutionAttribute.class);
b.endControlFlow();
Expand Down Expand Up @@ -774,7 +791,7 @@ private static CodeBlock copyV4EndpointSignerPropertiesToAuth() {
return code.build();
}

private static CodeBlock copyV4aEndpointSignerPropertiesToAuth() {
private CodeBlock copyV4aEndpointSignerPropertiesToAuth() {
CodeBlock.Builder code = CodeBlock.builder();

code.beginControlFlow("if (endpointAuthScheme instanceof $T)", SigV4aAuthScheme.class);
Expand All @@ -784,10 +801,15 @@ private static CodeBlock copyV4aEndpointSignerPropertiesToAuth() {
code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding())",
AwsV4aHttpSigner.class);
code.endControlFlow();

code.beginControlFlow("if (v4aAuthScheme.signingRegionSet() != null)");
if (multiAuthSigv4a) {
code.beginControlFlow("if (!(selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID) "
+ "&& selectedAuthScheme.authSchemeOption().signerProperty($T.REGION_SET) != null) "
+ "&& v4aAuthScheme.signingRegionSet() != null)",
AwsV4aAuthScheme.class, AwsV4aHttpSigner.class);
} else {
code.beginControlFlow("if (v4aAuthScheme.signingRegionSet() != null)");
}
code.addStatement("$1T regionSet = $1T.create(v4aAuthScheme.signingRegionSet())", RegionSet.class);

code.addStatement("option.putSignerProperty($T.REGION_SET, regionSet)", AwsV4aHttpSigner.class);
code.endControlFlow();

Expand Down Expand Up @@ -881,5 +903,4 @@ private MethodSpec constructorMethodSpec(String endpointAuthSchemeFieldName) {
b.addStatement("this.$N = $N.endpointAuthSchemeStrategy()", endpointAuthSchemeFieldName, factoryLocalVarName);
return b.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel;
import software.amazon.awssdk.codegen.model.intermediate.OperationModel;
import software.amazon.awssdk.codegen.model.service.AuthType;
import software.amazon.awssdk.utils.CollectionUtils;

public final class AuthUtils {
private AuthUtils() {
Expand Down Expand Up @@ -76,6 +77,12 @@ private static boolean isServiceSigv4a(IntermediateModel model) {

private static boolean isServiceAwsAuthType(IntermediateModel model) {
AuthType authType = model.getMetadata().getAuthType();
if (authType == null && !CollectionUtils.isNullOrEmpty(model.getMetadata().getAuth())) {
return model.getMetadata().getAuth().stream()
.map(AuthType::value)
.map(AuthType::fromValue)
.anyMatch(AuthUtils::isAuthTypeAws);
}
return isAuthTypeAws(authType);
}

Expand All @@ -85,6 +92,7 @@ private static boolean isAuthTypeAws(AuthType authType) {
}

switch (authType) {
case V4A:
case V4:
case S3:
case S3V4:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,10 @@ private static IntermediateModel getModel(boolean useSraAuth) {
model.getCustomizationConfig().setUseSraAuth(useSraAuth);
return model;
}

@Test
void endpointResolverInterceptorClassWithSigv4aMultiAuth() {
ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.opsWithSigv4a());
assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-with-multiauthsigv4a.java"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeParams;
import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeProvider;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Validate;

Expand Down Expand Up @@ -88,14 +89,9 @@ private DatabaseAuthSchemeParams authSchemeParams(SdkRequest request, ExecutionA
DatabaseAuthSchemeParams.Builder builder = DatabaseAuthSchemeParams.builder().operation(operation);
Region region = executionAttributes.getAttribute(AwsExecutionAttribute.AWS_REGION);
builder.region(region);
RegionSet regionSet = executionAttributes.getOptionalAttribute(AwsExecutionAttribute.AWS_SIGV4A_SIGNING_REGION_SET)
.filter(regions -> !regions.isEmpty()).map(regions -> RegionSet.create(String.join(", ", regions)))
.orElseGet(() -> {
Region fallbackRegion = executionAttributes.getAttribute(AwsExecutionAttribute.AWS_REGION);
return fallbackRegion != null ? RegionSet.create(fallbackRegion.toString()) : null;
});
;
builder.regionSet(regionSet);
executionAttributes.getOptionalAttribute(AwsExecutionAttribute.AWS_SIGV4A_SIGNING_REGION_SET)
.filter(regionSet -> !CollectionUtils.isNullOrEmpty(regionSet))
.ifPresent(nonEmptyRegionSet -> builder.regionSet(RegionSet.create(nonEmptyRegionSet)));
return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
{
"version": "1.2",
"serviceId": "Database Service",
"parameters": {
"region": {
"type": "string",
"builtIn": "AWS::Region",
"required": true,
"documentation": "The region to send requests to"
},
"useDualStackEndpoint": {
"type": "boolean",
"builtIn": "AWS::UseDualStack"
},
"useFIPSEndpoint": {
"type": "boolean",
"builtIn": "AWS::UseFIPS"
},
"AccountId": {
"type": "String",
"builtIn": "AWS::Auth::AccountId"
},
"operationContextParam": {
"type": "string"
}
},
"rules": [
{
"conditions": [
{
"fn": "aws.partition",
"argv": [
{
"ref": "region"
}
],
"assign": "partitionResult"
}
],
"rules": [
{
"conditions": [
{
"fn": "isSet",
"argv": [
{
"ref": "endpointId"
}
]
}
],
"rules": [
{
"conditions": [
{
"fn": "isSet",
"argv": [
{
"ref": "useFIPSEndpoint"
}
]
}
],
"error": "FIPS endpoints not supported with multi-region endpoints",
"type": "error"
},
{
"endpoint": {
"url": "https://{endpointId}.query.{partitionResult#dualStackDnsSuffix}",
"properties": {
"authSchemes": [
{
"name": "sigv4a",
"signingName": "query",
"signingRegionSet": ["*"]
}
]
}
},
"type": "endpoint"
}
],
"type": "tree"
}
],
"type": "tree"
}
]
}

Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public SdkRequest modifyRequest(Context.ModifyRequest context, ExecutionAttribut
.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER);
try {
long resolveEndpointStart = System.nanoTime();
Endpoint endpoint = provider.resolveEndpoint(ruleParams(result, executionAttributes)).join();
QueryEndpointParams endpointParams = ruleParams(result, executionAttributes);
Endpoint endpoint = provider.resolveEndpoint(endpointParams).join();
Duration resolveEndpointDuration = Duration.ofNanos(System.nanoTime() - resolveEndpointStart);
Optional<MetricCollector> metricCollector = executionAttributes
.getOptionalAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR);
Expand Down
Loading

0 comments on commit 3001dad

Please sign in to comment.