Skip to content

Commit

Permalink
added disk based prompt store
Browse files Browse the repository at this point in the history
also added prompt store to query rewriter
  • Loading branch information
alkampfergit committed Jan 9, 2025
1 parent efae48b commit f42c239
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 58 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
using System;
using System.IO;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Moq;
using Xunit;

namespace KernelMemory.Extensions.FunctionalTests.Helper;

public class LocalFolderPromptStoreTests : IDisposable
{
private readonly string _testDirectory;
private readonly Mock<ILogger<LocalFolderPromptStore>> _loggerMock;
private readonly LocalFolderPromptStore _store;

public LocalFolderPromptStoreTests()
{
_testDirectory = Path.Combine(Path.GetTempPath(), $"promptstore_tests_{Guid.NewGuid()}");
_loggerMock = new Mock<ILogger<LocalFolderPromptStore>>();
_store = new LocalFolderPromptStore(_testDirectory, _loggerMock.Object);
}

public void Dispose()
{
if (Directory.Exists(_testDirectory))
{
Directory.Delete(_testDirectory, true);
}
}

[Fact]
public async Task GetPromptAsync_NonExistentKey_ReturnsNull()
{
// Act
var result = await _store.GetPromptAsync("nonexistent");

// Assert
Assert.Null(result);
}

[Fact]
public async Task SetAndGetPromptAsync_ValidKey_ReturnsStoredPrompt()
{
// Arrange
const string key = "test-key";
const string expectedPrompt = "This is a test prompt";

// Act
await _store.SetPromptAsync(key, expectedPrompt);
var result = await _store.GetPromptAsync(key);

// Assert
Assert.Equal(expectedPrompt, result);
}

[Fact]
public async Task GetPromptAndSetDefaultAsync_NonExistentKey_SetsAndReturnsDefault()
{
// Arrange
const string key = "default-key";
const string defaultPrompt = "Default prompt value";

// Act
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);
var storedPrompt = await _store.GetPromptAsync(key);

// Assert
Assert.Equal(defaultPrompt, result);
Assert.Equal(defaultPrompt, storedPrompt);
}

[Fact]
public async Task GetPromptAndSetDefaultAsync_ExistingKey_ReturnsExistingPrompt()
{
// Arrange
const string key = "existing-key";
const string existingPrompt = "Existing prompt";
const string defaultPrompt = "Default prompt";
await _store.SetPromptAsync(key, existingPrompt);

// Act
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);

// Assert
Assert.Equal(existingPrompt, result);
}

[Fact]
public async Task SetPromptAsync_KeyWithSpecialCharacters_HandlesCorrectly()
{
// Arrange
const string key = "special/\\*:?\"<>|characters";
const string expectedPrompt = "Prompt with special characters";

// Act
await _store.SetPromptAsync(key, expectedPrompt);
var result = await _store.GetPromptAsync(key);

// Assert
Assert.Equal(expectedPrompt, result);
}

[Fact]
public async Task GetPromptAndSetDefaultAsync_MissingPlaceholder_LogsError()
{
// Arrange
const string key = "test-placeholder";
const string existingPrompt = "A prompt without placeholder";
const string defaultPrompt = "Default prompt with {{$placeholder}}";
await _store.SetPromptAsync(key, existingPrompt);

// Act
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);

// Assert
_loggerMock.Verify(
x => x.Log(
LogLevel.Error,
It.IsAny<EventId>(),
It.Is<It.IsAnyType>((v, t) => v.ToString().Contains("{{$placeholder}}")),
It.IsAny<Exception>(),
It.IsAny<Func<It.IsAnyType, Exception, string>>()
),
Times.Once);
}

[Fact]
public async Task GetPromptAndSetDefaultAsync_MultipleMissingPlaceholders_LogsMultipleErrors()
{
// Arrange
const string key = "test-multiple-placeholders";
const string existingPrompt = "A prompt without any placeholders";
const string defaultPrompt = "Default with {{$first}} and {{$second}}";
await _store.SetPromptAsync(key, existingPrompt);

// Act
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);

// Assert
_loggerMock.Verify(
x => x.Log(
LogLevel.Error,
It.IsAny<EventId>(),
It.Is<It.IsAnyType>((v, t) => true),
It.IsAny<Exception>(),
It.IsAny<Func<It.IsAnyType, Exception, string>>()
),
Times.Exactly(2));
}

[Fact]
public async Task GetPromptAndSetDefaultAsync_ValidPlaceholders_NoErrors()
{
// Arrange
const string key = "test-valid-placeholders";
const string existingPrompt = "A prompt with {{$placeholder}} correctly set";
const string defaultPrompt = "Default with {{$placeholder}}";
await _store.SetPromptAsync(key, existingPrompt);

// Act
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);

// Assert
_loggerMock.Verify(
x => x.Log(
LogLevel.Error,
It.IsAny<EventId>(),
It.Is<It.IsAnyType>((v, t) => true),
It.IsAny<Exception>(),
It.IsAny<Func<It.IsAnyType, Exception, string>>()
),
Times.Never);
}
}
50 changes: 50 additions & 0 deletions src/KernelMemory.Extensions/BasePromptStore.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using System.Text.RegularExpressions;

namespace KernelMemory.Extensions;

public abstract class BasePromptStore : IPromptStore
{
private static readonly Regex PlaceholderRegex = new(@"\{\{\$\w+\}\}", RegexOptions.Compiled);

protected readonly ILogger _log;

protected BasePromptStore(ILogger log)
{
_log = log;
}

protected void ValidatePlaceholders(string defaultPrompt, string loadedPrompt)
{
var placeholders = PlaceholderRegex.Matches(defaultPrompt)
.Cast<Match>()
.Select(m => m.Value)
.ToList();

foreach (var placeholder in placeholders)
{
if (!loadedPrompt.Contains(placeholder))
{
_log.LogError("The prompt does not contain {Placeholder} placeholder, the prompt will not work correctly", placeholder);
}
}
}

public abstract Task<string?> GetPromptAsync(string key);
public abstract Task SetPromptAsync(string key, string prompt);

public virtual async Task<string> GetPromptAndSetDefaultAsync(string key, string defaultPrompt)
{
var prompt = await GetPromptAsync(key);
if (prompt == null)
{
await SetPromptAsync(key, defaultPrompt);
return defaultPrompt;
}

ValidatePlaceholders(defaultPrompt, prompt);
return prompt;
}
}
19 changes: 19 additions & 0 deletions src/KernelMemory.Extensions/IPromptStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ public interface IPromptStore
/// various components will use some default prompts.</returns>
Task<string?> GetPromptAsync(string key);

/// <summary>
/// Get the prompt for the given key and set the default prompt if the prompt is not present.
/// </summary>
/// <param name="key"></param>
/// <param name="defaultPrompt"></param>
/// <returns></returns>
Task<string> GetPromptAndSetDefaultAsync(string key, string defaultPrompt);

/// <summary>
/// Allow setting prompt value.
/// </summary>
Expand All @@ -45,6 +53,17 @@ public class NullPromptStore : IPromptStore
return Task.FromResult<string?>(null);
}

/// <summary>
/// Return the default prompt always
/// </summary>
/// <param name="key"></param>
/// <param name="defaultPrompt"></param>
/// <returns></returns>
public Task<string> GetPromptAndSetDefaultAsync(string key, string defaultPrompt)
{
return Task.FromResult<string>(defaultPrompt);
}

/// <summary>
/// Allow setting prompt value.
/// </summary>
Expand Down
43 changes: 43 additions & 0 deletions src/KernelMemory.Extensions/LocalFolderPromptStore.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
using System;
using System.IO;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using System.Linq;
using Microsoft.KernelMemory.Diagnostics;

namespace KernelMemory.Extensions;

public class LocalFolderPromptStore : BasePromptStore
{
private readonly string _promptDirectory;

public LocalFolderPromptStore(string promptDirectory, ILogger<LocalFolderPromptStore>? log = null)
: base(log ?? DefaultLogger<LocalFolderPromptStore>.Instance)
{
_promptDirectory = promptDirectory;
Directory.CreateDirectory(promptDirectory);
}

private string GetPromptFilePath(string key)
{
var sanitizedKey = string.Join("_", key.Split(Path.GetInvalidFileNameChars()));
return Path.Combine(_promptDirectory, $"{sanitizedKey}.prompt");
}

public override async Task<string?> GetPromptAsync(string key)
{
var filePath = GetPromptFilePath(key);
if (!File.Exists(filePath))
{
return null;
}

return await File.ReadAllTextAsync(filePath);
}

public override async Task SetPromptAsync(string key, string prompt)
{
var filePath = GetPromptFilePath(key);
await File.WriteAllTextAsync(filePath, prompt);
}
}
Loading

0 comments on commit f42c239

Please sign in to comment.