From 21a45cefdfe65558fbe2e4294a800a9cb1047bed Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Tue, 26 Nov 2024 18:03:26 +0000 Subject: [PATCH] Skeleton implementation of OpenAI serialization methods. --- .../Microsoft.Extensions.AI.OpenAI.csproj | 2 + .../OpenAIModelMapper.ChatMessage.cs | 111 ++++++- .../OpenAIModelMapper.cs | 304 +++++++++++++++++- .../OpenAISerializationHelpers.cs | 134 ++++++++ 4 files changed, 538 insertions(+), 13 deletions(-) create mode 100644 src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj index 43991fa84e6..85c74c7a5a1 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj @@ -18,11 +18,13 @@ $(NoWarn);CA1063;CA1508;CA2227;SA1316;S1121;S3358;EA0002;OPENAI002 true true + true true true + true diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs index 42531d5c490..e57069f1dee 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs @@ -6,6 +6,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using OpenAI.Chat; @@ -13,6 +14,89 @@ namespace Microsoft.Extensions.AI; internal static partial class OpenAIModelMappers { + internal static IEnumerable FromOpenAIChatMessages(IEnumerable inputs, JsonSerializerOptions options) + { + // Maps all of the OpenAI types to the corresponding M.E.AI types. + // Unrecognized or non-processable content is ignored. + + foreach (OpenAI.Chat.ChatMessage input in inputs) + { + switch (input) + { + case SystemChatMessage systemMessage: + yield return new ChatMessage + { + Role = ChatRole.System, + AuthorName = systemMessage.ParticipantName, + Contents = FromOpenAIChatContent(systemMessage.Content), + }; + break; + + case UserChatMessage userMessage: + yield return new ChatMessage + { + Role = ChatRole.User, + AuthorName = userMessage.ParticipantName, + Contents = FromOpenAIChatContent(userMessage.Content), + }; + break; + + case ToolChatMessage toolMessage: + string textContent = string.Join(string.Empty, toolMessage.Content.Where(part => part.Kind is ChatMessageContentPartKind.Text).Select(part => part.Text)); + object? result = textContent; + if (!string.IsNullOrEmpty(textContent)) + { +#pragma warning disable CA1031 // Do not catch general exception types + try + { + result = JsonSerializer.Deserialize(textContent, options.GetTypeInfo(typeof(object))); + } + catch + { + // If the content can't be deserialized, leave it as a string. + } +#pragma warning restore CA1031 // Do not catch general exception types + } + + yield return new ChatMessage + { + Role = ChatRole.Tool, + Contents = new AIContent[] { new FunctionResultContent(toolMessage.ToolCallId, name: string.Empty, result) }, + }; + break; + + case AssistantChatMessage assistantMessage: + + ChatMessage message = new() + { + Role = ChatRole.Assistant, + AuthorName = assistantMessage.ParticipantName, + Contents = FromOpenAIChatContent(assistantMessage.Content), + }; + + foreach (ChatToolCall toolCall in assistantMessage.ToolCalls) + { + if (!string.IsNullOrWhiteSpace(toolCall.FunctionName)) + { + var callContent = ParseCallContentFromBinaryData(toolCall.FunctionArguments, toolCall.Id, toolCall.FunctionName); + callContent.RawRepresentation = toolCall; + + message.Contents.Add(callContent); + } + } + + if (assistantMessage.Refusal is not null) + { + message.AdditionalProperties ??= new(); + message.AdditionalProperties.Add(nameof(assistantMessage.Refusal), assistantMessage.Refusal); + } + + yield return message; + break; + } + } + } + /// Converts an Extensions chat message enumerable to an OpenAI chat message enumerable. internal static IEnumerable ToOpenAIChatMessages(IEnumerable inputs, JsonSerializerOptions options) { @@ -60,7 +144,7 @@ internal static partial class OpenAIModelMappers foreach (var content in input.Contents) { - if (content is FunctionCallContent { CallId: not null } callRequest) + if (content is FunctionCallContent callRequest) { message.ToolCalls.Add( ChatToolCall.CreateFunctionToolCall( @@ -82,6 +166,31 @@ internal static partial class OpenAIModelMappers } } + private static List FromOpenAIChatContent(IList openAiMessageContentParts) + { + List contents = new(); + foreach (var openAiContentPart in openAiMessageContentParts) + { + switch (openAiContentPart.Kind) + { + case ChatMessageContentPartKind.Text: + contents.Add(new TextContent(openAiContentPart.Text)); + break; + + case ChatMessageContentPartKind.Image when (openAiContentPart.ImageBytes is { } bytes): + contents.Add(new ImageContent(bytes.ToArray(), openAiContentPart.ImageBytesMediaType)); + break; + + case ChatMessageContentPartKind.Image: + contents.Add(new ImageContent(openAiContentPart.ImageUri?.ToString() ?? string.Empty)); + break; + + } + } + + return contents; + } + /// Converts a list of to a list of . private static List ToOpenAIChatContent(IList contents) { diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.cs index 3427403c924..a46ecf9a881 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.cs @@ -2,7 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.ClientModel.Primitives; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; @@ -14,6 +16,9 @@ #pragma warning disable SA1204 // Static elements should appear before instance elements #pragma warning disable S1135 // Track uses of "TODO" tags +#pragma warning disable SA1118 // Parameter should not span multiple lines +#pragma warning disable S103 // Lines should not be too long +#pragma warning disable CA1859 // Use concrete types when possible for improved performance namespace Microsoft.Extensions.AI; @@ -21,6 +26,46 @@ internal static partial class OpenAIModelMappers { private static readonly JsonElement _defaultParameterSchema = JsonDocument.Parse("{}").RootElement; + internal static OpenAI.Chat.ChatCompletion ToOpenAIChatCompletion(ChatCompletion chatCompletion, JsonSerializerOptions options) + { + _ = Throw.IfNull(chatCompletion); + + List? toolCalls = null; + foreach (AIContent content in chatCompletion.Message.Contents) + { + if (content is FunctionCallContent callRequest) + { + toolCalls ??= []; + toolCalls.Add(ChatToolCall.CreateFunctionToolCall( + callRequest.CallId, + callRequest.Name, + new(JsonSerializer.SerializeToUtf8Bytes( + callRequest.Arguments, + options.GetTypeInfo(typeof(IDictionary)))))); + } + } + + OpenAI.Chat.ChatTokenUsage? chatTokenUsage = null; + if (chatCompletion.Usage is UsageDetails usageDetails) + { + chatTokenUsage = ToOpenAIUsage(usageDetails); + } + + return OpenAIChatModelFactory.ChatCompletion( + id: chatCompletion.CompletionId, + model: chatCompletion.ModelId, + createdAt: chatCompletion.CreatedAt ?? default, + role: ToOpenAIChatRole(chatCompletion.Message.Role).Value, + finishReason: ToOpenAIFinishReason(chatCompletion.FinishReason), + content: [.. ToOpenAIChatContent(chatCompletion.Message.Contents)], + toolCalls: toolCalls, + refusal: chatCompletion.AdditionalProperties.GetValueOrDefault(nameof(OpenAI.Chat.ChatCompletion.Refusal)), + contentTokenLogProbabilities: chatCompletion.AdditionalProperties.GetValueOrDefault>(nameof(OpenAI.Chat.ChatCompletion.ContentTokenLogProbabilities)), + refusalTokenLogProbabilities: chatCompletion.AdditionalProperties.GetValueOrDefault>(nameof(OpenAI.Chat.ChatCompletion.RefusalTokenLogProbabilities)), + systemFingerprint: chatCompletion.AdditionalProperties.GetValueOrDefault(nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)), + usage: chatTokenUsage); + } + internal static ChatCompletion FromOpenAIChatCompletion(OpenAI.Chat.ChatCompletion openAICompletion, ChatOptions? options) { _ = Throw.IfNull(openAICompletion); @@ -29,7 +74,7 @@ internal static ChatCompletion FromOpenAIChatCompletion(OpenAI.Chat.ChatCompleti ChatMessage returnMessage = new() { RawRepresentation = openAICompletion, - Role = ToChatRole(openAICompletion.Role), + Role = FromOpenAIChatRole(openAICompletion.Role), }; // Populate its content from those in the OpenAI response content. @@ -63,12 +108,12 @@ internal static ChatCompletion FromOpenAIChatCompletion(OpenAI.Chat.ChatCompleti CompletionId = openAICompletion.Id, CreatedAt = openAICompletion.CreatedAt, ModelId = openAICompletion.Model, - FinishReason = ToFinishReason(openAICompletion.FinishReason), + FinishReason = FromOpenAIFinishReason(openAICompletion.FinishReason), }; if (openAICompletion.Usage is ChatTokenUsage tokenUsage) { - completion.Usage = ToUsageDetails(tokenUsage); + completion.Usage = FromOpenAIUsage(tokenUsage); } if (openAICompletion.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs) @@ -94,6 +139,49 @@ internal static ChatCompletion FromOpenAIChatCompletion(OpenAI.Chat.ChatCompleti return completion; } + internal static async IAsyncEnumerable ToOpenAIStreamingChatCompletionAsync( + IAsyncEnumerable chatCompletions, + JsonSerializerOptions options, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var chatCompletionUpdate in chatCompletions.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + List? toolCallUpdates = null; + ChatTokenUsage? chatTokenUsage = null; + + foreach (var content in chatCompletionUpdate.Contents) + { + if (content is FunctionCallContent functionCallContent) + { + toolCallUpdates ??= []; + toolCallUpdates.Add(OpenAIChatModelFactory.StreamingChatToolCallUpdate( + index: toolCallUpdates.Count, + toolCallId: functionCallContent.CallId, + functionName: functionCallContent.Name, + functionArgumentsUpdate: new(JsonSerializer.SerializeToUtf8Bytes(functionCallContent.Arguments, options.GetTypeInfo(typeof(IDictionary)))))); + } + else if (content is UsageContent usageContent) + { + chatTokenUsage = ToOpenAIUsage(usageContent.Details); + } + } + + yield return OpenAIChatModelFactory.StreamingChatCompletionUpdate( + completionId: chatCompletionUpdate.CompletionId, + model: chatCompletionUpdate.ModelId, + createdAt: chatCompletionUpdate.CreatedAt ?? default, + role: ToOpenAIChatRole(chatCompletionUpdate.Role), + finishReason: ToOpenAIFinishReason(chatCompletionUpdate.FinishReason), + contentUpdate: [.. ToOpenAIChatContent(chatCompletionUpdate.Contents)], + toolCallUpdates: toolCallUpdates, + refusalUpdate: chatCompletionUpdate.AdditionalProperties.GetValueOrDefault(nameof(OpenAI.Chat.StreamingChatCompletionUpdate.RefusalUpdate)), + contentTokenLogProbabilities: chatCompletionUpdate.AdditionalProperties.GetValueOrDefault>(nameof(OpenAI.Chat.StreamingChatCompletionUpdate.ContentTokenLogProbabilities)), + refusalTokenLogProbabilities: chatCompletionUpdate.AdditionalProperties.GetValueOrDefault>(nameof(OpenAI.Chat.StreamingChatCompletionUpdate.RefusalTokenLogProbabilities)), + systemFingerprint: chatCompletionUpdate.AdditionalProperties.GetValueOrDefault(nameof(OpenAI.Chat.StreamingChatCompletionUpdate.SystemFingerprint)), + usage: chatTokenUsage); + } + } + internal static async IAsyncEnumerable FromOpenAIStreamingChatCompletionAsync( IAsyncEnumerable chatCompletionUpdates, [EnumeratorCancellation] CancellationToken cancellationToken = default) @@ -111,8 +199,8 @@ internal static async IAsyncEnumerable FromOpenAI await foreach (OpenAI.Chat.StreamingChatCompletionUpdate chatCompletionUpdate in chatCompletionUpdates.WithCancellation(cancellationToken).ConfigureAwait(false)) { // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. - streamedRole ??= chatCompletionUpdate.Role is ChatMessageRole role ? ToChatRole(role) : null; - finishReason ??= chatCompletionUpdate.FinishReason is OpenAI.Chat.ChatFinishReason reason ? ToFinishReason(reason) : null; + streamedRole ??= chatCompletionUpdate.Role is ChatMessageRole role ? FromOpenAIChatRole(role) : null; + finishReason ??= chatCompletionUpdate.FinishReason is OpenAI.Chat.ChatFinishReason reason ? FromOpenAIFinishReason(reason) : null; completionId ??= chatCompletionUpdate.CompletionId; createdAt ??= chatCompletionUpdate.CreatedAt; modelId ??= chatCompletionUpdate.Model; @@ -186,7 +274,7 @@ internal static async IAsyncEnumerable FromOpenAI // Transfer over usage updates. if (chatCompletionUpdate.Usage is ChatTokenUsage tokenUsage) { - var usageDetails = ToUsageDetails(tokenUsage); + var usageDetails = FromOpenAIUsage(tokenUsage); completionUpdate.Contents.Add(new UsageContent(usageDetails)); } @@ -236,8 +324,87 @@ internal static async IAsyncEnumerable FromOpenAI } } + internal static ChatOptions FromOpenAIOptions(OpenAI.Chat.ChatCompletionOptions? options) + { + ChatOptions result = new(); + + if (options is not null) + { + result.FrequencyPenalty = options.FrequencyPenalty; + result.MaxOutputTokens = options.MaxOutputTokenCount; + result.TopP = options.TopP; + result.PresencePenalty = options.PresencePenalty; + result.Temperature = options.Temperature; +#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. + result.Seed = options.Seed; +#pragma warning restore OPENAI001 + + if (options.StopSequences is { Count: > 0 } stopSequences) + { + result.StopSequences = [.. stopSequences]; + } + + if (options.EndUserId is string endUserId) + { + (result.AdditionalProperties ??= [])[nameof(options.EndUserId)] = endUserId; + } + + if (options.IncludeLogProbabilities is bool includeLogProbabilities) + { + (result.AdditionalProperties ??= [])[nameof(options.IncludeLogProbabilities)] = includeLogProbabilities; + } + + if (options.LogitBiases is { Count: > 0 } logitBiases) + { + (result.AdditionalProperties ??= [])[nameof(options.LogitBiases)] = new Dictionary(logitBiases); + } + + if (options.AllowParallelToolCalls is bool allowParallelToolCalls) + { + (result.AdditionalProperties ??= [])[nameof(options.AllowParallelToolCalls)] = allowParallelToolCalls; + } + + if (options.TopLogProbabilityCount is int topLogProbabilityCount) + { + (result.AdditionalProperties ??= [])[nameof(options.TopLogProbabilityCount)] = topLogProbabilityCount; + } + + if (options.Tools is { Count: > 0 } tools) + { + foreach (ChatTool tool in tools) + { + result.Tools ??= []; + result.Tools.Add(FromOpenAIChatTool(tool)); + } + + using var toolChoiceJson = JsonDocument.Parse(((IJsonModel)options.ToolChoice).Write(ModelReaderWriterOptions.Json)); + JsonElement jsonElement = toolChoiceJson.RootElement; + switch (jsonElement.ValueKind) + { + case JsonValueKind.String: + result.ToolMode = jsonElement.GetString() switch + { + "required" => ChatToolMode.RequireAny, + _ => ChatToolMode.Auto, + }; + + break; + case JsonValueKind.Object: + if (jsonElement.TryGetProperty("function", out JsonElement functionElement)) + { + result.ToolMode = ChatToolMode.RequireSpecific(functionElement.GetString()!); + } + + break; + } + } + } + + return result; + } + /// Converts an extensions options instance to an OpenAI options instance. - internal static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options) + internal static OpenAI.Chat.ChatCompletionOptions ToOpenAIOptions(ChatOptions? options) { ChatCompletionOptions result = new(); @@ -330,6 +497,47 @@ internal static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options) return result; } + private static AITool FromOpenAIChatTool(ChatTool chatTool) + { + Dictionary additionalProperties = new(); + if (chatTool.FunctionSchemaIsStrict is bool strictValue) + { + additionalProperties["Strict"] = strictValue; + } + + OpenAIChatToolJson openAiChatTool = JsonSerializer.Deserialize(chatTool.FunctionParameters.ToMemory().Span, OpenAIJsonContext.Default.OpenAIChatToolJson)!; + List parameters = new(openAiChatTool.Properties.Count); + foreach (KeyValuePair property in openAiChatTool.Properties) + { + parameters.Add(new(property.Key) + { + Schema = property.Value, + IsRequired = openAiChatTool.Required.Contains(property.Key), + }); + } + + AIFunctionMetadata metadata = new(chatTool.FunctionName) + { + Description = chatTool.FunctionDescription, + AdditionalProperties = additionalProperties, + Parameters = parameters, + ReturnParameter = new() + { + Description = "Return parameter", + Schema = _defaultParameterSchema, + } + }; + + return new MetadataOnlyAIFunction(metadata); + } + + private sealed class MetadataOnlyAIFunction(AIFunctionMetadata metadata) : AIFunction + { + public override AIFunctionMetadata Metadata => metadata; + protected override Task InvokeCoreAsync(IEnumerable> arguments, CancellationToken cancellationToken) => + throw new InvalidOperationException($"The AI function '{metadata.Name}' does not support being invoked."); + } + /// Converts an Extensions function to an OpenAI chat tool. private static ChatTool ToOpenAIChatTool(AIFunction aiFunction) { @@ -351,7 +559,7 @@ strictObj is bool strictValue ? if (parameter.IsRequired) { - tool.Required.Add(parameter.Name); + _ = tool.Required.Add(parameter.Name); } } @@ -362,7 +570,7 @@ strictObj is bool strictValue ? return ChatTool.CreateFunctionTool(aiFunction.Metadata.Name, aiFunction.Metadata.Description, resultParameters, strict); } - private static UsageDetails ToUsageDetails(ChatTokenUsage tokenUsage) + private static UsageDetails FromOpenAIUsage(ChatTokenUsage tokenUsage) { var destination = new UsageDetails { @@ -406,8 +614,54 @@ private static UsageDetails ToUsageDetails(ChatTokenUsage tokenUsage) return destination; } + private static ChatTokenUsage ToOpenAIUsage(UsageDetails usageDetails) + { + ChatOutputTokenUsageDetails? outputTokenUsageDetails = null; + ChatInputTokenUsageDetails? inputTokenUsageDetails = null; + + if (usageDetails.AdditionalCounts is { Count: > 0 } additionalCounts) + { + int? inputAudioTokenCount = additionalCounts.TryGetValue( + $"{nameof(ChatTokenUsage.InputTokenDetails)}.{nameof(ChatInputTokenUsageDetails.AudioTokenCount)}", + out int value) ? value : null; + + int? inputCachedTokenCount = additionalCounts.TryGetValue( + $"{nameof(ChatTokenUsage.InputTokenDetails)}.{nameof(ChatInputTokenUsageDetails.CachedTokenCount)}", + out value) ? value : null; + + int? outputAudioTokenCount = additionalCounts.TryGetValue( + $"{nameof(ChatTokenUsage.OutputTokenDetails)}.{nameof(ChatOutputTokenUsageDetails.AudioTokenCount)}", + out value) ? value : null; + + int? outputReasoningTokenCount = additionalCounts.TryGetValue( + $"{nameof(ChatTokenUsage.OutputTokenDetails)}.{nameof(ChatOutputTokenUsageDetails.ReasoningTokenCount)}", + out value) ? value : null; + + if (inputAudioTokenCount is not null || inputCachedTokenCount is not null) + { + inputTokenUsageDetails = OpenAIChatModelFactory.ChatInputTokenUsageDetails( + audioTokenCount: inputAudioTokenCount, + cachedTokenCount: inputCachedTokenCount); + } + + if (outputAudioTokenCount is not null || outputReasoningTokenCount is not null) + { + outputTokenUsageDetails = OpenAIChatModelFactory.ChatOutputTokenUsageDetails( + audioTokenCount: outputAudioTokenCount, + reasoningTokenCount: outputReasoningTokenCount ?? 0); + } + } + + return OpenAIChatModelFactory.ChatTokenUsage( + inputTokenCount: usageDetails.InputTokenCount ?? 0, + outputTokenCount: usageDetails.OutputTokenCount ?? 0, + totalTokenCount: usageDetails.TotalTokenCount ?? 0, + outputTokenDetails: outputTokenUsageDetails, + inputTokenDetails: inputTokenUsageDetails); + } + /// Converts an OpenAI role to an Extensions role. - private static ChatRole ToChatRole(ChatMessageRole role) => + private static ChatRole FromOpenAIChatRole(ChatMessageRole role) => role switch { ChatMessageRole.System => ChatRole.System, @@ -417,6 +671,19 @@ private static ChatRole ToChatRole(ChatMessageRole role) => _ => new ChatRole(role.ToString()), }; + /// Converts an Extensions role to an OpenAI role. + [return: NotNullIfNotNull("role")] + private static ChatMessageRole? ToOpenAIChatRole(ChatRole? role) => + role switch + { + null => null, + _ when role == ChatRole.System => ChatMessageRole.System, + _ when role == ChatRole.User => ChatMessageRole.User, + _ when role == ChatRole.Assistant => ChatMessageRole.Assistant, + _ when role == ChatRole.Tool => ChatMessageRole.Tool, + _ => ChatMessageRole.System, + }; + /// Creates an from a . /// The content part to convert into a content. /// The constructed , or null if the content part could not be converted. @@ -456,7 +723,7 @@ private static ChatRole ToChatRole(ChatMessageRole role) => } /// Converts an OpenAI finish reason to an Extensions finish reason. - private static ChatFinishReason? ToFinishReason(OpenAI.Chat.ChatFinishReason? finishReason) => + private static ChatFinishReason? FromOpenAIFinishReason(OpenAI.Chat.ChatFinishReason? finishReason) => finishReason?.ToString() is not string s ? null : finishReason switch { @@ -467,6 +734,16 @@ private static ChatRole ToChatRole(ChatMessageRole role) => _ => new ChatFinishReason(s), }; + /// Converts an Extensions finish reason to an OpenAI finish reason. + private static OpenAI.Chat.ChatFinishReason ToOpenAIFinishReason(ChatFinishReason? finishReason) => + finishReason switch + { + _ when finishReason == ChatFinishReason.Length => OpenAI.Chat.ChatFinishReason.Length, + _ when finishReason == ChatFinishReason.ContentFilter => OpenAI.Chat.ChatFinishReason.ContentFilter, + _ when finishReason == ChatFinishReason.ToolCalls => OpenAI.Chat.ChatFinishReason.ToolCalls, + _ or null => OpenAI.Chat.ChatFinishReason.Stop, + }; + private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) => FunctionCallContent.CreateFromParsedArguments(json, callId, name, argumentParser: static json => JsonSerializer.Deserialize(json, OpenAIJsonContext.Default.IDictionaryStringObject)!); @@ -475,6 +752,9 @@ private static FunctionCallContent ParseCallContentFromBinaryData(BinaryData ut8 FunctionCallContent.CreateFromParsedArguments(ut8Json, callId, name, argumentParser: static json => JsonSerializer.Deserialize(json, OpenAIJsonContext.Default.IDictionaryStringObject)!); + private static T? GetValueOrDefault(this AdditionalPropertiesDictionary? dict, string key) => + dict?.TryGetValue(key, out T? value) is true ? value : default; + /// Used to create the JSON payload for an OpenAI chat tool description. public sealed class OpenAIChatToolJson { @@ -485,7 +765,7 @@ public sealed class OpenAIChatToolJson public string Type { get; set; } = "object"; [JsonPropertyName("required")] - public List Required { get; set; } = []; + public HashSet Required { get; set; } = []; [JsonPropertyName("properties")] public Dictionary Properties { get; set; } = []; diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs new file mode 100644 index 00000000000..737b9d214a9 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs @@ -0,0 +1,134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Buffers; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.ServerSentEvents; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; +using OpenAI.Chat; + +#pragma warning disable S1135 // Track uses of "TODO" tags + +namespace Microsoft.Extensions.AI; + +/// +/// Defines a set of helpers used to serialize Microsoft.Extensions.AI content using the OpenAI wire format. +/// +public static class OpenAISerializationHelpers +{ + /// + /// Deserializes a stream using the OpenAI wire format into a pair of and values. + /// + /// The stream containing a message using the OpenAI wire format. + /// The governing deserialization of function call content. + /// A token used to cancel the operation. + /// The deserialized list of chat messages and chat options. + public static async Task<(IList Messages, ChatOptions? Options)> DeserializeFromOpenAIAsync( + Stream stream, + JsonSerializerOptions? options = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(stream); + options ??= AIJsonUtilities.DefaultOptions; + + BinaryData binaryData = await BinaryData.FromStreamAsync(stream, cancellationToken).ConfigureAwait(false); + ChatCompletionOptions openAiChatOptions = JsonModelHelpers.Deserialize(binaryData); + var openAiMessages = (IList)typeof(ChatCompletionOptions).GetProperty("Messages")!.GetValue(openAiChatOptions)!; + + IList messages = OpenAIModelMappers.FromOpenAIChatMessages(openAiMessages, options).ToList(); + ChatOptions chatOptions = OpenAIModelMappers.FromOpenAIOptions(openAiChatOptions); + return (messages, chatOptions); + } + + /// + /// Serializes a Microsoft.Extensions.AI completion using the OpenAI wire format. + /// + /// The chat completion to serialize. + /// The stream to write the value. + /// The governing function call content serialization. + /// A token used to cancel the serialization operation. + /// A task tracking the serialization operation. + public static async Task SerializeAsOpenAIAsync( + this ChatCompletion chatCompletion, + Stream stream, + JsonSerializerOptions? options = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(stream); + _ = Throw.IfNull(chatCompletion); + options ??= AIJsonUtilities.DefaultOptions; + + OpenAI.Chat.ChatCompletion openAiChatCompletion = OpenAIModelMappers.ToOpenAIChatCompletion(chatCompletion, options); + BinaryData binaryData = JsonModelHelpers.Serialize(openAiChatCompletion); +#if NET + await stream.WriteAsync(binaryData.ToMemory(), cancellationToken).ConfigureAwait(false); +#else + await stream.WriteAsync(binaryData.ToArray(), 0, binaryData.Length, cancellationToken).ConfigureAwait(false); +#endif + } + + /// + /// Serializes a Microsoft.Extensions.AI streaming completion using the OpenAI wire format. + /// + /// The streaming chat completions to serialize. + /// The stream to write the value. + /// The governing function call content serialization. + /// A token used to cancel the serialization operation. + /// A task tracking the serialization operation. + public static Task SerializeAsOpenAIAsync( + this IAsyncEnumerable streamingChatCompletionUpdates, + Stream stream, + JsonSerializerOptions? options = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(stream); + _ = Throw.IfNull(streamingChatCompletionUpdates); + options ??= AIJsonUtilities.DefaultOptions; + + var mappedUpdates = OpenAIModelMappers.ToOpenAIStreamingChatCompletionAsync(streamingChatCompletionUpdates, options, cancellationToken); + return SseFormatter.WriteAsync(WrapEventsAsync(mappedUpdates), stream, FormatAsSseEvent, cancellationToken); + + static async IAsyncEnumerable> WrapEventsAsync(IAsyncEnumerable elements) + { + await foreach (T element in elements.ConfigureAwait(false)) + { + yield return new SseItem(element); // TODO specify eventId or reconnection interval? + } + } + + static void FormatAsSseEvent(SseItem sseItem, IBufferWriter writer) + { + BinaryData binaryData = JsonModelHelpers.Serialize(sseItem.Data); + writer.Write(binaryData.ToMemory().Span); + } + } + + private static class JsonModelHelpers + { + public static BinaryData Serialize(TModel value) + where TModel : IJsonModel + { + return value.Write(ModelReaderWriterOptions.Json); + } + + public static TModel Deserialize(BinaryData data) + where TModel : IJsonModel, new() + { + return JsonModelDeserializationWitness.Value.Create(data, ModelReaderWriterOptions.Json); + } + + private sealed class JsonModelDeserializationWitness + where TModel : IJsonModel, new() + { + public static readonly IJsonModel Value = new TModel(); + } + } + +}