Skip to content

Commit

Permalink
Compose the cache message with SystemMessage + UserMessage
Browse files Browse the repository at this point in the history
  • Loading branch information
andreadimaio committed Jun 6, 2024
1 parent 9747dc5 commit 631e904
Show file tree
Hide file tree
Showing 16 changed files with 253 additions and 2,107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ public interface AiCacheBuildConfig {
/**
* Ai Cache embedding model related settings
*/
CacheEmbeddingModelConfig embeddingModel();
CacheEmbeddingModelConfig embedding();
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import org.jboss.jandex.ClassType;
import org.jboss.jandex.IndexView;

import dev.langchain4j.model.embedding.EmbeddingModel;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.runtime.AiCacheRecorder;
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider;
Expand Down Expand Up @@ -49,23 +51,31 @@ void setupBeans(AiCacheBuildConfig cacheBuildConfig,
}
}

String embeddingModel = NamedConfigUtil.DEFAULT_NAME;
if (cacheBuildConfig.embeddingModel() != null)
embeddingModel = cacheBuildConfig.embeddingModel().name().orElse(NamedConfigUtil.DEFAULT_NAME);
String embeddingModelName = NamedConfigUtil.DEFAULT_NAME;
if (cacheBuildConfig.embedding() != null)
embeddingModelName = cacheBuildConfig.embedding().name().orElse(NamedConfigUtil.DEFAULT_NAME);

aiCacheBuildItemProducer.produce(new AiCacheBuildItem(enableCache, embeddingModel));
aiCacheBuildItemProducer.produce(new AiCacheBuildItem(enableCache, embeddingModelName));

if (enableCache) {
var configurator = SyntheticBeanBuildItem
SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
.configure(AiCacheProvider.class)
.setRuntimeInit()
.addInjectionPoint(ClassType.create(AiCacheStore.class))
.scope(ApplicationScoped.class)
.createWith(recorder.messageWindow(cacheConfig))
.createWith(recorder.messageWindow(cacheConfig, embeddingModelName))
.defaultBean();

if (NamedConfigUtil.isDefault(embeddingModelName)) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.EMBEDDING_MODEL));
} else {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.EMBEDDING_MODEL),
AnnotationInstance.builder(ModelName.class).add("value", embeddingModelName).build());
}

syntheticBeanProducer.produce(configurator.done());
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(AiCacheStore.class));
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(EmbeddingModel.class));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,

String chatModelName = bi.getChatModelName();
String moderationModelName = bi.getModerationModelName();
String aiCacheEmbeddingModelName = aiCacheBuildItem.getEmbeddingModelName();
boolean enableCache = aiCacheBuildItem.isEnable();

// It is not possible to use the cache in combination with the tools.
Expand All @@ -464,7 +463,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
retrievalAugmentorSupplierClassName,
auditServiceClassSupplierName, moderationModelSupplierClassName, chatModelName,
moderationModelName,
aiCacheEmbeddingModelName,
needsStreamingChatModel,
needsModerationModel,
enableCache)))
Expand Down Expand Up @@ -560,13 +558,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
} else {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.AI_CACHE_PROVIDER));
}

if (NamedConfigUtil.isDefault(aiCacheEmbeddingModelName)) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.EMBEDDING_MODEL));
} else {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.EMBEDDING_MODEL),
AnnotationInstance.builder(ModelName.class).add("value", aiCacheEmbeddingModelName).build());
}
needsAiCacheProvider = true;
}

Expand Down Expand Up @@ -596,7 +587,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
}
if (needsAiCacheProvider) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.AI_CACHE_PROVIDER));
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.EMBEDDING_MODEL));
}
if (!allToolNames.isEmpty()) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolNames));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import java.time.Duration;
import java.util.function.Function;

import dev.langchain4j.model.embedding.EmbeddingModel;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.runtime.cache.AiCache;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore;
Expand All @@ -14,15 +16,27 @@
@Recorder
public class AiCacheRecorder {

public Function<SyntheticCreationalContext<AiCacheProvider>, AiCacheProvider> messageWindow(AiCacheConfig config) {
public Function<SyntheticCreationalContext<AiCacheProvider>, AiCacheProvider> messageWindow(AiCacheConfig config,
String embeddingModelName) {
return new Function<>() {
@Override
public AiCacheProvider apply(SyntheticCreationalContext<AiCacheProvider> context) {

EmbeddingModel embeddingModel;
AiCacheStore aiCacheStore = context.getInjectedReference(AiCacheStore.class);

if (NamedConfigUtil.isDefault(embeddingModelName)) {
embeddingModel = context.getInjectedReference(EmbeddingModel.class);
} else {
embeddingModel = context.getInjectedReference(EmbeddingModel.class,
ModelName.Literal.of(embeddingModelName));
}

double threshold = config.threshold();
int maxSize = config.maxSize();
Duration ttl = config.ttl().orElse(null);
String queryPrefix = config.embedding().queryPrefix().orElse("");
String passagePrefix = config.embedding().passagePrefix().orElse("");

return new AiCacheProvider() {
@Override
Expand All @@ -32,6 +46,9 @@ public AiCache get(Object memoryId) {
.ttl(ttl)
.maxSize(maxSize)
.threshold(threshold)
.queryPrefix(queryPrefix)
.passagePrefix(passagePrefix)
.embeddingModel(embeddingModel)
.store(aiCacheStore)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.retriever.Retriever;
Expand Down Expand Up @@ -253,13 +252,6 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
}
}

if (NamedConfigUtil.isDefault(info.aiCacheEmbeddingModelName())) {
aiServiceContext.embeddingModel = creationalContext.getInjectedReference(EmbeddingModel.class);
} else {
aiServiceContext.embeddingModel = creationalContext.getInjectedReference(EmbeddingModel.class,
ModelName.Literal.of(info.aiCacheEmbeddingModelName()));
}

aiServiceContext.aiCaches = new ConcurrentHashMap<>();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
Expand Down Expand Up @@ -197,15 +196,14 @@ public void accept(Response<AiMessage> message) {
if (cache != null) {
log.debug("Attempting to obtain AI response from cache");

Embedding query = context.embeddingModel.embed(userMessage.text()).content();
var cacheResponse = cache.search(query);
var cacheResponse = cache.search(systemMessage.orElse(null), userMessage);

if (cacheResponse.isPresent()) {
log.debug("Return cached response");
response = Response.from(cacheResponse.get());
} else {
response = executeLLMCall(context, messages, moderationFuture, toolSpecifications);
cache.add(query, response.content());
cache.add(systemMessage.orElse(null), userMessage, response.content());
}

} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ public record DeclarativeAiServiceCreateInfo(String serviceClassName,
String moderationModelSupplierClassName,
String chatModelName,
String moderationModelName,
String aiCacheEmbeddingModelName,
boolean needsStreamingChatModel,
boolean needsModerationModel,
boolean enableCache) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import java.util.function.BiConsumer;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.service.AiServiceContext;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.audit.AuditService;
Expand All @@ -16,7 +15,6 @@ public class QuarkusAiServiceContext extends AiServiceContext {
public AuditService auditService;
public Map<Object, AiCache> aiCaches;
public AiCacheProvider aiCacheProvider;
public EmbeddingModel embeddingModel;

// needed by Arc
public QuarkusAiServiceContext() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import java.util.Optional;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;

/**
* Represents the cache of a AI. It can be used to reduces response time for similar queries.
Expand All @@ -20,18 +21,20 @@ public interface AiCache {
/**
* Cache a new message.
*
* @param query Embedded value to add to the cache.
* @param response Response returned by the AI to add to the cache.
* @param systemMessage {@link SystemMessage} value to add to the cache.
* @param userMessage {@link UserMessage} value to add to the cache.
* @param aiResponse {@link AiMessage} value to add to the cache.
*/
void add(Embedding query, AiMessage response);
void add(SystemMessage systemMessage, UserMessage userMessage, AiMessage aiResponse);

/**
* Check to see if there is a response in the cache that is semantically close to a cached query.
* Check if there is a response in the cache that is semantically close to the cached items.
*
* @param query
* @param systemMessage {@link SystemMessage} value to find in the cache.
* @param userMessage {@link UserMessage} value to find in the cache.
* @return
*/
Optional<AiMessage> search(Embedding query);
Optional<AiMessage> search(SystemMessage systemMessage, UserMessage userMessage);

/**
* Clears the cache.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import java.util.Optional;
import java.util.concurrent.locks.ReentrantLock;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.CosineSimilarity;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore.CacheRecord;

Expand All @@ -23,6 +25,9 @@ public class MessageWindowAiCache implements AiCache {
private final AiCacheStore store;
private final Double threshold;
private final Duration ttl;
private final String queryPrefix;
private final String passagePrefix;
private final EmbeddingModel embeddingModel;
private final ReentrantLock lock;

public MessageWindowAiCache(Builder builder) {
Expand All @@ -31,6 +36,9 @@ public MessageWindowAiCache(Builder builder) {
this.store = builder.store;
this.ttl = builder.ttl;
this.threshold = builder.threshold;
this.queryPrefix = builder.queryPrefix;
this.passagePrefix = builder.passagePrefix;
this.embeddingModel = builder.embeddingModel;
this.lock = new ReentrantLock();
}

Expand All @@ -40,12 +48,18 @@ public Object id() {
}

@Override
public void add(Embedding query, AiMessage response) {
public void add(SystemMessage systemMessage, UserMessage userMessage, AiMessage aiResponse) {

if (Objects.isNull(query) || Objects.isNull(response)) {
if (Objects.isNull(userMessage) || Objects.isNull(aiResponse)) {
return;
}

String query;
if (Objects.isNull(systemMessage) || Objects.isNull(systemMessage.text()) || systemMessage.text().isBlank())
query = userMessage.text();
else
query = "%s%s%s".formatted(passagePrefix, systemMessage.text(), userMessage.text());

try {

lock.lock();
Expand All @@ -55,7 +69,7 @@ public void add(Embedding query, AiMessage response) {
elements.remove(0);
}

List<CacheRecord> update = new LinkedList<>();
List<CacheRecord> items = new LinkedList<>();
for (int i = 0; i < elements.size(); i++) {

var expiredTime = Date.from(elements.get(i).creation().plus(ttl));
Expand All @@ -64,23 +78,29 @@ public void add(Embedding query, AiMessage response) {
if (currentTime.after(expiredTime))
continue;

update.add(elements.get(i));
items.add(elements.get(i));
}

update.add(CacheRecord.of(query, response));
store.updateCache(id, update);
items.add(CacheRecord.of(embeddingModel.embed(query).content(), aiResponse));
store.updateCache(id, items);

} finally {
lock.unlock();
}
}

@Override
public Optional<AiMessage> search(Embedding query) {
public Optional<AiMessage> search(SystemMessage systemMessage, UserMessage userMessage) {

if (Objects.isNull(query))
if (Objects.isNull(userMessage))
return Optional.empty();

String query;
if (Objects.isNull(systemMessage) || Objects.isNull(systemMessage.text()) || systemMessage.text().isBlank())
query = userMessage.text();
else
query = "%s%s%s".formatted(queryPrefix, systemMessage.text(), userMessage.text());

var elements = store.getAll(id);
double maxScore = 0;
AiMessage result = null;
Expand All @@ -95,7 +115,7 @@ public Optional<AiMessage> search(Embedding query) {
continue;
}

var relevanceScore = CosineSimilarity.between(query, cacheRecord.embedded());
var relevanceScore = CosineSimilarity.between(embeddingModel.embed(query).content(), cacheRecord.embedded());
var score = (float) CosineSimilarity.fromRelevanceScore(relevanceScore);

if (score >= threshold.doubleValue() && score >= maxScore) {
Expand All @@ -119,6 +139,9 @@ public static class Builder {
AiCacheStore store;
Double threshold;
Duration ttl;
String queryPrefix;
String passagePrefix;
EmbeddingModel embeddingModel;

private Builder(Object id) {
this.id = id;
Expand Down Expand Up @@ -148,6 +171,21 @@ public Builder ttl(Duration ttl) {
return this;
}

public Builder queryPrefix(String queryPrefix) {
this.queryPrefix = queryPrefix;
return this;
}

public Builder passagePrefix(String passagePrefix) {
this.passagePrefix = passagePrefix;
return this;
}

public Builder embeddingModel(EmbeddingModel embeddingModel) {
this.embeddingModel = embeddingModel;
return this;
}

public AiCache build() {
return new MessageWindowAiCache(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,9 @@ public interface AiCacheConfig {
* Time to live for messages stored in the cache.
*/
Optional<Duration> ttl();

/**
* Allow to customize the embedding operation.
*/
AiCacheEmbeddingConfig embedding();
}
Loading

0 comments on commit 631e904

Please sign in to comment.