Skip to content

Commit

Permalink
Implemented very first vesrion of hyde.
Browse files Browse the repository at this point in the history
  • Loading branch information
alkampfergit committed Nov 23, 2024
1 parent c2d316c commit 477c591
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,28 @@ public async Task RunSample2()
{
var services = new ServiceCollection();

var apiKey = Dotenv.Get("COHERE_API_KEY")!;
var cohereBaseUrl = Dotenv.Get("COHERE_BASE_API_KEY");
if (string.IsNullOrEmpty(cohereBaseUrl))
var cohereAzureBaseUrl = Dotenv.Get("COHERE_AZURE_BASE_URL");
if (string.IsNullOrEmpty(cohereAzureBaseUrl))
{
var apiKey = Dotenv.Get("COHERE_API_KEY")!;
services.ConfigureCohereChat(apiKey);
}
else
{
services.ConfigureCohereChat(apiKey, cohereBaseUrl);
var azureApiKey = Dotenv.Get("COHERE_AZURE_API_KEY");
services.ConfigureCohereChat(azureApiKey, cohereAzureBaseUrl);
}
//verify if rerank has a different api key (because the apikey point on azure ai studio)
var rerankApiKey = Dotenv.Get("COHERE_RERANK_API_KEY");
if (string.IsNullOrEmpty(rerankApiKey))
var rerankAzureBaseUrl = Dotenv.Get("COHERE_AZURE_RERANK_BASE_URL");
if (string.IsNullOrEmpty(rerankAzureBaseUrl))
{
var apiKey = Dotenv.Get("COHERE_API_KEY")!;
services.ConfigureCohereRerank(apiKey);
}
else
{
services.ConfigureCohereRerank(rerankApiKey);
var azureReRankApiKey = Dotenv.Get("COHERE_AZURE_RERANK_API_KEY");
services.ConfigureCohereRerank(azureReRankApiKey, rerankAzureBaseUrl);
}

services.AddHttpClient<RawCohereChatClient>()
Expand Down Expand Up @@ -89,12 +92,15 @@ public async Task RunSample2()
.Title("Select query rewriter")
.AddChoices(["Semantic Kernel Base", "Semantic Kernel Handlebar"]));

var useHyde = AnsiConsole.Confirm("Do you want to use HyDe? (y/n)", false);

var kernelBuider = CreateBasicKernelBuilder();
var builder = CreateBasicKernelMemoryBuilder(
services,
storageToUse == "elasticsearch",
queryExecutorToUse,
queryRewriterTool == "Semantic Kernel Handlebar");
queryRewriterTool == "Semantic Kernel Handlebar",
useHyde);
var kernelMemory = builder.Build<MemoryServerless>();
var kernel = kernelBuider.Build();

Expand Down Expand Up @@ -238,7 +244,8 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
ServiceCollection services,
bool useElasticSearch,
string ragToolToUse,
bool useHandlebarQueryRewriter)
bool useHandlebarQueryRewriter,
bool useHyde)
{
// we need a series of services to use Kernel Memory, the first one is
// an embedding service that will be used to create dense vector for
Expand Down Expand Up @@ -306,6 +313,12 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
services.AddSingleton<HandlebarSemanticKernelQueryRewriter>();
services.AddSingleton<SemanticKernelQueryRewriter>();
services.AddSingleton<StandardVectorSearchQueryHandler>();
services.AddSingleton<HyDeQueryHandler>();
var hydeConfig = new HiDeQueryHandlerConfiguration()
{
Prompt = "Given a question, generate a paragraph of text that answers the question in the context of computer security and IT security"
};
services.AddSingleton(hydeConfig);
services.AddSingleton<KeywordSearchQueryHandler>();

var rewriterOptions = new SemanticKernelQueryRewriterOptions();
Expand Down Expand Up @@ -337,6 +350,11 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
config.AddHandler<KeywordSearchQueryHandler>();
}

if (useHyde)
{
config.AddHandler<HyDeQueryHandler>();
}

if (ragToolToUse == "Cohere CommandR+")
{
config.AddHandler<CohereCommandRQueryExecutor>();
Expand Down
95 changes: 95 additions & 0 deletions src/KernelMemory.Extensions/QueryPipeline/HyDeQueryHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
using DocumentFormat.OpenXml.Packaging;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.MemoryStorage;
using Microsoft.SemanticKernel;
using System.Collections.Generic;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace KernelMemory.Extensions;

public class HiDeQueryHandlerConfiguration
{
public string Prompt { get; set; } = "Given a question, generate a paragraph of text that answers the question.";
}

public class HyDeQueryHandler : BasicQueryHandler
{
private readonly IMemoryDb _memoryDb;
private readonly Kernel _kernel;
private readonly HiDeQueryHandlerConfiguration _configuration;
private readonly ILogger<StandardVectorSearchQueryHandler> _log;

public override string Name => "StandardVectorSearchQueryHandler";

public HyDeQueryHandler(
IMemoryDb memory,
Kernel kernel,
HiDeQueryHandlerConfiguration? configuration,
ILogger<StandardVectorSearchQueryHandler>? log = null)
{
_memoryDb = memory;
_kernel = kernel;
_configuration = configuration ?? new HiDeQueryHandlerConfiguration();
_log = log ?? DefaultLogger<StandardVectorSearchQueryHandler>.Instance;
}

/// <summary>
/// Perform a vector search in default memory using the hyde principle.
/// </summary>
/// <param name="userQuestion"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
protected override async Task OnHandleAsync(UserQuestion userQuestion, CancellationToken cancellationToken)
{
// Perform a vector search in default memory
StringBuilder prompt = new StringBuilder(_configuration.Prompt.Length + userQuestion.Question.Length + "Question: ".Length + "Paragraph: ".Length + 20);
prompt.AppendLine(_configuration.Prompt);
prompt.AppendLine("Question: " + userQuestion.Question);
prompt.AppendLine("Paragraph: ");

var result = await _kernel.InvokePromptAsync(prompt.ToString(), cancellationToken: cancellationToken);

var paragraph = result.ToString();

var list = new List<(MemoryRecord memory, double relevance)>();

IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync(
index: userQuestion.UserQueryOptions.Index,
text: paragraph,
filters: userQuestion.Filters,
minRelevance: userQuestion.UserQueryOptions.MinRelevance,
limit: userQuestion.UserQueryOptions.RetrievalQueryLimit,
withEmbeddings: false,
cancellationToken: cancellationToken);

// Memories are sorted by relevance, starting from the most relevant
await foreach ((MemoryRecord memory, double relevance) in matches.ConfigureAwait(false))
{
list.Add((memory, relevance));
}

var records = new List<MemoryRecord>();
// Memories are sorted by relevance, starting from the most relevant
foreach ((MemoryRecord memory, double relevance) in list)
{
var partitionText = memory.GetPartitionText(this._log).Trim();
if (string.IsNullOrEmpty(partitionText))
{
this._log.LogError("The document partition is empty, doc: {0}", memory.Id);
continue;
}

if (relevance > float.MinValue)
{
this._log.LogTrace("Adding result with relevance {0}", relevance);
records.Add(memory);
}
}

//ok now that you have all the memory record and citations, add to the object
userQuestion.AddMemoryRecordSource("hyde-vector-search", records);
}
}

0 comments on commit 477c591

Please sign in to comment.