Skip to content

Commit

Permalink
Added context to cohere.
Browse files Browse the repository at this point in the history
  • Loading branch information
alkampfergit committed Jan 9, 2025
1 parent 554cd00 commit d0d7d3e
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,15 @@ public async Task RunSample2()
call.TokenCount.CachedTokenRead,
call.TokenCount.CachedTokenWrite);
}

if (call.Warnings.Count > 0)
{
Console.WriteLine("Warnings:");
foreach (var warning in call.Warnings)
{
Console.WriteLine(warning);
}
}
}
}
else
Expand Down Expand Up @@ -351,6 +360,7 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(

services.AddSingleton<IKernelMemoryBuilder>(kernelMemoryBuilder);
services.AddSingleton<CohereReRanker>();

services.AddSingleton<HandlebarSemanticKernelQueryRewriter>();
services.AddSingleton<SemanticKernelQueryRewriter>();
services.AddSingleton<StandardVectorSearchQueryHandler>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.MemoryStorage;
using Polly.Fallback;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
Expand Down Expand Up @@ -47,15 +48,18 @@ public class CohereCommandRQueryExecutor : BasicAsyncQueryHandlerWithProgress

private readonly RawCohereClient _rawCohereClient;
private readonly CohereCommandRQueryExecutorConfiguration _config;
private readonly CohereTokenizer _cohereTokenizer;
private readonly ILogger<StandardRagQueryExecutor> _log;

public CohereCommandRQueryExecutor(
RawCohereClient rawCohereClient,
CohereCommandRQueryExecutorConfiguration config,
CohereTokenizer cohereTokenizer,
ILogger<StandardRagQueryExecutor>? log = null)
{
_rawCohereClient = rawCohereClient;
_config = config;
_cohereTokenizer = cohereTokenizer;
_log = log ?? DefaultLogger<StandardRagQueryExecutor>.Instance;
}

Expand Down
2 changes: 2 additions & 0 deletions src/KernelMemory.Extensions/Cohere/CohereConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ public static IServiceCollection ConfigureCohereChat(
BaseUrl = baseUrl,
});

services.AddSingleton<CohereTokenizer>();

return services;
}

Expand Down
54 changes: 46 additions & 8 deletions src/KernelMemory.Extensions/Cohere/RawCohereChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,32 @@
using System.Threading.Tasks;
using KernelMemory.Extensions.Helper;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.Context;
using Microsoft.KernelMemory.Diagnostics;

namespace KernelMemory.Extensions.Cohere;

public class RawCohereChatClient
{
private readonly HttpClient _httpClient;
private readonly HttpClient _httpClient;
private readonly IContextProvider _contextProvider;
private readonly ILogger<RawCohereChatClient> _log;
private readonly string _apiKey;
private readonly string _baseUrl;

public RawCohereChatClient(
CohereChatConfiguration config,
HttpClient httpClient,
IContextProvider contextProvider,
ILogger<RawCohereChatClient>? log = null)
{
if (String.IsNullOrEmpty(config.ApiKey))
{
throw new ArgumentException("ApiKey is required", nameof(config.ApiKey));
}

this._httpClient = httpClient;
_httpClient = httpClient;
_contextProvider = contextProvider;
_log = log ?? DefaultLogger<RawCohereChatClient>.Instance;
_apiKey = config.ApiKey;
_baseUrl = config.BaseUrl;
Expand Down Expand Up @@ -91,10 +95,9 @@ public async IAsyncEnumerable<CohereRagStreamingResponse> RagQueryStreamingAsync
CohereRagRequest cohereRagRequest,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (cohereRagRequest is null)
{
throw new ArgumentNullException(nameof(cohereRagRequest));
}
ArgumentNullException.ThrowIfNull(cohereRagRequest);

var context = _contextProvider.GetContext();

var client = _httpClient;
//force streaming
Expand Down Expand Up @@ -130,7 +133,7 @@ public async IAsyncEnumerable<CohereRagStreamingResponse> RagQueryStreamingAsync
string line = (await reader.ReadLineAsync(cancellationToken))!;
var data = JsonSerializer.Deserialize<ChatStreamEvent>(line)!;

if (data.EventType == "stream-start" || data.EventType == "stream-end" || data.EventType == "search-results")
if (data.EventType == "stream-start" || data.EventType == "search-results")
{
//not interested in this events
continue;
Expand All @@ -152,12 +155,47 @@ public async IAsyncEnumerable<CohereRagStreamingResponse> RagQueryStreamingAsync
ResponseType = CohereRagResponseType.Citations
};
}
else if (data.EventType == "stream-end")
{
//create log
AddLog(context, "CommandR+RAG", cohereRagRequest.Describe(), data);
}
else
{
//not supported.
_log.LogWarning("Cohere stream api receved unknown event data type {0}", data.EventType);
}
}
}
}
}

private void AddLog(
IContext context,
string name,
string input,
ChatStreamEvent data)
{
LLMCallLog callLog = new()
{
CallName = name,
ReturnObject = data,
InputPrompt = input,
Output = data.Response.Text,
TokenCount = new TokenCount()
{
InputTokens = data.Response?.Meta.Tokens.InputTokens ?? 0,
OutputTokens = data.Response?.Meta.Tokens.OutputTokens ?? 0,
}
};

if (data.Response?.Meta.Warnings?.Length > 0)
{
foreach (var warning in data.Response.Meta.Warnings)
{
callLog.AddWarning(warning);
}
}

context.AddCallLog(callLog);
}
}
64 changes: 63 additions & 1 deletion src/KernelMemory.Extensions/Cohere/RawCohereClientDtos.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using Microsoft.KernelMemory.MemoryStorage;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;

namespace KernelMemory.Extensions.Cohere;
Expand Down Expand Up @@ -43,16 +46,51 @@ public static CohereRagRequest CreateFromMemoryRecord(string question, IEnumerab

foreach (var memory in memoryRecords)
{
//if the text is more than 300 words we need to split it
var text = memory.GetPartitionText();
int start = 0;
int spaceCount = 0;
for (int i = 0; i < text.Length; i++)
{
if (text[i] == ' ')
{
spaceCount++;
}
if (spaceCount > 250)
{
ragRequest.Documents.Add(new RagDocument()
{
DocId = memory.Id,
Text = text[start..i]
});
start = i;
spaceCount = 0;
}
}

ragRequest.Documents.Add(new RagDocument()
{
DocId = memory.Id,
Text = memory.GetPartitionText()
Text = text[start..text.Length]
});
}

return ragRequest;
}

internal string Describe()
{
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.AppendLine($"Message: {Message}");
stringBuilder.AppendLine($"Model: {Model}");
stringBuilder.AppendLine($"Document count: {Documents.Count}");
stringBuilder.AppendLine($"Temperature: {Temperature}");
stringBuilder.AppendLine($"Stream: {Stream}");
stringBuilder.AppendLine($"\n\nFullDocuments\n{string.Join("\n", Documents.Select(d => d.Text))}");

return stringBuilder.ToString();
}

[JsonPropertyName("message")]
public string Message { get; set; }

Expand Down Expand Up @@ -321,6 +359,30 @@ public class ChatStreamEvent

[JsonPropertyName("citations")]
public List<CohereRagCitation> Citations { get; set; }

[JsonPropertyName("response")]
public ChatStreamingResponse Response { get; set; }
}

public class ChatStreamingResponse
{
[JsonPropertyName("response_id")]
public string ResponseId { get; set; }

[JsonPropertyName("text")]
public string Text { get; set; }

[JsonPropertyName("generation_id")]
public string GenerationId { get; set; }

[JsonPropertyName("chat_history")]
public List<ChatMessage> ChatHistory { get; set; }

[JsonPropertyName("finish_reason")]
public string FinishReason { get; set; }

[JsonPropertyName("meta")]
public Meta Meta { get; set; }
}

public class CohereRagCitation
Expand Down
12 changes: 12 additions & 0 deletions src/KernelMemory.Extensions/Helper/LLMCallLog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using OpenAI.Chat;
using System;
using System.Collections.Generic;
using System.Linq;

Expand All @@ -20,6 +21,10 @@ public class LLMCallLog

public object? ReturnObject { get; set; }

public IReadOnlyList<string> Warnings => _warnings;

private readonly List<string> _warnings = new();

public TokenCount TokenCount { get; set; } = null!;

public void AddOpenaiChatMessageContent(OpenAIChatMessageContent mc)
Expand Down Expand Up @@ -50,6 +55,11 @@ public void AddOpenaiChatMessageContent(OpenAIChatMessageContent mc)
};
}
}

public void AddWarning(string warning)
{
_warnings.Add(warning);
}
}

public class TokenCount
Expand All @@ -66,6 +76,8 @@ public class TokenCount
/// </summary>
public class LLMCallLogContext
{
public Guid Id { get; private set; } = Guid.NewGuid();

public IReadOnlyList<LLMCallLog> CallLogs => _callLogs;

private readonly List<LLMCallLog> _callLogs = new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace KernelMemory.Extensions.Helper;
public interface ISemanticKernelWrapper
{
KernelFunction CreateFunctionFromMethod(Delegate method, string functionName);

KernelPlugin CreateFromFunctions(string pluginName, IEnumerable<KernelFunction> functions);

KernelFunction CreateFunctionFromPrompt(PromptTemplateConfig config, IPromptTemplateFactory? promptTemplateFactory = null);
Expand Down

0 comments on commit d0d7d3e

Please sign in to comment.