-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
also added prompt store to query rewriter
- Loading branch information
1 parent
efae48b
commit f42c239
Showing
5 changed files
with
359 additions
and
58 deletions.
There are no files selected for viewing
174 changes: 174 additions & 0 deletions
174
src/KernelMemory.Extensions.FunctionalTests/Helper/LocalFolderPromptStoreTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
using System; | ||
using System.IO; | ||
using System.Threading.Tasks; | ||
using Microsoft.Extensions.Logging; | ||
using Moq; | ||
using Xunit; | ||
|
||
namespace KernelMemory.Extensions.FunctionalTests.Helper; | ||
|
||
public class LocalFolderPromptStoreTests : IDisposable | ||
{ | ||
private readonly string _testDirectory; | ||
private readonly Mock<ILogger<LocalFolderPromptStore>> _loggerMock; | ||
private readonly LocalFolderPromptStore _store; | ||
|
||
public LocalFolderPromptStoreTests() | ||
{ | ||
_testDirectory = Path.Combine(Path.GetTempPath(), $"promptstore_tests_{Guid.NewGuid()}"); | ||
_loggerMock = new Mock<ILogger<LocalFolderPromptStore>>(); | ||
_store = new LocalFolderPromptStore(_testDirectory, _loggerMock.Object); | ||
} | ||
|
||
public void Dispose() | ||
{ | ||
if (Directory.Exists(_testDirectory)) | ||
{ | ||
Directory.Delete(_testDirectory, true); | ||
} | ||
} | ||
|
||
[Fact] | ||
public async Task GetPromptAsync_NonExistentKey_ReturnsNull() | ||
{ | ||
// Act | ||
var result = await _store.GetPromptAsync("nonexistent"); | ||
|
||
// Assert | ||
Assert.Null(result); | ||
} | ||
|
||
[Fact] | ||
public async Task SetAndGetPromptAsync_ValidKey_ReturnsStoredPrompt() | ||
{ | ||
// Arrange | ||
const string key = "test-key"; | ||
const string expectedPrompt = "This is a test prompt"; | ||
|
||
// Act | ||
await _store.SetPromptAsync(key, expectedPrompt); | ||
var result = await _store.GetPromptAsync(key); | ||
|
||
// Assert | ||
Assert.Equal(expectedPrompt, result); | ||
} | ||
|
||
[Fact] | ||
public async Task GetPromptAndSetDefaultAsync_NonExistentKey_SetsAndReturnsDefault() | ||
{ | ||
// Arrange | ||
const string key = "default-key"; | ||
const string defaultPrompt = "Default prompt value"; | ||
|
||
// Act | ||
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); | ||
var storedPrompt = await _store.GetPromptAsync(key); | ||
|
||
// Assert | ||
Assert.Equal(defaultPrompt, result); | ||
Assert.Equal(defaultPrompt, storedPrompt); | ||
} | ||
|
||
[Fact] | ||
public async Task GetPromptAndSetDefaultAsync_ExistingKey_ReturnsExistingPrompt() | ||
{ | ||
// Arrange | ||
const string key = "existing-key"; | ||
const string existingPrompt = "Existing prompt"; | ||
const string defaultPrompt = "Default prompt"; | ||
await _store.SetPromptAsync(key, existingPrompt); | ||
|
||
// Act | ||
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); | ||
|
||
// Assert | ||
Assert.Equal(existingPrompt, result); | ||
} | ||
|
||
[Fact] | ||
public async Task SetPromptAsync_KeyWithSpecialCharacters_HandlesCorrectly() | ||
{ | ||
// Arrange | ||
const string key = "special/\\*:?\"<>|characters"; | ||
const string expectedPrompt = "Prompt with special characters"; | ||
|
||
// Act | ||
await _store.SetPromptAsync(key, expectedPrompt); | ||
var result = await _store.GetPromptAsync(key); | ||
|
||
// Assert | ||
Assert.Equal(expectedPrompt, result); | ||
} | ||
|
||
[Fact] | ||
public async Task GetPromptAndSetDefaultAsync_MissingPlaceholder_LogsError() | ||
{ | ||
// Arrange | ||
const string key = "test-placeholder"; | ||
const string existingPrompt = "A prompt without placeholder"; | ||
const string defaultPrompt = "Default prompt with {{$placeholder}}"; | ||
await _store.SetPromptAsync(key, existingPrompt); | ||
|
||
// Act | ||
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); | ||
|
||
// Assert | ||
_loggerMock.Verify( | ||
x => x.Log( | ||
LogLevel.Error, | ||
It.IsAny<EventId>(), | ||
It.Is<It.IsAnyType>((v, t) => v.ToString().Contains("{{$placeholder}}")), | ||
It.IsAny<Exception>(), | ||
It.IsAny<Func<It.IsAnyType, Exception, string>>() | ||
), | ||
Times.Once); | ||
} | ||
|
||
[Fact] | ||
public async Task GetPromptAndSetDefaultAsync_MultipleMissingPlaceholders_LogsMultipleErrors() | ||
{ | ||
// Arrange | ||
const string key = "test-multiple-placeholders"; | ||
const string existingPrompt = "A prompt without any placeholders"; | ||
const string defaultPrompt = "Default with {{$first}} and {{$second}}"; | ||
await _store.SetPromptAsync(key, existingPrompt); | ||
|
||
// Act | ||
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); | ||
|
||
// Assert | ||
_loggerMock.Verify( | ||
x => x.Log( | ||
LogLevel.Error, | ||
It.IsAny<EventId>(), | ||
It.Is<It.IsAnyType>((v, t) => true), | ||
It.IsAny<Exception>(), | ||
It.IsAny<Func<It.IsAnyType, Exception, string>>() | ||
), | ||
Times.Exactly(2)); | ||
} | ||
|
||
[Fact] | ||
public async Task GetPromptAndSetDefaultAsync_ValidPlaceholders_NoErrors() | ||
{ | ||
// Arrange | ||
const string key = "test-valid-placeholders"; | ||
const string existingPrompt = "A prompt with {{$placeholder}} correctly set"; | ||
const string defaultPrompt = "Default with {{$placeholder}}"; | ||
await _store.SetPromptAsync(key, existingPrompt); | ||
|
||
// Act | ||
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt); | ||
|
||
// Assert | ||
_loggerMock.Verify( | ||
x => x.Log( | ||
LogLevel.Error, | ||
It.IsAny<EventId>(), | ||
It.Is<It.IsAnyType>((v, t) => true), | ||
It.IsAny<Exception>(), | ||
It.IsAny<Func<It.IsAnyType, Exception, string>>() | ||
), | ||
Times.Never); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
using System.Linq; | ||
using System.Threading.Tasks; | ||
using Microsoft.Extensions.Logging; | ||
using System.Text.RegularExpressions; | ||
|
||
namespace KernelMemory.Extensions; | ||
|
||
public abstract class BasePromptStore : IPromptStore | ||
{ | ||
private static readonly Regex PlaceholderRegex = new(@"\{\{\$\w+\}\}", RegexOptions.Compiled); | ||
|
||
protected readonly ILogger _log; | ||
|
||
protected BasePromptStore(ILogger log) | ||
{ | ||
_log = log; | ||
} | ||
|
||
protected void ValidatePlaceholders(string defaultPrompt, string loadedPrompt) | ||
{ | ||
var placeholders = PlaceholderRegex.Matches(defaultPrompt) | ||
.Cast<Match>() | ||
.Select(m => m.Value) | ||
.ToList(); | ||
|
||
foreach (var placeholder in placeholders) | ||
{ | ||
if (!loadedPrompt.Contains(placeholder)) | ||
{ | ||
_log.LogError("The prompt does not contain {Placeholder} placeholder, the prompt will not work correctly", placeholder); | ||
} | ||
} | ||
} | ||
|
||
public abstract Task<string?> GetPromptAsync(string key); | ||
public abstract Task SetPromptAsync(string key, string prompt); | ||
|
||
public virtual async Task<string> GetPromptAndSetDefaultAsync(string key, string defaultPrompt) | ||
{ | ||
var prompt = await GetPromptAsync(key); | ||
if (prompt == null) | ||
{ | ||
await SetPromptAsync(key, defaultPrompt); | ||
return defaultPrompt; | ||
} | ||
|
||
ValidatePlaceholders(defaultPrompt, prompt); | ||
return prompt; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
using System; | ||
using System.IO; | ||
using System.Threading.Tasks; | ||
using Microsoft.Extensions.Logging; | ||
using System.Linq; | ||
using Microsoft.KernelMemory.Diagnostics; | ||
|
||
namespace KernelMemory.Extensions; | ||
|
||
public class LocalFolderPromptStore : BasePromptStore | ||
{ | ||
private readonly string _promptDirectory; | ||
|
||
public LocalFolderPromptStore(string promptDirectory, ILogger<LocalFolderPromptStore>? log = null) | ||
: base(log ?? DefaultLogger<LocalFolderPromptStore>.Instance) | ||
{ | ||
_promptDirectory = promptDirectory; | ||
Directory.CreateDirectory(promptDirectory); | ||
} | ||
|
||
private string GetPromptFilePath(string key) | ||
{ | ||
var sanitizedKey = string.Join("_", key.Split(Path.GetInvalidFileNameChars())); | ||
return Path.Combine(_promptDirectory, $"{sanitizedKey}.prompt"); | ||
} | ||
|
||
public override async Task<string?> GetPromptAsync(string key) | ||
{ | ||
var filePath = GetPromptFilePath(key); | ||
if (!File.Exists(filePath)) | ||
{ | ||
return null; | ||
} | ||
|
||
return await File.ReadAllTextAsync(filePath); | ||
} | ||
|
||
public override async Task SetPromptAsync(string key, string prompt) | ||
{ | ||
var filePath = GetPromptFilePath(key); | ||
await File.WriteAllTextAsync(filePath, prompt); | ||
} | ||
} |
Oops, something went wrong.