Skip to content

Commit

Permalink
feat: skip cert rotation if connectivity info hasn't changed (#226)
Browse files Browse the repository at this point in the history
When we receive CIS shadow updates, only rotate if connectivity info
has actually changed.
  • Loading branch information
jcosentino11 authored and MikeDombo committed Mar 3, 2023
1 parent 7cbf543 commit ce5416c
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@
import software.amazon.awssdk.services.greengrassv2data.model.ThrottlingException;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -201,7 +204,7 @@ private void processCISShadow(ShadowDeltaUpdatedEvent event) {
processCISShadow(cisVersion, event.state);
}

@SuppressWarnings("PMD.AvoidCatchingGenericException")
@SuppressWarnings({"PMD.AvoidCatchingGenericException", "PMD.PrematureDeclaration"})
private synchronized void processCISShadow(String version, Map<String, Object> desiredState) {
if (version == null) {
LOGGER.atWarn().log("Ignoring CIS shadow response, version is missing");
Expand All @@ -220,6 +223,9 @@ private synchronized void processCISShadow(String version, Map<String, Object> d
// operation (particularly on low end devices) it is imperative that we process this event asynchronously
// to avoid blocking other MQTT subscribers in the Nucleus
CompletableFuture.runAsync(() -> {

Set<String> prevCachedHostAddresses = new HashSet<>(connectivityInformation.getCachedHostAddresses());

Optional<List<ConnectivityInfo>> connectivityInfo;
try {
connectivityInfo = RetryUtils.runWithRetry(
Expand All @@ -245,6 +251,24 @@ private synchronized void processCISShadow(String version, Map<String, Object> d
// We won't retry in this case, but we will update the CIS shadow reported state
// to signal that we have fully processed this version.
try {
LOGGER.atInfo().kv(VERSION, version)
.log("No connectivity info found. Skipping cert re-generation");
updateCISShadowReportedState(desiredState);
} finally {
// Don't process the same version again
lastVersion = version;
}
return;
}

// skip cert rotation if connectivity info hasn't changed
Set<String> cachedHostAddresses = new HashSet<>(connectivityInformation.getCachedHostAddresses());
if (Objects.equals(prevCachedHostAddresses, cachedHostAddresses)) {
try {
LOGGER.atInfo().kv(VERSION, version)
.log("Connectivity info hasn't changed. Skipping cert re-generation");
// update the CIS shadow reported state
// to signal that we have fully processed this version.
updateCISShadowReportedState(desiredState);
} finally {
// Don't process the same version again
Expand All @@ -255,8 +279,7 @@ private synchronized void processCISShadow(String version, Map<String, Object> d

try {
for (CertificateGenerator cg : monitoredCertificateGenerators) {
cg.generateCertificate(connectivityInformation::getCachedHostAddresses,
"connectivity info was updated");
cg.generateCertificate(() -> new ArrayList<>(cachedHostAddresses), "connectivity info was updated");
}
} catch (CertificateGenerationException e) {
LOGGER.atError().kv(VERSION, version).cause(e).log("Failed to generate new certificates");
Expand All @@ -267,6 +290,7 @@ private synchronized void processCISShadow(String version, Map<String, Object> d
// update CIS shadow so reported state matches desired state
updateCISShadowReportedState(desiredState);
} finally {
// Don't process the same version again
lastVersion = version;
}
}, executorService).exceptionally(e -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import com.aws.greengrass.logging.impl.LogManager;
import com.aws.greengrass.util.Coerce;
import com.aws.greengrass.util.GreengrassServiceClientFactory;
import lombok.AccessLevel;
import lombok.Getter;
import software.amazon.awssdk.services.greengrassv2data.model.ConnectivityInfo;
import software.amazon.awssdk.services.greengrassv2data.model.GetConnectivityInfoRequest;
import software.amazon.awssdk.services.greengrassv2data.model.GetConnectivityInfoResponse;
Expand All @@ -33,6 +35,7 @@ public class ConnectivityInformation {

private final DeviceConfiguration deviceConfiguration;
private final GreengrassServiceClientFactory clientFactory;
@Getter(AccessLevel.PACKAGE) // unit testing
private final ConnectivityInfoCache connectivityInfoCache;

private final Map<String, Set<HostAddress>> connectivityInformationMap = new ConcurrentHashMap<>();
Expand All @@ -55,7 +58,7 @@ public ConnectivityInformation(DeviceConfiguration deviceConfiguration,
}

/**
* Get cached connectivity info.
* Get cached connectivity info. Items in this list are unique.
*
* @return list of cached connectivity info items
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,51 @@ void GIVEN_CISShadowMonitor_WHEN_cis_shadow_delta_duplicate_received_THEN_delta_
verifyCertsRotatedWhenConnectivityChanges();
}

@Test
@SuppressWarnings("unchecked")
void GIVEN_CISShadowMonitor_WHEN_connectivity_info_does_not_change_across_multiple_deltas_received_THEN_extra_processing_is_ignored() throws Exception {
// make connectivity call yield the same response each time,
// to match scenario where we receive same shadow delta version multiple times.
connectivityInfoProvider.setMode(FakeConnectivityInformation.Mode.CONSTANT);

// capture the subscription callback for shadow delta update
ArgumentCaptor<Consumer<MqttMessage>> shadowDeltaUpdatedCallback = ArgumentCaptor.forClass(Consumer.class);
when(shadowClientConnection.subscribe(eq(SHADOW_DELTA_UPDATED_TOPIC), any(),
shadowDeltaUpdatedCallback.capture())).thenReturn(DUMMY_PACKET_ID);

// generated list of deltas to feed to the shadow monitor
List<Map<String, Object>> deltas =
IntStream.range(0, 5).mapToObj(i -> Utils.immutableMap("version", (Object) String.valueOf(i)))
.collect(Collectors.toList());
Map<String, Object> lastDelta = deltas.get(deltas.size() - 1);

// notify when last shadow update is published
WhenUpdateIsPublished whenUpdateIsPublished = WhenUpdateIsPublished.builder()
.expectedReportedState(lastDelta) // reported state updated to desired state
.expectedDesiredState(null) // desired state isn't modified
.build();
when(shadowClientConnection.publish(argThat(new ShadowUpdateRequestMatcher()), any(), anyBoolean())).thenAnswer(
whenUpdateIsPublished);

cisShadowMonitor.startMonitor();
cisShadowMonitor.addToMonitor(certificateGenerator);

// trigger update delta subscription callbacks
AtomicInteger version = new AtomicInteger(1);
deltas.forEach(delta -> {
ShadowDeltaUpdatedEvent deltaUpdatedEvent = new ShadowDeltaUpdatedEvent();
deltaUpdatedEvent.version = version.getAndIncrement();
deltaUpdatedEvent.state = new HashMap<>(delta);

// original message
wrapInMessage(SHADOW_DELTA_UPDATED_TOPIC, deltaUpdatedEvent, false).ifPresent(
resp -> shadowDeltaUpdatedCallback.getValue().accept(resp));
});

assertTrue(whenUpdateIsPublished.getLatch().await(5L, TimeUnit.SECONDS));
verifyCertsRotatedWhenConnectivityChanges();
}

@Test
void GIVEN_CISShadowMonitor_WHEN_stop_monitor_THEN_unsubscribe() {
AtomicInteger numSubscriptions = new AtomicInteger();
Expand Down Expand Up @@ -570,14 +615,28 @@ private FakeIotShadowClient(MqttClientConnection connection) {
}


static class FakeConnectivityInfoCache extends ConnectivityInfoCache {

private final Map<String, Set<HostAddress>> cache = new HashMap<>();

@Override
public void put(String source, Set<HostAddress> connectivityInfo) {
cache.put(source, connectivityInfo);
}

@Override
public Set<HostAddress> getAll() {
return cache.values().stream().flatMap(Set::stream).collect(Collectors.toSet());
}
}

static class FakeConnectivityInformation extends ConnectivityInformation {

private final AtomicReference<List<ConnectivityInfo>> CONNECTIVITY_INFO_SAMPLE =
new AtomicReference<>(Collections.singletonList(connectivityInfoWithRandomHost()));
private final Set<Integer> responseHashes = new CopyOnWriteArraySet<>();
private final AtomicReference<Mode> mode = new AtomicReference<>(Mode.RANDOM);
private final AtomicBoolean failed = new AtomicBoolean();
private List<String> cachedHostAddresses;

enum Mode {
/**
Expand All @@ -600,7 +659,7 @@ enum Mode {
}

FakeConnectivityInformation() {
super(null, null, null);
super(null, null, new FakeConnectivityInfoCache());
}

void setMode(Mode mode) {
Expand All @@ -620,9 +679,11 @@ int getNumUniqueConnectivityInfoResponses() {
public Optional<List<ConnectivityInfo>> getConnectivityInfo() {
List<ConnectivityInfo> connectivityInfo = doGetConnectivityInfo();
if (connectivityInfo != null) {
cachedHostAddresses = connectivityInfo.stream().map(ConnectivityInfo::hostAddress).distinct()
.collect(Collectors.toList());
responseHashes.add(cachedHostAddresses.hashCode());
Set<HostAddress> addresses = connectivityInfo.stream()
.map(HostAddress::of)
.collect(Collectors.toSet());
getConnectivityInfoCache().put("source", addresses);
responseHashes.add(addresses.hashCode());
}
return Optional.ofNullable(connectivityInfo);
}
Expand Down

0 comments on commit ce5416c

Please sign in to comment.