From f42c239bb4c3e0b238ce13f9f9cf424bfecab811 Mon Sep 17 00:00:00 2001 From: Gian Maria Ricci Date: Thu, 9 Jan 2025 12:44:15 +0100 Subject: [PATCH] added disk based prompt store also added prompt store to query rewriter --- .../Helper/LocalFolderPromptStoreTests.cs | 174 ++++++++++++++++++ .../BasePromptStore.cs | 50 +++++ src/KernelMemory.Extensions/IPromptStore.cs | 19 ++ .../LocalFolderPromptStore.cs | 43 +++++ .../IConversationQueryRewriter.cs | 131 +++++++------ 5 files changed, 359 insertions(+), 58 deletions(-) create mode 100644 src/KernelMemory.Extensions.FunctionalTests/Helper/LocalFolderPromptStoreTests.cs create mode 100644 src/KernelMemory.Extensions/BasePromptStore.cs create mode 100644 src/KernelMemory.Extensions/LocalFolderPromptStore.cs diff --git a/src/KernelMemory.Extensions.FunctionalTests/Helper/LocalFolderPromptStoreTests.cs b/src/KernelMemory.Extensions.FunctionalTests/Helper/LocalFolderPromptStoreTests.cs new file mode 100644 index 0000000..44c8b70 --- /dev/null +++ b/src/KernelMemory.Extensions.FunctionalTests/Helper/LocalFolderPromptStoreTests.cs @@ -0,0 +1,174 @@ +using System; +using System.IO; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Moq; +using Xunit; + +namespace KernelMemory.Extensions.FunctionalTests.Helper; + +public class LocalFolderPromptStoreTests : IDisposable +{ + private readonly string _testDirectory; + private readonly Mock> _loggerMock; + private readonly LocalFolderPromptStore _store; + + public LocalFolderPromptStoreTests() + { + _testDirectory = Path.Combine(Path.GetTempPath(), $"promptstore_tests_{Guid.NewGuid()}"); + _loggerMock = new Mock>(); + _store = new LocalFolderPromptStore(_testDirectory, _loggerMock.Object); + } + + public void Dispose() + { + if (Directory.Exists(_testDirectory)) + { + Directory.Delete(_testDirectory, true); + } + } + + [Fact] + public async Task GetPromptAsync_NonExistentKey_ReturnsNull() + { + // Act + var result = await _store.GetPromptAsync("nonexistent"); + + // Assert + Assert.Null(result); + } + + [Fact] + public async Task SetAndGetPromptAsync_ValidKey_ReturnsStoredPrompt() + { + // Arrange + const string key = "test-key"; + const string expectedPrompt = "This is a test prompt"; + + // Act + await _store.SetPromptAsync(key, expectedPrompt); + var result = await _store.GetPromptAsync(key); + + // Assert + Assert.Equal(expectedPrompt, result); + } + + [Fact] + public async Task GetPromptAndSetDefaultAsync_NonExistentKey_SetsAndReturnsDefault() + { + // Arrange + const string key = "default-key"; + const string defaultPrompt = "Default prompt value"; + + // Act + var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); + var storedPrompt = await _store.GetPromptAsync(key); + + // Assert + Assert.Equal(defaultPrompt, result); + Assert.Equal(defaultPrompt, storedPrompt); + } + + [Fact] + public async Task GetPromptAndSetDefaultAsync_ExistingKey_ReturnsExistingPrompt() + { + // Arrange + const string key = "existing-key"; + const string existingPrompt = "Existing prompt"; + const string defaultPrompt = "Default prompt"; + await _store.SetPromptAsync(key, existingPrompt); + + // Act + var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); + + // Assert + Assert.Equal(existingPrompt, result); + } + + [Fact] + public async Task SetPromptAsync_KeyWithSpecialCharacters_HandlesCorrectly() + { + // Arrange + const string key = "special/\\*:?\"<>|characters"; + const string expectedPrompt = "Prompt with special characters"; + + // Act + await _store.SetPromptAsync(key, expectedPrompt); + var result = await _store.GetPromptAsync(key); + + // Assert + Assert.Equal(expectedPrompt, result); + } + + [Fact] + public async Task GetPromptAndSetDefaultAsync_MissingPlaceholder_LogsError() + { + // Arrange + const string key = "test-placeholder"; + const string existingPrompt = "A prompt without placeholder"; + const string defaultPrompt = "Default prompt with {{$placeholder}}"; + await _store.SetPromptAsync(key, existingPrompt); + + // Act + var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); + + // Assert + _loggerMock.Verify( + x => x.Log( + LogLevel.Error, + It.IsAny(), + It.Is((v, t) => v.ToString().Contains("{{$placeholder}}")), + It.IsAny(), + It.IsAny>() + ), + Times.Once); + } + + [Fact] + public async Task GetPromptAndSetDefaultAsync_MultipleMissingPlaceholders_LogsMultipleErrors() + { + // Arrange + const string key = "test-multiple-placeholders"; + const string existingPrompt = "A prompt without any placeholders"; + const string defaultPrompt = "Default with {{$first}} and {{$second}}"; + await _store.SetPromptAsync(key, existingPrompt); + + // Act + var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); + + // Assert + _loggerMock.Verify( + x => x.Log( + LogLevel.Error, + It.IsAny(), + It.Is((v, t) => true), + It.IsAny(), + It.IsAny>() + ), + Times.Exactly(2)); + } + + [Fact] + public async Task GetPromptAndSetDefaultAsync_ValidPlaceholders_NoErrors() + { + // Arrange + const string key = "test-valid-placeholders"; + const string existingPrompt = "A prompt with {{$placeholder}} correctly set"; + const string defaultPrompt = "Default with {{$placeholder}}"; + await _store.SetPromptAsync(key, existingPrompt); + + // Act + var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); + + // Assert + _loggerMock.Verify( + x => x.Log( + LogLevel.Error, + It.IsAny(), + It.Is((v, t) => true), + It.IsAny(), + It.IsAny>() + ), + Times.Never); + } +} diff --git a/src/KernelMemory.Extensions/BasePromptStore.cs b/src/KernelMemory.Extensions/BasePromptStore.cs new file mode 100644 index 0000000..835a029 --- /dev/null +++ b/src/KernelMemory.Extensions/BasePromptStore.cs @@ -0,0 +1,50 @@ +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using System.Text.RegularExpressions; + +namespace KernelMemory.Extensions; + +public abstract class BasePromptStore : IPromptStore +{ + private static readonly Regex PlaceholderRegex = new(@"\{\{\$\w+\}\}", RegexOptions.Compiled); + + protected readonly ILogger _log; + + protected BasePromptStore(ILogger log) + { + _log = log; + } + + protected void ValidatePlaceholders(string defaultPrompt, string loadedPrompt) + { + var placeholders = PlaceholderRegex.Matches(defaultPrompt) + .Cast() + .Select(m => m.Value) + .ToList(); + + foreach (var placeholder in placeholders) + { + if (!loadedPrompt.Contains(placeholder)) + { + _log.LogError("The prompt does not contain {Placeholder} placeholder, the prompt will not work correctly", placeholder); + } + } + } + + public abstract Task GetPromptAsync(string key); + public abstract Task SetPromptAsync(string key, string prompt); + + public virtual async Task GetPromptAndSetDefaultAsync(string key, string defaultPrompt) + { + var prompt = await GetPromptAsync(key); + if (prompt == null) + { + await SetPromptAsync(key, defaultPrompt); + return defaultPrompt; + } + + ValidatePlaceholders(defaultPrompt, prompt); + return prompt; + } +} diff --git a/src/KernelMemory.Extensions/IPromptStore.cs b/src/KernelMemory.Extensions/IPromptStore.cs index cc51a1c..c81f053 100644 --- a/src/KernelMemory.Extensions/IPromptStore.cs +++ b/src/KernelMemory.Extensions/IPromptStore.cs @@ -23,6 +23,14 @@ public interface IPromptStore /// various components will use some default prompts. Task GetPromptAsync(string key); + /// + /// Get the prompt for the given key and set the default prompt if the prompt is not present. + /// + /// + /// + /// + Task GetPromptAndSetDefaultAsync(string key, string defaultPrompt); + /// /// Allow setting prompt value. /// @@ -45,6 +53,17 @@ public class NullPromptStore : IPromptStore return Task.FromResult(null); } + /// + /// Return the default prompt always + /// + /// + /// + /// + public Task GetPromptAndSetDefaultAsync(string key, string defaultPrompt) + { + return Task.FromResult(defaultPrompt); + } + /// /// Allow setting prompt value. /// diff --git a/src/KernelMemory.Extensions/LocalFolderPromptStore.cs b/src/KernelMemory.Extensions/LocalFolderPromptStore.cs new file mode 100644 index 0000000..99e319d --- /dev/null +++ b/src/KernelMemory.Extensions/LocalFolderPromptStore.cs @@ -0,0 +1,43 @@ +using System; +using System.IO; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using System.Linq; +using Microsoft.KernelMemory.Diagnostics; + +namespace KernelMemory.Extensions; + +public class LocalFolderPromptStore : BasePromptStore +{ + private readonly string _promptDirectory; + + public LocalFolderPromptStore(string promptDirectory, ILogger? log = null) + : base(log ?? DefaultLogger.Instance) + { + _promptDirectory = promptDirectory; + Directory.CreateDirectory(promptDirectory); + } + + private string GetPromptFilePath(string key) + { + var sanitizedKey = string.Join("_", key.Split(Path.GetInvalidFileNameChars())); + return Path.Combine(_promptDirectory, $"{sanitizedKey}.prompt"); + } + + public override async Task GetPromptAsync(string key) + { + var filePath = GetPromptFilePath(key); + if (!File.Exists(filePath)) + { + return null; + } + + return await File.ReadAllTextAsync(filePath); + } + + public override async Task SetPromptAsync(string key, string prompt) + { + var filePath = GetPromptFilePath(key); + await File.WriteAllTextAsync(filePath, prompt); + } +} diff --git a/src/KernelMemory.Extensions/QueryPipeline/IConversationQueryRewriter.cs b/src/KernelMemory.Extensions/QueryPipeline/IConversationQueryRewriter.cs index 3f3837b..0d9c988 100644 --- a/src/KernelMemory.Extensions/QueryPipeline/IConversationQueryRewriter.cs +++ b/src/KernelMemory.Extensions/QueryPipeline/IConversationQueryRewriter.cs @@ -3,7 +3,9 @@ using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.OpenAI; using Microsoft.SemanticKernel.PromptTemplates.Handlebars; +using System.Collections.Generic; using System.Threading.Tasks; +using System.Collections.Concurrent; namespace KernelMemory.Extensions.QueryPipeline; @@ -28,13 +30,16 @@ public interface IConversationQueryRewriter public class SemanticKernelQueryRewriter : IConversationQueryRewriter { private readonly SemanticKernelQueryRewriterOptions _semanticKernelQueryRewriterOptions; + private readonly IPromptStore _promptStore; private readonly ISemanticKernelWrapper _kernel; public SemanticKernelQueryRewriter( SemanticKernelQueryRewriterOptions semanticKernelQueryRewriterOptions, + IPromptStore promptStore, ISemanticKernelWrapper kernel) { _semanticKernelQueryRewriterOptions = semanticKernelQueryRewriterOptions; + _promptStore = promptStore; _kernel = kernel; } @@ -56,11 +61,12 @@ public async Task RewriteAsync(Conversation conversation, string questio chatMessages.AddAssistantMessage("I do not know the answer"); } } - string prompt = $@"You will reformulate the question based on the conversation up to this point so the question will + string defPrompt = $@"You will reformulate the question based on the conversation up to this point so the question will be a standalone question that contains also the previous context. If there is no correlation between the conversation and the question you will output the question unchanged. You will answer only with the rewritten question no other text must be included. Question: {question}"; + var prompt = await _promptStore.GetPromptAndSetDefaultAsync("SemanticKernelQueryRewriter", defPrompt); chatMessages.AddUserMessage(prompt); var result = await chatCompletionService.GetChatMessageContentAsync(chatMessages, new PromptExecutionSettings() @@ -70,80 +76,89 @@ You will answer only with the rewritten question no other text must be included. return result?.ToString() ?? question; } -} - +} + /// /// Allows some parametrization of the rewriter. /// -public class SemanticKernelQueryRewriterOptions -{ - public string? ModelId { get; set; } - - public float Temperature { get; set; } = 0.0f; +public class SemanticKernelQueryRewriterOptions +{ + public string? ModelId { get; set; } + + public float Temperature { get; set; } = 0.0f; } public class HandlebarSemanticKernelQueryRewriter : IConversationQueryRewriter { + private const string DefaultPromptTemplate = @"system: +* Given the following conversation history and the users next question, rephrase the +follow up input to be a stand alone question. +If the conversation is irrelevant or empty, just restate the original question. +Do not add more details than necessary to the question. + +chat history: +{{#each history}} +question: +{{question}} +answer: +{{answer}} +{{/each}} + +Follow up Input: {{ chat_input }} +Standalone Question:"; + + private readonly ConcurrentDictionary _functionCache = new(); private readonly SemanticKernelQueryRewriterOptions _semanticKernelQueryRewriterOptions; - private readonly ISemanticKernelWrapper _kernel; - private readonly KernelFunction _chatFunction; - + private readonly ISemanticKernelWrapper _kernel; + private readonly IPromptStore _promptStore; + public HandlebarSemanticKernelQueryRewriter( SemanticKernelQueryRewriterOptions semanticKernelQueryRewriterOptions, - ISemanticKernelWrapper kernel) + ISemanticKernelWrapper kernel, + IPromptStore promptStore) { _semanticKernelQueryRewriterOptions = semanticKernelQueryRewriterOptions; - _kernel = kernel; - - // Create a template for chat with settings - _chatFunction = kernel.CreateFunctionFromPrompt(new PromptTemplateConfig() - { - Name = "TestRewrite", - Description = "Rewrite a query for kernel memory.", - Template = @"system: -* Given the following conversation history and the users next question, rephrase the -follow up input to be a stand alone question. -If the conversation is irrelevant or empty, just restate the original question. -Do not add more details than necessary to the question. - -chat history: -{{#each history}} -question: -{{question}} -answer: -{{answer}} -{{/each}} - -Follow up Input: {{ chat_input }} -Standalone Question:", - TemplateFormat = "handlebars", - InputVariables = - [ - new() { Name = "chat_input", Description = "New question of the user", IsRequired = false, Default = "" }, - new() { Name = "history", Description = "The history of the RAG CHAT.", IsRequired = true } - ], - ExecutionSettings = - { - { "default", new OpenAIPromptExecutionSettings() - { - MaxTokens = 1000, - Temperature = 0, - ModelId = "gpt35", - } - }, - } - }, - promptTemplateFactory: new HandlebarsPromptTemplateFactory()); + _kernel = kernel; + _promptStore = promptStore; + } + + private async Task CreateRewriteFunction() + { + var template = await _promptStore.GetPromptAndSetDefaultAsync("HandlebarSemanticKernelQueryRewriter", DefaultPromptTemplate); + + return _functionCache.GetOrAdd(template, _ => _kernel.CreateFunctionFromPrompt(new PromptTemplateConfig() + { + Name = "TestRewrite", + Description = "Rewrite a query for kernel memory.", + Template = template, + TemplateFormat = "handlebars", + InputVariables = + [ + new() { Name = "chat_input", Description = "New question of the user", IsRequired = false, Default = "" }, + new() { Name = "history", Description = "The history of the RAG CHAT.", IsRequired = true } + ], + ExecutionSettings = + { + { "default", new OpenAIPromptExecutionSettings() + { + MaxTokens = 1000, + Temperature = 0, + ModelId = "gpt35", + } + }, + } + }, + promptTemplateFactory: new HandlebarsPromptTemplateFactory())); } public async Task RewriteAsync(Conversation conversation, string question) - { - KernelArguments ka = new(); - ka["chat_input"] = question; - + { + KernelArguments ka = new(); + ka["chat_input"] = question; ka["history"] = conversation.GetQuestions(); - var result = await _kernel.InvokeAsync("RewriteQuery", _chatFunction, ka); + var chatFunction = await CreateRewriteFunction(); + var result = await _kernel.InvokeAsync("RewriteQuery", chatFunction, ka); return result?.ToString() ?? question; }