diff --git a/vaadin-spring/src/main/java/com/vaadin/flow/spring/SpringLookupInitializer.java b/vaadin-spring/src/main/java/com/vaadin/flow/spring/SpringLookupInitializer.java index 75570deb07d..33c97b2156c 100644 --- a/vaadin-spring/src/main/java/com/vaadin/flow/spring/SpringLookupInitializer.java +++ b/vaadin-spring/src/main/java/com/vaadin/flow/spring/SpringLookupInitializer.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.Map; import java.util.function.BiFunction; import java.util.stream.Collectors; @@ -61,15 +62,44 @@ private static class SpringLookup extends LookupImpl { private final WebApplicationContext context; + private final Map, Object> cachedServices; + + private final Map, Boolean> cacheableServices; + private SpringLookup(WebApplicationContext context, BiFunction, Class, Object> factory, Map, Collection>> services) { super(services, factory); this.context = context; + this.cachedServices = new HashMap<>(); + this.cacheableServices = new HashMap<>(); + } + + private boolean isCacheableService(Class serviceClass) { + return cacheableServices.computeIfAbsent(serviceClass, + key -> LookupInitializer.getDefaultImplementations() + .stream().anyMatch(serviceClass::isAssignableFrom)); + } + + private T getCachedService(Class serviceClass) { + return serviceClass.cast(cachedServices.get(serviceClass)); + } + + private void setCachedService(Class serviceClass, T service) { + cachedServices.put(serviceClass, service); } @Override public T lookup(Class serviceClass) { + boolean cacheableService = isCacheableService(serviceClass); + + if (cacheableService) { + T cached = getCachedService(serviceClass); + if (cached != null) { + return cached; + } + } + Collection beans = context.getBeansOfType(serviceClass).values(); // Check whether we have service objects instantiated without Spring @@ -87,13 +117,19 @@ public T lookup(Class serviceClass) { allFound.addAll(beans); allFound.add(service); } + T lookupResult; if (allFound.size() == 0) { - return null; + lookupResult = null; } else if (allFound.size() == 1) { - return allFound.iterator().next(); + lookupResult = allFound.iterator().next(); + } else { + throw new IllegalStateException(SEVERAL_IMPLS + serviceClass + + SPI + allFound + ONE_IMPL_REQUIRED); + } + if (cacheableService) { + setCachedService(serviceClass, lookupResult); } - throw new IllegalStateException(SEVERAL_IMPLS + serviceClass + SPI - + allFound + ONE_IMPL_REQUIRED); + return lookupResult; } @Override