Skip to content

Commit

Permalink
Method to add an interaction #100
Browse files Browse the repository at this point in the history
  • Loading branch information
marcominerva committed Aug 11, 2023
1 parent 67bc6b4 commit 010b1da
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 38 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ If necessary, it is possibile to provide a custom Cache by implementing the [ICh
```csharp
public class LocalMessageCache : IChatGptCache
{
private readonly Dictionary<Guid, List<ChatGptMessage>> localCache = new();
private readonly Dictionary<Guid, IEnumerable<ChatGptMessage>> localCache = new();

public Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration, CancellationToken cancellationToken = default)
{
localCache[conversationId] = messages.ToList();
return Task.CompletedTask;
}

public Task<List<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default)
public Task<IEnumerable<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default)
{
localCache.TryGetValue(conversationId, out var messages);
return Task.FromResult(messages);
Expand Down
2 changes: 1 addition & 1 deletion samples/ChatGptApi/ChatGptApi.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="7.0.9" />
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="7.0.10" />
<PackageReference Include="MinimalHelpers.OpenApi" Version="1.0.4" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.5.0" />
</ItemGroup>
Expand Down
4 changes: 2 additions & 2 deletions samples/ChatGptApi/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ async IAsyncEnumerable<string> Stream()
})
.WithOpenApi();

app.MapDelete("/api/chat/{conversationId:guid}", async (Guid conversationId, IChatGptClient chatGptClient) =>
app.MapDelete("/api/chat/{conversationId:guid}", async (Guid conversationId, bool? preserveSetup, IChatGptClient chatGptClient) =>
{
await chatGptClient.DeleteConversationAsync(conversationId);
await chatGptClient.DeleteConversationAsync(conversationId, preserveSetup.GetValueOrDefault());
return TypedResults.NoContent();
})
.WithOpenApi();
Expand Down
6 changes: 3 additions & 3 deletions samples/ChatGptBlazor.Wasm/ChatGptBlazor.Wasm.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Markdig" Version="0.31.0" />
<PackageReference Include="Microsoft.AspNetCore.Components.WebAssembly" Version="7.0.9" />
<PackageReference Include="Microsoft.AspNetCore.Components.WebAssembly.DevServer" Version="7.0.9" PrivateAssets="all" />
<PackageReference Include="Markdig" Version="0.32.0" />
<PackageReference Include="Microsoft.AspNetCore.Components.WebAssembly" Version="7.0.10" />
<PackageReference Include="Microsoft.AspNetCore.Components.WebAssembly.DevServer" Version="7.0.10" PrivateAssets="all" />
</ItemGroup>

<ItemGroup>
Expand Down
4 changes: 2 additions & 2 deletions samples/ChatGptConsole/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ static void ConfigureServices(HostBuilderContext context, IServiceCollection ser

public class LocalMessageCache : IChatGptCache
{
private readonly Dictionary<Guid, List<ChatGptMessage>> localCache = new();
private readonly Dictionary<Guid, IEnumerable<ChatGptMessage>> localCache = new();

public Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration, CancellationToken cancellationToken = default)
{
localCache[conversationId] = messages.ToList();
return Task.CompletedTask;
}
public Task<List<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default)
public Task<IEnumerable<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default)
{
localCache.TryGetValue(conversationId, out var messages);
return Task.FromResult(messages);
Expand Down
4 changes: 2 additions & 2 deletions samples/ChatGptFunctionCallingConsole/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ static void ConfigureServices(HostBuilderContext context, IServiceCollection ser

public class LocalMessageCache : IChatGptCache
{
private readonly Dictionary<Guid, List<ChatGptMessage>> localCache = new();
private readonly Dictionary<Guid, IEnumerable<ChatGptMessage>> localCache = new();

public Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages, TimeSpan expiration, CancellationToken cancellationToken = default)
{
localCache[conversationId] = messages.ToList();
return Task.CompletedTask;
}
public Task<List<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default)
public Task<IEnumerable<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default)
{
localCache.TryGetValue(conversationId, out var messages);
return Task.FromResult(messages);
Expand Down
73 changes: 50 additions & 23 deletions src/ChatGptNet/ChatGptClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public async Task<ChatGptResponse> AskAsync(Guid conversationId, string message,
using var httpResponse = await httpClient.PostAsJsonAsync(requestUri, request, jsonSerializerOptions, cancellationToken);

var response = await httpResponse.Content.ReadFromJsonAsync<ChatGptResponse>(jsonSerializerOptions, cancellationToken: cancellationToken);
NormalizeResponse(httpResponse, response!, conversationId);
NormalizeResponse(httpResponse, response!, conversationId, model ?? options.DefaultModel);

if (response!.IsSuccessful)
{
Expand Down Expand Up @@ -170,7 +170,7 @@ public async IAsyncEnumerable<ChatGptResponse> AskStreamAsync(Guid conversationI
else
{
var response = await httpResponse.Content.ReadFromJsonAsync<ChatGptResponse>(cancellationToken: cancellationToken);
NormalizeResponse(httpResponse, response!, conversationId);
NormalizeResponse(httpResponse, response!, conversationId, model ?? options.DefaultModel);

if (options.ThrowExceptionOnError)
{
Expand All @@ -183,7 +183,9 @@ public async IAsyncEnumerable<ChatGptResponse> AskStreamAsync(Guid conversationI

public async Task<IEnumerable<ChatGptMessage>> GetConversationAsync(Guid conversationId, CancellationToken cancellationToken = default)
{
var messages = await cache.GetAsync(conversationId, cancellationToken) ?? Enumerable.Empty<ChatGptMessage>();
var conversationHistory = await cache.GetAsync(conversationId, cancellationToken);
var messages = conversationHistory?.ToList() ?? Enumerable.Empty<ChatGptMessage>();

return messages;
}

Expand All @@ -205,8 +207,8 @@ public async Task DeleteConversationAsync(Guid conversationId, bool preserveSetu
var messages = await cache.GetAsync(conversationId, cancellationToken);
if (messages is not null)
{
// Removes all the messages, except system ones.
messages.RemoveAll(m => m.Role != ChatGptRoles.System);
// Preserves the system message.
messages = messages.Where(m => m.Role == ChatGptRoles.System);
await cache.SetAsync(conversationId, messages, options.MessageExpiration, cancellationToken);
}
}
Expand All @@ -222,48 +224,68 @@ public async Task<Guid> LoadConversationAsync(Guid conversationId, IEnumerable<C
if (replaceHistory)
{
// If messages must replace history, just use the current list, discarding all the previously cached content.
// If messages.Count() > ChatGptOptions.MessageLimit, the UpdateCache take care of taking only the last messages.
// If messages.Count() > ChatGptOptions.MessageLimit, the UpdateCacheAsync method takes care of taking only the last messages.
await UpdateCacheAsync(conversationId, messages, cancellationToken);
}
else
{
// Retrieves the current history and adds new messages.
var conversationHistory = await cache.GetAsync(conversationId, cancellationToken) ?? new List<ChatGptMessage>();
conversationHistory.AddRange(messages);
var conversationHistory = await cache.GetAsync(conversationId, cancellationToken) ?? Enumerable.Empty<ChatGptMessage>();
conversationHistory = conversationHistory.Union(messages);

// If messages total length > ChatGptOptions.MessageLimit, the UpdateCache take care of taking only the last messages.
// If messages.Count() > ChatGptOptions.MessageLimit, the UpdateCacheAsync method takes care of taking only the last messages.
await UpdateCacheAsync(conversationId, conversationHistory, cancellationToken);
}

return conversationId;
}

public async Task AddInteractionAsync(Guid conversationId, string question, string answer, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(question);
ArgumentNullException.ThrowIfNull(answer);

var messages = await cache.GetAsync(conversationId, cancellationToken) ?? Enumerable.Empty<ChatGptMessage>();
messages = messages.Union(new ChatGptMessage[]
{
new()
{
Role = ChatGptRoles.User,
Content = question
},
new()
{
Role = ChatGptRoles.Assistant,
Content = answer
}
});

await UpdateCacheAsync(conversationId, messages, cancellationToken);
}

public async Task AddFunctionResponseAsync(Guid conversationId, string functionName, string content, CancellationToken cancellationToken = default)
{
var conversationHistory = await cache.GetAsync(conversationId, cancellationToken);
if (!conversationHistory?.Any() ?? true)
var messages = await cache.GetAsync(conversationId, cancellationToken);
if (!messages?.Any() ?? true)
{
throw new InvalidOperationException("Cannot add a function response message if the conversation history is empty");
}

var messages = new List<ChatGptMessage>(conversationHistory!)
messages = messages!.Append(new()
{
new()
{
Role = ChatGptRoles.Function,
Name = functionName,
Content = content
}
};
Role = ChatGptRoles.Function,
Name = functionName,
Content = content
});

await UpdateCacheAsync(conversationId, messages, cancellationToken);
}

private async Task<List<ChatGptMessage>> CreateMessageListAsync(Guid conversationId, string message, CancellationToken cancellationToken = default)
private async Task<IList<ChatGptMessage>> CreateMessageListAsync(Guid conversationId, string message, CancellationToken cancellationToken = default)
{
// Checks whether a list of messages for the given conversationId already exists.
var conversationHistory = await cache.GetAsync(conversationId, cancellationToken);
List<ChatGptMessage> messages = conversationHistory is not null ? new(conversationHistory) : new();
var messages = conversationHistory?.ToList() ?? new List<ChatGptMessage>();

messages.Add(new()
{
Expand Down Expand Up @@ -323,16 +345,21 @@ private async Task UpdateCacheAsync(Guid conversationId, IEnumerable<ChatGptMess
conversation = conversation.Prepend(firstMessage);
}

messages = conversation.ToList();
messages = conversation;
}

await cache.SetAsync(conversationId, messages, options.MessageExpiration, cancellationToken);
}

private static void NormalizeResponse(HttpResponseMessage httpResponse, ChatGptResponse response, Guid conversationId)
private static void NormalizeResponse(HttpResponseMessage httpResponse, ChatGptResponse response, Guid conversationId, string? model)
{
response.ConversationId = conversationId;

if (string.IsNullOrWhiteSpace(response.Model) && model is not null)
{
response.Model = model;
}

if (!httpResponse.IsSuccessStatusCode && response.Error is null)
{
response.Error = new ChatGptError
Expand Down
4 changes: 2 additions & 2 deletions src/ChatGptNet/ChatGptMemoryCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ public Task SetAsync(Guid conversationId, IEnumerable<ChatGptMessage> messages,
return Task.CompletedTask;
}

public Task<List<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default)
public Task<IEnumerable<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default)
{
var messages = cache.Get<List<ChatGptMessage>?>(conversationId);
var messages = cache.Get<IEnumerable<ChatGptMessage>?>(conversationId);
return Task.FromResult(messages);
}

Expand Down
2 changes: 1 addition & 1 deletion src/ChatGptNet/IChatGptCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public interface IChatGptCache
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
/// <returns>The message list of the conversation, or <see langword="null"/> if the Conversation Id does not exist.</returns>
/// <seealso cref="ChatGptMessage"/>
Task<List<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default);
Task<IEnumerable<ChatGptMessage>?> GetAsync(Guid conversationId, CancellationToken cancellationToken = default);

/// <summary>
/// Removes from the cache all the message for the given <paramref name="conversationId"/>.
Expand Down
11 changes: 11 additions & 0 deletions src/ChatGptNet/IChatGptClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ IAsyncEnumerable<ChatGptResponse> AskStreamAsync(string message, ChatGptParamete
/// <seealso cref="ChatGptParameters"/>
IAsyncEnumerable<ChatGptResponse> AskStreamAsync(Guid conversationId, string message, ChatGptParameters? parameters = null, string? model = null, bool addToConversationHistory = true, CancellationToken cancellationToken = default);

/// <summary>
/// Explicitly adds a new interaction (a question and the corresponding answer) to the conversation history.
/// </summary>
/// <param name="conversationId">The unique identifier of the conversation.</param>
/// <param name="question">The question.</param>
/// <param name="answer">The answer.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
/// <returns>The <see cref="Task"/> corresponding to the asynchronous operation.</returns>
/// <exception cref="ArgumentNullException"><paramref name="question"/> or <paramref name="answer"/> are <see langword="null"/>.</exception>
Task AddInteractionAsync(Guid conversationId, string question, string answer, CancellationToken cancellationToken = default);

/// <summary>
/// Retrieves a chat conversation from the cache.
/// </summary>
Expand Down
3 changes: 3 additions & 0 deletions src/ChatGptNet/Models/ChatGptResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ public class ChatGptResponse
/// <summary>
/// Gets or sets information about token usage.
/// </summary>
/// <remarks>
/// The <see cref="Usage"/> property is always <see langword="null"/> when requesting response streaming with <see cref="ChatGptClient.AskStreamAsync(Guid, string, ChatGptParameters?, string?, bool, CancellationToken)"/>.
/// </remarks>
public ChatGptUsage? Usage { get; set; }

/// <summary>
Expand Down

0 comments on commit 010b1da

Please sign in to comment.