-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
also added prompt store to query rewriter
- Loading branch information
1 parent
efae48b
commit 25be802
Showing
9 changed files
with
496 additions
and
91 deletions.
There are no files selected for viewing
174 changes: 174 additions & 0 deletions
174
src/KernelMemory.Extensions.FunctionalTests/Helper/LocalFolderPromptStoreTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
using System; | ||
using System.IO; | ||
using System.Threading.Tasks; | ||
using Microsoft.Extensions.Logging; | ||
using Moq; | ||
using Xunit; | ||
|
||
namespace KernelMemory.Extensions.FunctionalTests.Helper; | ||
|
||
public class LocalFolderPromptStoreTests : IDisposable | ||
{ | ||
private readonly string _testDirectory; | ||
private readonly Mock<ILogger<LocalFolderPromptStore>> _loggerMock; | ||
private readonly LocalFolderPromptStore _store; | ||
|
||
public LocalFolderPromptStoreTests() | ||
{ | ||
_testDirectory = Path.Combine(Path.GetTempPath(), $"promptstore_tests_{Guid.NewGuid()}"); | ||
_loggerMock = new Mock<ILogger<LocalFolderPromptStore>>(); | ||
_store = new LocalFolderPromptStore(_testDirectory, _loggerMock.Object); | ||
} | ||
|
||
public void Dispose() | ||
{ | ||
if (Directory.Exists(_testDirectory)) | ||
{ | ||
Directory.Delete(_testDirectory, true); | ||
} | ||
} | ||
|
||
[Fact] | ||
public async Task GetPromptAsync_NonExistentKey_ReturnsNull() | ||
{ | ||
// Act | ||
var result = await _store.GetPromptAsync("nonexistent"); | ||
|
||
// Assert | ||
Assert.Null(result); | ||
} | ||
|
||
[Fact] | ||
public async Task SetAndGetPromptAsync_ValidKey_ReturnsStoredPrompt() | ||
{ | ||
// Arrange | ||
const string key = "test-key"; | ||
const string expectedPrompt = "This is a test prompt"; | ||
|
||
// Act | ||
await _store.SetPromptAsync(key, expectedPrompt); | ||
var result = await _store.GetPromptAsync(key); | ||
|
||
// Assert | ||
Assert.Equal(expectedPrompt, result); | ||
} | ||
|
||
[Fact] | ||
public async Task GetPromptAndSetDefaultAsync_NonExistentKey_SetsAndReturnsDefault() | ||
{ | ||
// Arrange | ||
const string key = "default-key"; | ||
const string defaultPrompt = "Default prompt value"; | ||
|
||
// Act | ||
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); | ||
var storedPrompt = await _store.GetPromptAsync(key); | ||
|
||
// Assert | ||
Assert.Equal(defaultPrompt, result); | ||
Assert.Equal(defaultPrompt, storedPrompt); | ||
} | ||
|
||
[Fact] | ||
public async Task GetPromptAndSetDefaultAsync_ExistingKey_ReturnsExistingPrompt() | ||
{ | ||
// Arrange | ||
const string key = "existing-key"; | ||
const string existingPrompt = "Existing prompt"; | ||
const string defaultPrompt = "Default prompt"; | ||
await _store.SetPromptAsync(key, existingPrompt); | ||
|
||
// Act | ||
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); | ||
|
||
// Assert | ||
Assert.Equal(existingPrompt, result); | ||
} | ||
|
||
[Fact] | ||
public async Task SetPromptAsync_KeyWithSpecialCharacters_HandlesCorrectly() | ||
{ | ||
// Arrange | ||
const string key = "special/\\*:?\"<>|characters"; | ||
const string expectedPrompt = "Prompt with special characters"; | ||
|
||
// Act | ||
await _store.SetPromptAsync(key, expectedPrompt); | ||
var result = await _store.GetPromptAsync(key); | ||
|
||
// Assert | ||
Assert.Equal(expectedPrompt, result); | ||
} | ||
|
||
[Fact] | ||
public async Task GetPromptAndSetDefaultAsync_MissingPlaceholder_LogsError() | ||
{ | ||
// Arrange | ||
const string key = "test-placeholder"; | ||
const string existingPrompt = "A prompt without placeholder"; | ||
const string defaultPrompt = "Default prompt with {{$placeholder}}"; | ||
await _store.SetPromptAsync(key, existingPrompt); | ||
|
||
// Act | ||
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); | ||
|
||
// Assert | ||
_loggerMock.Verify( | ||
x => x.Log( | ||
LogLevel.Error, | ||
It.IsAny<EventId>(), | ||
It.Is<It.IsAnyType>((v, t) => v.ToString().Contains("{{$placeholder}}")), | ||
It.IsAny<Exception>(), | ||
It.IsAny<Func<It.IsAnyType, Exception, string>>() | ||
), | ||
Times.Once); | ||
} | ||
|
||
[Fact] | ||
public async Task GetPromptAndSetDefaultAsync_MultipleMissingPlaceholders_LogsMultipleErrors() | ||
{ | ||
// Arrange | ||
const string key = "test-multiple-placeholders"; | ||
const string existingPrompt = "A prompt without any placeholders"; | ||
const string defaultPrompt = "Default with {{$first}} and {{$second}}"; | ||
await _store.SetPromptAsync(key, existingPrompt); | ||
|
||
// Act | ||
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); | ||
|
||
// Assert | ||
_loggerMock.Verify( | ||
x => x.Log( | ||
LogLevel.Error, | ||
It.IsAny<EventId>(), | ||
It.Is<It.IsAnyType>((v, t) => true), | ||
It.IsAny<Exception>(), | ||
It.IsAny<Func<It.IsAnyType, Exception, string>>() | ||
), | ||
Times.Exactly(2)); | ||
} | ||
|
||
[Fact] | ||
public async Task GetPromptAndSetDefaultAsync_ValidPlaceholders_NoErrors() | ||
{ | ||
// Arrange | ||
const string key = "test-valid-placeholders"; | ||
const string existingPrompt = "A prompt with {{$placeholder}} correctly set"; | ||
const string defaultPrompt = "Default with {{$placeholder}}"; | ||
await _store.SetPromptAsync(key, existingPrompt); | ||
|
||
// Act | ||
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); | ||
|
||
// Assert | ||
_loggerMock.Verify( | ||
x => x.Log( | ||
LogLevel.Error, | ||
It.IsAny<EventId>(), | ||
It.Is<It.IsAnyType>((v, t) => true), | ||
It.IsAny<Exception>(), | ||
It.IsAny<Func<It.IsAnyType, Exception, string>>() | ||
), | ||
Times.Never); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
...KernelMemory.Extensions.FunctionalTests/QueryPipeline/SemanticKernelQueryRewriterTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
using KernelMemory.Extensions.Helper; | ||
using KernelMemory.Extensions.QueryPipeline; | ||
using Microsoft.SemanticKernel; | ||
using Microsoft.SemanticKernel.ChatCompletion; | ||
using Moq; | ||
|
||
namespace KernelMemory.Extensions.FunctionalTests.QueryPipeline; | ||
|
||
public class SemanticKernelQueryRewriterTests | ||
{ | ||
private readonly Mock<IPromptStore> _promptStoreMock; | ||
private readonly Mock<ISemanticKernelWrapper> _kernelWrapperMock; | ||
private readonly Mock<IChatCompletionService> _chatCompletionServiceMock; | ||
private readonly SemanticKernelQueryRewriterOptions _options; | ||
|
||
public SemanticKernelQueryRewriterTests() | ||
{ | ||
_promptStoreMock = new Mock<IPromptStore>(); | ||
_kernelWrapperMock = new Mock<ISemanticKernelWrapper>(); | ||
_chatCompletionServiceMock = new Mock<IChatCompletionService>(); | ||
_options = new SemanticKernelQueryRewriterOptions { ModelId = "gpt-4" }; | ||
|
||
_kernelWrapperMock.Setup(x => x.GetChatCompletionService()) | ||
.Returns(_chatCompletionServiceMock.Object); | ||
} | ||
|
||
[Fact] | ||
public async Task RewriteAsync_ShouldUsePromptFromStore() | ||
{ | ||
// Arrange | ||
var conversation = new Conversation(); | ||
var question = "What is the weather?"; | ||
var customPrompt = "Custom prompt template {{question}}"; | ||
|
||
_promptStoreMock.Setup(x => x.GetPromptAndSetDefaultAsync( | ||
"SemanticKernelQueryRewriter", | ||
It.IsAny<string>(), | ||
It.IsAny<CancellationToken>())) | ||
.ReturnsAsync(customPrompt); | ||
|
||
_chatCompletionServiceMock.Setup(x => x.GetChatMessageContentsAsync( | ||
It.IsAny<ChatHistory>(), | ||
It.IsAny<PromptExecutionSettings?>(), | ||
It.IsAny<Kernel?>(), | ||
It.IsAny<CancellationToken>())) | ||
.ReturnsAsync([new ChatMessageContent(AuthorRole.Assistant, "Rewritten question")]); | ||
|
||
var rewriter = new SemanticKernelQueryRewriter(_options, _promptStoreMock.Object, _kernelWrapperMock.Object); | ||
|
||
// Act | ||
var result = await rewriter.RewriteAsync(conversation, question); | ||
|
||
// Assert | ||
_promptStoreMock.Verify(x => x.GetPromptAndSetDefaultAsync( | ||
"SemanticKernelQueryRewriter", | ||
It.IsAny<string>(), | ||
It.IsAny<CancellationToken>()), | ||
Times.Once); | ||
|
||
Assert.Equal("Rewritten question", result); | ||
} | ||
|
||
/// <summary> | ||
/// IF we cannot rewrite, we cannot answer | ||
/// </summary> | ||
/// <returns></returns> | ||
[Fact] | ||
public async Task RewriteAsync_WhenChatCompletionFails_ShouldThrow() | ||
{ | ||
// Arrange | ||
var conversation = new Conversation(); | ||
var question = "What is the weather?"; | ||
|
||
_promptStoreMock.Setup(x => x.GetPromptAndSetDefaultAsync( | ||
It.IsAny<string>(), | ||
It.IsAny<string>(), | ||
It.IsAny<CancellationToken>())) | ||
.ReturnsAsync("prompt"); | ||
|
||
_chatCompletionServiceMock.Setup(x => x.GetChatMessageContentsAsync( | ||
It.IsAny<ChatHistory>(), | ||
It.IsAny<PromptExecutionSettings?>(), | ||
It.IsAny<Kernel?>(), | ||
It.IsAny<CancellationToken>())) | ||
.ReturnsAsync(new List<ChatMessageContent>()); | ||
|
||
var rewriter = new SemanticKernelQueryRewriter(_options, _promptStoreMock.Object, _kernelWrapperMock.Object); | ||
|
||
// Act | ||
await Assert.ThrowsAsync<InvalidOperationException>(() => rewriter.RewriteAsync(conversation, question)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.