Skip to content

Commit

Permalink
Add custom Cache support (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcominerva authored May 30, 2023
2 parents 8c8be64 + 5329aab commit 287ce5b
Show file tree
Hide file tree
Showing 16 changed files with 361 additions and 98 deletions.
32 changes: 30 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,41 @@ In Azure OpenAI Service, you're required to first [deploy a model](https://learn
> **Note**
Some models are not available in all regions. You can refer to [Model Summary table and region availability page](https://learn.microsoft.com/azure/cognitive-services/openai/concepts/models#model-summary-table-and-region-availability) to check current availabilities.

### MessageLimit and MessageExpiration
### Caching, MessageLimit and MessageExpiration

ChatGPT is aimed to support conversational scenarios: user can talk to ChatGPT without specifying the full context for every interaction. However, conversation history isn't managed by OpenAI or Azure OpenAI service, so it's up to us to retain the current state. **ChatGptNet** handles this requirement using a [MemoryCache](https://learn.microsoft.com/en-us/dotnet/api/microsoft.extensions.caching.memory.memorycache) that stores messages for each conversation. The behavior can be set using the following properties:
ChatGPT is aimed to support conversational scenarios: user can talk to ChatGPT without specifying the full context for every interaction. However, conversation history isn't managed by OpenAI or Azure OpenAI service, so it's up to us to retain the current state. By default, **ChatGptNet** handles this requirement using a [MemoryCache](https://learn.microsoft.com/en-us/dotnet/api/microsoft.extensions.caching.memory.memorycache) that stores messages for each conversation. The behavior can be set using the following properties:

* *MessageLimit*: specifies how many messages for each conversation must be saved. When this limit is reached, oldest messages are automatically removed.
* *MessageExpiration*: specifies the time interval used to maintain messages in cache, regardless their count.

If necessary, it is possibile to provide a custom Cache by implementing the [IChatGptCache](https://github.com/marcominerva/ChatGptNet/blob/master/src/ChatGptNet/IChatGptCache.cs) interface and then calling the **WithCache** extension method:

public class LocalMessageCache : IChatGptCache
{
private readonly Dictionary<Guid, List<ChatGptMessage>> localCache = new();

public Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration)
{
localCache[conversationId] = messages.ToList();
return Task.CompletedTask;
}

public Task<List<ChatGptMessage>?> GetAsync(Guid conversationId)
{
localCache.TryGetValue(conversationId, out var messages);
return Task.FromResult(messages);
}

public Task RemoveAsync(Guid conversationId)
{
localCache.Remove(conversationId);
return Task.CompletedTask;
}
}

// Registers the custom cache at application startup.
builder.Services.AddChatGpt(/* ... */).WithCache<LocalMessageCache>();

We can also set ChatGPT parameters for chat completion at startup. Check the [official documentation](https://platform.openai.com/docs/api-reference/chat/create) for the list of available parameters and their meaning.

## Configuration using an external source
Expand Down
27 changes: 26 additions & 1 deletion samples/ChatGptConsole/Program.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using ChatGptConsole;
using ChatGptNet;
using ChatGptNet.Models;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;

Expand Down Expand Up @@ -28,5 +29,29 @@ static void ConfigureServices(HostBuilderContext context, IServiceCollection ser
//});

// Adds ChatGPT service using settings from IConfiguration.
services.AddChatGpt(context.Configuration);
services.AddChatGpt(context.Configuration)
//.WithCache<LocalMessageCache>() // Uncomment this line to use a custom cache implementation instead of the default MemoryCache.
;
}

public class LocalMessageCache : IChatGptCache
{
private readonly Dictionary<Guid, List<ChatGptMessage>> localCache = new();

public Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration)
{
localCache[conversationId] = messages.ToList();
return Task.CompletedTask;
}
public Task<List<ChatGptMessage>?> GetAsync(Guid conversationId)
{
localCache.TryGetValue(conversationId, out var messages);
return Task.FromResult(messages);
}

public Task RemoveAsync(Guid conversationId)
{
localCache.Remove(conversationId);
return Task.CompletedTask;
}
}
19 changes: 19 additions & 0 deletions src/ChatGptNet/ChatGptBuilder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Microsoft.Extensions.DependencyInjection;

namespace ChatGptNet;

/// <inheritdoc/>
public class ChatGptBuilder : IChatGptBuilder
{
/// <inheritdoc/>
public IServiceCollection Services { get; }

/// <inheritdoc/>
public IHttpClientBuilder HttpClientBuilder { get; }

internal ChatGptBuilder(IServiceCollection services, IHttpClientBuilder httpClientBuilder)
{
Services = services;
HttpClientBuilder = httpClientBuilder;
}
}
74 changes: 37 additions & 37 deletions src/ChatGptNet/ChatGptClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,21 @@
using System.Text.Json.Serialization;
using ChatGptNet.Exceptions;
using ChatGptNet.Models;
using Microsoft.Extensions.Caching.Memory;

namespace ChatGptNet;

internal class ChatGptClient : IChatGptClient
{
private readonly HttpClient httpClient;
private readonly IMemoryCache cache;
private readonly IChatGptCache cache;
private readonly ChatGptOptions options;

private static readonly JsonSerializerOptions jsonSerializerOptions = new(JsonSerializerDefaults.Web)
{
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
};

public ChatGptClient(HttpClient httpClient, IMemoryCache cache, ChatGptOptions options)
public ChatGptClient(HttpClient httpClient, IChatGptCache cache, ChatGptOptions options)
{
this.httpClient = httpClient;

Expand All @@ -34,7 +33,7 @@ public ChatGptClient(HttpClient httpClient, IMemoryCache cache, ChatGptOptions o
this.options = options;
}

public Task<Guid> SetupAsync(Guid conversationId, string message)
public async Task<Guid> SetupAsync(Guid conversationId, string message)
{
ArgumentNullException.ThrowIfNull(message);

Expand All @@ -53,9 +52,8 @@ public Task<Guid> SetupAsync(Guid conversationId, string message)
}
};

cache.Set(conversationId, messages, options.MessageExpiration);

return Task.FromResult(conversationId);
await cache.SetAsync(conversationId, messages, options.MessageExpiration);
return conversationId;
}

public async Task<ChatGptResponse> AskAsync(Guid conversationId, string message, ChatGptParameters? parameters = null, string? model = null, CancellationToken cancellationToken = default)
Expand All @@ -68,7 +66,7 @@ public async Task<ChatGptResponse> AskAsync(Guid conversationId, string message,
conversationId = Guid.NewGuid();
}

var messages = CreateMessageList(conversationId, message);
var messages = await CreateMessageListAsync(conversationId, message);
var request = CreateRequest(messages, false, parameters, model);

var requestUri = options.ServiceConfiguration.GetServiceEndpoint(model ?? options.DefaultModel);
Expand All @@ -81,7 +79,7 @@ public async Task<ChatGptResponse> AskAsync(Guid conversationId, string message,
if (response.IsSuccessful)
{
// Adds the response message to the conversation cache.
UpdateHistory(conversationId, messages, response.Choices[0].Message);
await UpdateHistoryAsync(conversationId, messages, response.Choices.First().Message);
}
else if (options.ThrowExceptionOnError)
{
Expand All @@ -101,7 +99,7 @@ public async IAsyncEnumerable<ChatGptResponse> AskStreamAsync(Guid conversationI
conversationId = Guid.NewGuid();
}

var messages = CreateMessageList(conversationId, message);
var messages = await CreateMessageListAsync(conversationId, message);
var request = CreateRequest(messages, true, parameters, model);

var requestUri = options.ServiceConfiguration.GetServiceEndpoint(model ?? options.DefaultModel);
Expand All @@ -128,15 +126,15 @@ public async IAsyncEnumerable<ChatGptResponse> AskStreamAsync(Guid conversationI
var json = line["data: ".Length..];
var response = JsonSerializer.Deserialize<ChatGptResponse>(json, jsonSerializerOptions);

var content = response!.Choices?[0].Delta?.Content;
var content = response!.Choices?.FirstOrDefault()?.Delta?.Content;

if (!string.IsNullOrEmpty(content))
{
if (contentBuilder.Length == 0)
{
// If this is the first response, trims all the initial special characters.
content = content.TrimStart('\n');
response.Choices![0].Delta!.Content = content;
response.Choices!.First().Delta!.Content = content;
}

// Yields the response only if there is an actual content.
Expand All @@ -157,7 +155,7 @@ public async IAsyncEnumerable<ChatGptResponse> AskStreamAsync(Guid conversationI
}

// Adds the response message to the conversation cache.
UpdateHistory(conversationId, messages, new()
await UpdateHistoryAsync(conversationId, messages, new()
{
Role = ChatGptRoles.Assistant,
Content = contentBuilder.ToString()
Expand All @@ -178,30 +176,32 @@ public async IAsyncEnumerable<ChatGptResponse> AskStreamAsync(Guid conversationI
}
}

public Task<IEnumerable<ChatGptMessage>> GetConversationAsync(Guid conversationId)
public async Task<IEnumerable<ChatGptMessage>> GetConversationAsync(Guid conversationId)
{
var messages = cache.Get<IEnumerable<ChatGptMessage>>(conversationId) ?? Enumerable.Empty<ChatGptMessage>();
return Task.FromResult(messages);
var messages = await cache.GetAsync(conversationId) ?? Enumerable.Empty<ChatGptMessage>();
return messages;
}

public Task DeleteConversationAsync(Guid conversationId, bool preserveSetup = false)
public async Task DeleteConversationAsync(Guid conversationId, bool preserveSetup = false)
{
if (!preserveSetup)
{
// We don't want to preserve setup message, so just deletes all the cache history.
cache.Remove(conversationId);
await cache.RemoveAsync(conversationId);
}
else if (cache.TryGetValue<List<ChatGptMessage>>(conversationId, out var messages))
else
{
// Removes all the messages, except system ones.
messages!.RemoveAll(m => m.Role != ChatGptRoles.System);
cache.Set(conversationId, messages, options.MessageExpiration);
var messages = await cache.GetAsync(conversationId);
if (messages is not null)
{
// Removes all the messages, except system ones.
messages.RemoveAll(m => m.Role != ChatGptRoles.System);
await cache.SetAsync(conversationId, messages, options.MessageExpiration);
}
}

return Task.CompletedTask;
}

public Task<Guid> LoadConversationAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, bool replaceHistory = true)
public async Task<Guid> LoadConversationAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, bool replaceHistory = true)
{
ArgumentNullException.ThrowIfNull(messages);

Expand All @@ -215,25 +215,25 @@ public Task<Guid> LoadConversationAsync(Guid conversationId, IEnumerable<ChatGpt
{
// If messages must replace history, just use the current list, discarding all the previously cached content.
// If messages.Count() > ChatGptOptions.MessageLimit, the UpdateCache take care of taking only the last messages.
UpdateCache(conversationId, messages);
await UpdateCacheAsync(conversationId, messages);
}
else
{
// Retrieves the current history and adds new messages.
var conversationHistory = cache.Get<List<ChatGptMessage>>(conversationId) ?? new List<ChatGptMessage>();
var conversationHistory = await cache.GetAsync(conversationId) ?? new List<ChatGptMessage>();
conversationHistory.AddRange(messages);

// If messages total length > ChatGptOptions.MessageLimit, the UpdateCache take care of taking only the last messages.
UpdateCache(conversationId, conversationHistory);
await UpdateCacheAsync(conversationId, conversationHistory);
}

return Task.FromResult(conversationId);
return conversationId;
}

private IList<ChatGptMessage> CreateMessageList(Guid conversationId, string message)
private async Task<List<ChatGptMessage>> CreateMessageListAsync(Guid conversationId, string message)
{
// Checks whether a list of messages for the given conversationId already exists.
var conversationHistory = cache.Get<IList<ChatGptMessage>>(conversationId);
var conversationHistory = await cache.GetAsync(conversationId);
List<ChatGptMessage> messages = conversationHistory is not null ? new(conversationHistory) : new();

messages.Add(new()
Expand All @@ -245,11 +245,11 @@ private IList<ChatGptMessage> CreateMessageList(Guid conversationId, string mess
return messages;
}

private ChatGptRequest CreateRequest(IList<ChatGptMessage> messages, bool stream, ChatGptParameters? parameters = null, string? model = null)
private ChatGptRequest CreateRequest(IEnumerable<ChatGptMessage> messages, bool stream, ChatGptParameters? parameters = null, string? model = null)
=> new()
{
Model = model ?? options.DefaultModel,
Messages = messages.ToArray(),
Messages = messages,
Stream = stream,
Temperature = parameters?.Temperature ?? options.DefaultParameters.Temperature,
TopP = parameters?.TopP ?? options.DefaultParameters.TopP,
Expand All @@ -259,13 +259,13 @@ private ChatGptRequest CreateRequest(IList<ChatGptMessage> messages, bool stream
User = options.User,
};

private void UpdateHistory(Guid conversationId, IList<ChatGptMessage> messages, ChatGptMessage message)
private async Task UpdateHistoryAsync(Guid conversationId, IList<ChatGptMessage> messages, ChatGptMessage message)
{
messages.Add(message);
UpdateCache(conversationId, messages);
await UpdateCacheAsync(conversationId, messages);
}

private void UpdateCache(Guid conversationId, IEnumerable<ChatGptMessage> messages)
private async Task UpdateCacheAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages)
{
// If the maximum number of messages has been reached, deletes the oldest ones.
// Note: system message does not count for message limit.
Expand All @@ -285,7 +285,7 @@ private void UpdateCache(Guid conversationId, IEnumerable<ChatGptMessage> messag
messages = conversation.ToList();
}

cache.Set(conversationId, messages, options.MessageExpiration);
await cache.SetAsync(conversationId, messages, options.MessageExpiration);
}

private static void EnsureErrorIsSet(ChatGptResponse response, HttpResponseMessage httpResponse)
Expand Down
32 changes: 32 additions & 0 deletions src/ChatGptNet/ChatGptMemoryCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using ChatGptNet.Models;
using Microsoft.Extensions.Caching.Memory;

namespace ChatGptNet;

internal class ChatGptMemoryCache : IChatGptCache
{
private readonly IMemoryCache cache;

public ChatGptMemoryCache(IMemoryCache cache)
{
this.cache = cache;
}

public Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration)
{
cache.Set(conversationId, messages, expiration);
return Task.CompletedTask;
}

public Task<List<ChatGptMessage>?> GetAsync(Guid conversationId)
{
var messages = cache.Get<List<ChatGptMessage>?>(conversationId);
return Task.FromResult(messages);
}

public Task RemoveAsync(Guid conversationId)
{
cache.Remove(conversationId);
return Task.CompletedTask;
}
}
2 changes: 1 addition & 1 deletion src/ChatGptNet/ChatGptOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public class ChatGptOptions
/// Gets or sets the default parameters for chat completion.
/// </summary>
/// <see cref="ChatGptParameters"/>
public ChatGptParameters DefaultParameters { get; } = new();
public ChatGptParameters DefaultParameters { get; internal set; } = new();

/// <summary>
/// Gets or sets the user identification for chat completion, which can help OpenAI to monitor and detect abuse.
Expand Down
Loading

0 comments on commit 287ce5b

Please sign in to comment.