Skip to content

Commit

Permalink
added disk based prompt store
Browse files Browse the repository at this point in the history
also added prompt store to query rewriter
  • Loading branch information
alkampfergit committed Jan 9, 2025
1 parent efae48b commit 25be802
Show file tree
Hide file tree
Showing 9 changed files with 496 additions and 91 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
using System;
using System.IO;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Moq;
using Xunit;

namespace KernelMemory.Extensions.FunctionalTests.Helper;

public class LocalFolderPromptStoreTests : IDisposable
{
private readonly string _testDirectory;
private readonly Mock<ILogger<LocalFolderPromptStore>> _loggerMock;
private readonly LocalFolderPromptStore _store;

public LocalFolderPromptStoreTests()
{
_testDirectory = Path.Combine(Path.GetTempPath(), $"promptstore_tests_{Guid.NewGuid()}");
_loggerMock = new Mock<ILogger<LocalFolderPromptStore>>();
_store = new LocalFolderPromptStore(_testDirectory, _loggerMock.Object);
}

public void Dispose()
{
if (Directory.Exists(_testDirectory))
{
Directory.Delete(_testDirectory, true);
}
}

[Fact]
public async Task GetPromptAsync_NonExistentKey_ReturnsNull()
{
// Act
var result = await _store.GetPromptAsync("nonexistent");

// Assert
Assert.Null(result);
}

[Fact]
public async Task SetAndGetPromptAsync_ValidKey_ReturnsStoredPrompt()
{
// Arrange
const string key = "test-key";
const string expectedPrompt = "This is a test prompt";

// Act
await _store.SetPromptAsync(key, expectedPrompt);
var result = await _store.GetPromptAsync(key);

// Assert
Assert.Equal(expectedPrompt, result);
}

[Fact]
public async Task GetPromptAndSetDefaultAsync_NonExistentKey_SetsAndReturnsDefault()
{
// Arrange
const string key = "default-key";
const string defaultPrompt = "Default prompt value";

// Act
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);
var storedPrompt = await _store.GetPromptAsync(key);

// Assert
Assert.Equal(defaultPrompt, result);
Assert.Equal(defaultPrompt, storedPrompt);
}

[Fact]
public async Task GetPromptAndSetDefaultAsync_ExistingKey_ReturnsExistingPrompt()
{
// Arrange
const string key = "existing-key";
const string existingPrompt = "Existing prompt";
const string defaultPrompt = "Default prompt";
await _store.SetPromptAsync(key, existingPrompt);

// Act
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);

// Assert
Assert.Equal(existingPrompt, result);
}

[Fact]
public async Task SetPromptAsync_KeyWithSpecialCharacters_HandlesCorrectly()
{
// Arrange
const string key = "special/\\*:?\"<>|characters";
const string expectedPrompt = "Prompt with special characters";

// Act
await _store.SetPromptAsync(key, expectedPrompt);
var result = await _store.GetPromptAsync(key);

// Assert
Assert.Equal(expectedPrompt, result);
}

[Fact]
public async Task GetPromptAndSetDefaultAsync_MissingPlaceholder_LogsError()
{
// Arrange
const string key = "test-placeholder";
const string existingPrompt = "A prompt without placeholder";
const string defaultPrompt = "Default prompt with {{$placeholder}}";
await _store.SetPromptAsync(key, existingPrompt);

// Act
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);

// Assert
_loggerMock.Verify(
x => x.Log(
LogLevel.Error,
It.IsAny<EventId>(),
It.Is<It.IsAnyType>((v, t) => v.ToString().Contains("{{$placeholder}}")),
It.IsAny<Exception>(),
It.IsAny<Func<It.IsAnyType, Exception, string>>()
),
Times.Once);
}

[Fact]
public async Task GetPromptAndSetDefaultAsync_MultipleMissingPlaceholders_LogsMultipleErrors()
{
// Arrange
const string key = "test-multiple-placeholders";
const string existingPrompt = "A prompt without any placeholders";
const string defaultPrompt = "Default with {{$first}} and {{$second}}";
await _store.SetPromptAsync(key, existingPrompt);

// Act
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);

// Assert
_loggerMock.Verify(
x => x.Log(
LogLevel.Error,
It.IsAny<EventId>(),
It.Is<It.IsAnyType>((v, t) => true),
It.IsAny<Exception>(),
It.IsAny<Func<It.IsAnyType, Exception, string>>()
),
Times.Exactly(2));
}

[Fact]
public async Task GetPromptAndSetDefaultAsync_ValidPlaceholders_NoErrors()
{
// Arrange
const string key = "test-valid-placeholders";
const string existingPrompt = "A prompt with {{$placeholder}} correctly set";
const string defaultPrompt = "Default with {{$placeholder}}";
await _store.SetPromptAsync(key, existingPrompt);

// Act
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);

// Assert
_loggerMock.Verify(
x => x.Log(
LogLevel.Error,
It.IsAny<EventId>(),
It.Is<It.IsAnyType>((v, t) => true),
It.IsAny<Exception>(),
It.IsAny<Func<It.IsAnyType, Exception, string>>()
),
Times.Never);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@ namespace KernelMemory.Extensions.FunctionalTests.Helper;

public class OpenaiRagQueryExecutorTests
{
private Kernel _kernel;
private OpenaiRagQueryExecutor _sut;
private Mock<IPromptStore> _mockPromptStore;
private Mock<ILogger<StandardRagQueryExecutor>> _mockLogger;
private Mock<ISemanticKernelWrapper> _mockKernel;
private readonly OpenaiRagQueryExecutor _sut;
private readonly Mock<IPromptStore> _mockPromptStore;
private readonly Mock<ILogger<StandardRagQueryExecutor>> _mockLogger;
private readonly Mock<ISemanticKernelWrapper> _mockKernel;

public OpenaiRagQueryExecutorTests()
{
_kernel = new Kernel();
var kernel = new Kernel();
_mockPromptStore = new Mock<IPromptStore>();
_mockLogger = new Mock<ILogger<StandardRagQueryExecutor>>();
_mockKernel = new Mock<ISemanticKernelWrapper>();
Expand All @@ -29,7 +28,7 @@ public async Task GetPromptAsync_ShouldReturnPromptFromMock()
{
// Arrange
var expectedPrompt = "Test Prompt";
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>())).ReturnsAsync(expectedPrompt);
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>(), It.IsAny<CancellationToken>())).ReturnsAsync(expectedPrompt);

// Act
var task = (Task<string>)_sut.CallMethod("GetPromptAsync");
Expand All @@ -44,7 +43,7 @@ public async Task GetPromptAsync_Should_validate_log_called()
{
// Arrange
var invalidPrompt = "Invalid Prompt";
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>())).ReturnsAsync(invalidPrompt);
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>(), It.IsAny<CancellationToken>())).ReturnsAsync(invalidPrompt);

// Act
var task = (Task<string>)_sut.CallMethod("GetPromptAsync");
Expand Down Expand Up @@ -73,15 +72,15 @@ public async Task GetPromptAsync_Should_validate_log_called()

// verify that store method of Ipromptstore is not called
_mockPromptStore.Verify(
store => store.SetPromptAsync(It.IsAny<string>(), It.IsAny<string>()),
store => store.SetPromptAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<CancellationToken>()),
Times.Never);
}

[Fact]
public async Task If_prompt_not_saved_reload()
{
// Arrange
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>())).ReturnsAsync((String?) null);
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>(), It.IsAny<CancellationToken>())).ReturnsAsync((String?) null);

// Act
var task = (Task<string>)_sut.CallMethod("GetPromptAsync");
Expand All @@ -90,7 +89,7 @@ public async Task If_prompt_not_saved_reload()
// Assert
// verify that store method of Ipromptstore is called
_mockPromptStore.Verify(
store => store.SetPromptAsync("OpenaiRagQueryExecutor", It.IsAny<string>()),
store => store.SetPromptAsync("OpenaiRagQueryExecutor", It.IsAny<string>(), It.IsAny<CancellationToken>()),
Times.Once);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
using KernelMemory.Extensions.Helper;
using KernelMemory.Extensions.QueryPipeline;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Moq;

namespace KernelMemory.Extensions.FunctionalTests.QueryPipeline;

public class SemanticKernelQueryRewriterTests
{
private readonly Mock<IPromptStore> _promptStoreMock;
private readonly Mock<ISemanticKernelWrapper> _kernelWrapperMock;
private readonly Mock<IChatCompletionService> _chatCompletionServiceMock;
private readonly SemanticKernelQueryRewriterOptions _options;

public SemanticKernelQueryRewriterTests()
{
_promptStoreMock = new Mock<IPromptStore>();
_kernelWrapperMock = new Mock<ISemanticKernelWrapper>();
_chatCompletionServiceMock = new Mock<IChatCompletionService>();
_options = new SemanticKernelQueryRewriterOptions { ModelId = "gpt-4" };

_kernelWrapperMock.Setup(x => x.GetChatCompletionService())
.Returns(_chatCompletionServiceMock.Object);
}

[Fact]
public async Task RewriteAsync_ShouldUsePromptFromStore()
{
// Arrange
var conversation = new Conversation();
var question = "What is the weather?";
var customPrompt = "Custom prompt template {{question}}";

_promptStoreMock.Setup(x => x.GetPromptAndSetDefaultAsync(
"SemanticKernelQueryRewriter",
It.IsAny<string>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(customPrompt);

_chatCompletionServiceMock.Setup(x => x.GetChatMessageContentsAsync(
It.IsAny<ChatHistory>(),
It.IsAny<PromptExecutionSettings?>(),
It.IsAny<Kernel?>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync([new ChatMessageContent(AuthorRole.Assistant, "Rewritten question")]);

var rewriter = new SemanticKernelQueryRewriter(_options, _promptStoreMock.Object, _kernelWrapperMock.Object);

// Act
var result = await rewriter.RewriteAsync(conversation, question);

// Assert
_promptStoreMock.Verify(x => x.GetPromptAndSetDefaultAsync(
"SemanticKernelQueryRewriter",
It.IsAny<string>(),
It.IsAny<CancellationToken>()),
Times.Once);

Assert.Equal("Rewritten question", result);
}

/// <summary>
/// IF we cannot rewrite, we cannot answer
/// </summary>
/// <returns></returns>
[Fact]
public async Task RewriteAsync_WhenChatCompletionFails_ShouldThrow()
{
// Arrange
var conversation = new Conversation();
var question = "What is the weather?";

_promptStoreMock.Setup(x => x.GetPromptAndSetDefaultAsync(
It.IsAny<string>(),
It.IsAny<string>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync("prompt");

_chatCompletionServiceMock.Setup(x => x.GetChatMessageContentsAsync(
It.IsAny<ChatHistory>(),
It.IsAny<PromptExecutionSettings?>(),
It.IsAny<Kernel?>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new List<ChatMessageContent>());

var rewriter = new SemanticKernelQueryRewriter(_options, _promptStoreMock.Object, _kernelWrapperMock.Object);

// Act
await Assert.ThrowsAsync<InvalidOperationException>(() => rewriter.RewriteAsync(conversation, question));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public Task<IReadOnlyCollection<MemoryRecord>> ReRankAsync(string question, IRea

private class TestQueryRewriter : IConversationQueryRewriter
{
public Task<string> RewriteAsync(Conversation conversation, string question)
public Task<string> RewriteAsync(Conversation conversation, string question, CancellationToken cancellationToken = default)
{
return Task.FromResult(question);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ public async Task Basic_conversation_handling()

//Generate mock for conversation rewriter
var conversationRewriterMock = new Mock<IConversationQueryRewriter>();
conversationRewriterMock.Setup(x => x.RewriteAsync(It.IsAny<Conversation>(), It.IsAny<string>()))
conversationRewriterMock.Setup(x => x.RewriteAsync(It.IsAny<Conversation>(), It.IsAny<string>(), It.IsAny<CancellationToken>()))
.Returns(Task.FromResult("New rewritten question"));
sut.SetConversationQueryRewriter(conversationRewriterMock.Object);

Expand All @@ -199,7 +199,7 @@ public async Task Basic_conversation_handling_async()

//Generate mock for conversation rewriter
var conversationRewriterMock = new Mock<IConversationQueryRewriter>();
conversationRewriterMock.Setup(x => x.RewriteAsync(It.IsAny<Conversation>(), It.IsAny<string>()))
conversationRewriterMock.Setup(x => x.RewriteAsync(It.IsAny<Conversation>(), It.IsAny<string>(), It.IsAny<CancellationToken>()))
.Returns(Task.FromResult("New rewritten question"));
sut.SetConversationQueryRewriter(conversationRewriterMock.Object);

Expand Down
Loading

0 comments on commit 25be802

Please sign in to comment.