Skip to content

Commit

Permalink
Add UT for LambdaClientFactory
Browse files Browse the repository at this point in the history
Signed-off-by: Srikanth Govindarajan <[email protected]>
  • Loading branch information
srikanthjg committed Nov 17, 2024
1 parent cd72a82 commit 0eca813
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.dataprepper.plugins.lambda.processor;

import org.junit.jupiter.params.provider.EnumSource;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.spy;
import org.mockito.junit.jupiter.MockitoSettings;
Expand All @@ -21,7 +20,6 @@

import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.expression.ExpressionEvaluator;
import org.opensearch.dataprepper.model.configuration.PluginModel;
import org.opensearch.dataprepper.model.configuration.PluginSetting;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType;
Expand Down Expand Up @@ -50,25 +48,21 @@
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.lenient;

import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import java.util.List;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,185 +1,126 @@
package org.opensearch.dataprepper.plugins.lambda.common.client;

import java.time.Duration;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;

import java.time.Duration;
import java.util.HashMap;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class LambdaClientFactoryTest {

@Mock
private AwsCredentialsSupplier awsCredentialsSupplier;
private AwsAuthenticationOptions awsAuthenticationOptions;

@Mock
private AwsAuthenticationOptions awsAuthenticationOptions;
private AwsCredentialsSupplier awsCredentialsSupplier;

@Mock
private AwsCredentialsProvider awsCredentialsProvider;

private Duration sdkTimeout = Duration.ofSeconds(60);

/*@Test
void createLambdaAsyncClient_with_real_LambdaAsyncClient() {
try (MockedStatic<LambdaAsyncClient> mockedStaticLambdaAsyncClient = mockStatic(
LambdaAsyncClient.class);
MockedStatic<PluginMetrics> mockedPluginMetrics = mockStatic(PluginMetrics.class)) {
PluginMetrics pluginMetricsMock = mock(PluginMetrics.class);
mockedPluginMetrics.when(() -> PluginMetrics.fromNames("sdk", "aws"))
.thenReturn(pluginMetricsMock);
LambdaAsyncClientBuilder lambdaAsyncClientBuilder = mock(LambdaAsyncClientBuilder.class);
LambdaAsyncClient lambdaAsyncClientMock = mock(LambdaAsyncClient.class);
mockedStaticLambdaAsyncClient.when(LambdaAsyncClient::builder)
.thenReturn(lambdaAsyncClientBuilder);
when(lambdaAsyncClientBuilder.region(any(Region.class))).thenReturn(lambdaAsyncClientBuilder);
when(lambdaAsyncClientBuilder.credentialsProvider(
any(AwsCredentialsProvider.class))).thenReturn(lambdaAsyncClientBuilder);
when(lambdaAsyncClientBuilder.overrideConfiguration(
any(ClientOverrideConfiguration.class))).thenReturn(lambdaAsyncClientBuilder);
when(lambdaAsyncClientBuilder.build()).thenReturn(lambdaAsyncClientMock);
Region region = Region.US_WEST_2;
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region);
String stsRoleArn = UUID.randomUUID().toString();
String stsExternalId = UUID.randomUUID().toString();
Map<String, String> stsHeaderOverrides = Map.of(UUID.randomUUID().toString(),
UUID.randomUUID().toString());
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn);
when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(stsExternalId);
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides);
when(awsCredentialsSupplier.getProvider(any(AwsCredentialsOptions.class))).thenReturn(
awsCredentialsProvider);
// Act
int maxConnectionRetries = 3;
LambdaAsyncClient lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient(
awsAuthenticationOptions,
maxConnectionRetries,
awsCredentialsSupplier,
sdkTimeout
);
// Verify
assertThat(lambdaAsyncClient, notNullValue());
verify(lambdaAsyncClientBuilder).region(region);
verify(lambdaAsyncClientBuilder).credentialsProvider(awsCredentialsProvider);
// Capture and verify ClientOverrideConfiguration
ArgumentCaptor<ClientOverrideConfiguration> configCaptor = ArgumentCaptor.forClass(
ClientOverrideConfiguration.class);
verify(lambdaAsyncClientBuilder).overrideConfiguration(configCaptor.capture());
ClientOverrideConfiguration config = configCaptor.getValue();
assertThat(config.apiCallTimeout(), equalTo(Optional.of(sdkTimeout)));
// Verify RetryPolicy
assertThat(config.retryPolicy().isPresent(), equalTo(true));
RetryPolicy retryPolicy = config.retryPolicy().get();
assertThat(retryPolicy.numRetries(), equalTo(maxConnectionRetries));
// Verify MetricPublisher
assertThat(config.metricPublishers(), notNullValue());
assertThat(config.metricPublishers().size(), equalTo(1));
MetricPublisher metricPublisher = config.metricPublishers().get(0);
assertThat(metricPublisher, instanceOf(MicrometerMetricPublisher.class));
// Verify that awsCredentialsSupplier.getProvider was called with correct AwsCredentialsOptions
ArgumentCaptor<AwsCredentialsOptions> optionsCaptor = ArgumentCaptor.forClass(
AwsCredentialsOptions.class);
verify(awsCredentialsSupplier).getProvider(optionsCaptor.capture());
AwsCredentialsOptions credentialsOptions = optionsCaptor.getValue();
assertThat(credentialsOptions.getRegion(), equalTo(region));
assertThat(credentialsOptions.getStsRoleArn(), equalTo(stsRoleArn));
assertThat(credentialsOptions.getStsExternalId(), equalTo(stsExternalId));
assertThat(credentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides));
}
@BeforeEach
void setUp() {
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_WEST_2);
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn("arn:aws:iam::123456789012:role/example-role");
when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn("externalId");
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(new HashMap<>());

when(awsCredentialsSupplier.getProvider(any(AwsCredentialsOptions.class))).thenReturn(awsCredentialsProvider);
}

@Test
void testCreateAsyncLambdaClient() {
int maxConnectionRetries = 3;
Duration sdkTimeout = Duration.ofSeconds(120);

LambdaAsyncClient client = LambdaClientFactory.createAsyncLambdaClient(
awsAuthenticationOptions,
maxConnectionRetries,
awsCredentialsSupplier,
sdkTimeout
);

assertNotNull(client);
assertEquals(Region.US_WEST_2, client.serviceClientConfiguration().region());
}

@Test
void testCreateAsyncLambdaClientWithDifferentRegion() {
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.EU_CENTRAL_1);

LambdaAsyncClient client = LambdaClientFactory.createAsyncLambdaClient(
awsAuthenticationOptions,
3,
awsCredentialsSupplier,
Duration.ofSeconds(60)
);

assertNotNull(client);
assertEquals(Region.EU_CENTRAL_1, client.serviceClientConfiguration().region());
}

@Test
void testCreateAsyncLambdaClientWithCustomSdkTimeout() {
Duration customTimeout = Duration.ofMinutes(5);

LambdaAsyncClient client = LambdaClientFactory.createAsyncLambdaClient(
awsAuthenticationOptions,
3,
awsCredentialsSupplier,
customTimeout
);

assertNotNull(client);
assertEquals(customTimeout, client.serviceClientConfiguration().overrideConfiguration().apiCallTimeout().get());
}

@Test
void testCreateAsyncLambdaClientWithMaxRetries() {
int maxRetries = 5;

LambdaAsyncClient client = LambdaClientFactory.createAsyncLambdaClient(
awsAuthenticationOptions,
maxRetries,
awsCredentialsSupplier,
Duration.ofSeconds(60)
);

assertNotNull(client);
}

@Test
void createAsyncLambdaClient_with_correct_configuration() {
try (MockedStatic<LambdaAsyncClient> mockedStaticLambdaAsyncClient = mockStatic(
LambdaAsyncClient.class);
MockedStatic<PluginMetrics> mockedPluginMetrics = mockStatic(PluginMetrics.class)) {
PluginMetrics pluginMetricsMock = mock(PluginMetrics.class);
mockedPluginMetrics.when(() -> PluginMetrics.fromNames("sdk", "aws"))
.thenReturn(pluginMetricsMock);
LambdaAsyncClientBuilder lambdaAsyncClientBuilder = mock(LambdaAsyncClientBuilder.class);
LambdaAsyncClient lambdaAsyncClientMock = mock(LambdaAsyncClient.class);
mockedStaticLambdaAsyncClient.when(LambdaAsyncClient::builder)
.thenReturn(lambdaAsyncClientBuilder);
when(lambdaAsyncClientBuilder.region(any(Region.class))).thenReturn(lambdaAsyncClientBuilder);
when(lambdaAsyncClientBuilder.credentialsProvider(
any(AwsCredentialsProvider.class))).thenReturn(lambdaAsyncClientBuilder);
when(lambdaAsyncClientBuilder.overrideConfiguration(
any(ClientOverrideConfiguration.class))).thenReturn(lambdaAsyncClientBuilder);
when(lambdaAsyncClientBuilder.build()).thenReturn(lambdaAsyncClientMock);
Region region = Region.US_WEST_2;
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region);
String stsRoleArn = UUID.randomUUID().toString();
String stsExternalId = UUID.randomUUID().toString();
Map<String, String> stsHeaderOverrides = Map.of(UUID.randomUUID().toString(),
UUID.randomUUID().toString());
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn);
when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(stsExternalId);
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides);
// when(awsCredentialsSupplier.getProvider(any(AwsCredentialsOptions.class))).thenReturn(awsCredentialsProvider);
// Act
int maxConnectionRetries = 3;
LambdaAsyncClient lambdaAsyncClient = LambdaClientFactory.createAsyncLambdaClient(
awsAuthenticationOptions,
maxConnectionRetries,
awsCredentialsSupplier,
sdkTimeout
);
// Verify
assertThat(lambdaAsyncClient, notNullValue());
// Verify builder methods
verify(lambdaAsyncClientBuilder).region(region);
verify(lambdaAsyncClientBuilder).credentialsProvider(awsCredentialsProvider);
// Capture and verify ClientOverrideConfiguration
ArgumentCaptor<ClientOverrideConfiguration> configCaptor = ArgumentCaptor.forClass(
ClientOverrideConfiguration.class);
verify(lambdaAsyncClientBuilder).overrideConfiguration(configCaptor.capture());
ClientOverrideConfiguration config = configCaptor.getValue();
// Verify apiCallTimeout
assertThat(config.apiCallTimeout(), equalTo(Optional.of(sdkTimeout)));
// Verify RetryPolicy
assertThat(config.retryPolicy().isPresent(), equalTo(true));
RetryPolicy retryPolicy = config.retryPolicy().get();
assertThat(retryPolicy.numRetries(), equalTo(maxConnectionRetries));
// Verify MetricPublisher
assertThat(config.metricPublishers(), notNullValue());
assertThat(config.metricPublishers().size(), equalTo(1));
MetricPublisher metricPublisher = config.metricPublishers().get(0);
assertThat(metricPublisher, instanceOf(MicrometerMetricPublisher.class));
// Verify that awsCredentialsSupplier.getProvider was called with correct AwsCredentialsOptions
ArgumentCaptor<AwsCredentialsOptions> optionsCaptor = ArgumentCaptor.forClass(
AwsCredentialsOptions.class);
verify(awsCredentialsSupplier).getProvider(optionsCaptor.capture());
AwsCredentialsOptions credentialsOptions = optionsCaptor.getValue();
assertThat(credentialsOptions.getRegion(), equalTo(region));
assertThat(credentialsOptions.getStsRoleArn(), equalTo(stsRoleArn));
assertThat(credentialsOptions.getStsExternalId(), equalTo(stsExternalId));
assertThat(credentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides));
}
}*/
}
void testCreateAsyncLambdaClientOverrideConfiguration() {
Duration sdkTimeout = Duration.ofSeconds(90);
int maxRetries = 4;

LambdaAsyncClient client = LambdaClientFactory.createAsyncLambdaClient(
awsAuthenticationOptions,
maxRetries,
awsCredentialsSupplier,
sdkTimeout
);

assertNotNull(client);
ClientOverrideConfiguration overrideConfig = client.serviceClientConfiguration().overrideConfiguration();

assertEquals(sdkTimeout, overrideConfig.apiCallTimeout().get());
assertEquals(maxRetries, overrideConfig.retryPolicy().get().numRetries());
assertNotNull(overrideConfig.metricPublishers());
assertFalse(overrideConfig.metricPublishers().isEmpty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,11 @@
import java.lang.reflect.Field;
import java.time.Duration;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -53,7 +50,6 @@
import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig;
import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType;
import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions;
Expand Down

0 comments on commit 0eca813

Please sign in to comment.