Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose options for making schema generation conformant with the subset accepted by OpenAI. #5619

Merged
merged 3 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,36 @@ public sealed class AIJsonSchemaCreateOptions
/// <summary>
/// Gets a value indicating whether to include the type keyword in inferred schemas for .NET enums.
/// </summary>
public bool IncludeTypeInEnumSchemas { get; init; }
public bool IncludeTypeInEnumSchemas { get; init; } = true;

/// <summary>
/// Gets a value indicating whether to generate schemas with the additionalProperties set to false for .NET objects.
/// </summary>
public bool DisallowAdditionalProperties { get; init; }
public bool DisallowAdditionalProperties { get; init; } = true;

/// <summary>
/// Gets a value indicating whether to include the $schema keyword in inferred schemas.
/// </summary>
public bool IncludeSchemaKeyword { get; init; }

/// <summary>
/// Gets a value indicating whether to mark all properties as required in the schema.
/// </summary>
public bool RequireAllProperties { get; init; } = true;

/// <summary>
/// Gets a value indicating whether to filter keywords that are disallowed by certain AI vendors.
/// </summary>
/// <remarks>
/// Filters a number of non-essential schema keywords that are not yet supported by some AI vendors.
/// These include:
/// <list type="bullet">
/// <item>The "minLength", "maxLength", "pattern", and "format" keywords.</item>
/// <item>The "minimum", "maximum", and "multipleOf" keywords.</item>
/// <item>The "patternProperties", "unevaluatedProperties", "propertyNames", "minProperties", and "maxProperties" keywords.</item>
/// <item>The "unevaluatedItems", "contains", "minContains", "maxContains", "minItems", "maxItems", and "uniqueItems" keywords.</item>
/// </list>
/// See also https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported.
/// </remarks>
public bool FilterDisallowedKeywords { get; init; } = true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
#if !NET9_0_OR_GREATER
Expand Down Expand Up @@ -30,7 +31,9 @@
object? DefaultValue,
bool IncludeSchemaUri,
bool DisallowAdditionalProperties,
bool IncludeTypeInEnumSchemas);
bool IncludeTypeInEnumSchemas,
bool RequireAllProperties,
bool FilterDisallowedKeywords);

namespace Microsoft.Extensions.AI;

Expand All @@ -52,6 +55,10 @@ public static partial class AIJsonUtilities
/// <summary>Gets a JSON schema only accepting null values.</summary>
private static readonly JsonElement _nullJsonSchema = ParseJsonElement("""{"type":"null"}"""u8);

// List of keywords used by JsonSchemaExporter but explicitly disallowed by some AI vendors.
// cf. https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported
private static readonly string[] _schemaKeywordsDisallowedByAIVendors = ["minLength", "maxLength", "pattern", "format"];

/// <summary>
/// Determines a JSON schema for the provided parameter metadata.
/// </summary>
Expand Down Expand Up @@ -122,7 +129,9 @@ public static JsonElement CreateParameterJsonSchema(
defaultValue,
IncludeSchemaUri: false,
inferenceOptions.DisallowAdditionalProperties,
inferenceOptions.IncludeTypeInEnumSchemas);
inferenceOptions.IncludeTypeInEnumSchemas,
inferenceOptions.RequireAllProperties,
inferenceOptions.FilterDisallowedKeywords);

return GetJsonSchemaCached(serializerOptions, key);
}
Expand Down Expand Up @@ -154,7 +163,9 @@ public static JsonElement CreateJsonSchema(
defaultValue,
inferenceOptions.IncludeSchemaKeyword,
inferenceOptions.DisallowAdditionalProperties,
inferenceOptions.IncludeTypeInEnumSchemas);
inferenceOptions.IncludeTypeInEnumSchemas,
inferenceOptions.RequireAllProperties,
inferenceOptions.FilterDisallowedKeywords);

return GetJsonSchemaCached(serializerOptions, key);
}
Expand Down Expand Up @@ -242,6 +253,7 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)
const string PatternPropertyName = "pattern";
const string EnumPropertyName = "enum";
const string PropertiesPropertyName = "properties";
const string RequiredPropertyName = "required";
const string AdditionalPropertiesPropertyName = "additionalProperties";
const string DefaultPropertyName = "default";
const string RefPropertyName = "$ref";
Expand Down Expand Up @@ -275,11 +287,35 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)
}

// Disallow additional properties in object schemas
if (key.DisallowAdditionalProperties && objSchema.ContainsKey(PropertiesPropertyName) && !objSchema.ContainsKey(AdditionalPropertiesPropertyName))
if (key.DisallowAdditionalProperties &&
objSchema.ContainsKey(PropertiesPropertyName) &&
!objSchema.ContainsKey(AdditionalPropertiesPropertyName))
{
objSchema.Add(AdditionalPropertiesPropertyName, (JsonNode)false);
}

// Mark all properties as required
if (key.RequireAllProperties &&
objSchema.TryGetPropertyValue(PropertiesPropertyName, out JsonNode? properties) &&
properties is JsonObject propertiesObj)
{
_ = objSchema.TryGetPropertyValue(RequiredPropertyName, out JsonNode? required);
if (required is not JsonArray { } requiredArray || requiredArray.Count != propertiesObj.Count)
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
requiredArray = [.. propertiesObj.Select(prop => prop.Key)];
objSchema[RequiredPropertyName] = requiredArray;
}
}

// Filter potentially disallowed keywords.
if (key.FilterDisallowedKeywords)
{
foreach (string keyword in _schemaKeywordsDisallowedByAIVendors)
{
_ = objSchema.Remove(keyword);
}
}

// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
// schemas with "type": [...], and only understand "type" being a single value.
// STJ represents .NET integer types as ["string", "integer"], which will then lead to an error.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ static bool IsAsyncMethod(MethodInfo method)
bool sawAIContextParameter = false;
for (int i = 0; i < parameters.Length; i++)
{
if (GetParameterMarshaller(options.SerializerOptions, parameters[i], ref sawAIContextParameter, out _parameterMarshallers[i]) is AIFunctionParameterMetadata parameterView)
if (GetParameterMarshaller(options, parameters[i], ref sawAIContextParameter, out _parameterMarshallers[i]) is AIFunctionParameterMetadata parameterView)
{
parameterMetadata?.Add(parameterView);
}
Expand All @@ -209,7 +209,7 @@ static bool IsAsyncMethod(MethodInfo method)
{
ParameterType = returnType,
Description = method.ReturnParameter.GetCustomAttribute<DescriptionAttribute>(inherit: true)?.Description,
Schema = AIJsonUtilities.CreateJsonSchema(returnType, serializerOptions: options.SerializerOptions),
Schema = AIJsonUtilities.CreateJsonSchema(returnType, serializerOptions: options.SerializerOptions, inferenceOptions: options.SchemaCreateOptions),
},
AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary<string, object?>.Instance,
JsonSerializerOptions = options.SerializerOptions,
Expand Down Expand Up @@ -272,7 +272,7 @@ static bool IsAsyncMethod(MethodInfo method)
/// Gets a delegate for handling the marshaling of a parameter.
/// </summary>
private static AIFunctionParameterMetadata? GetParameterMarshaller(
JsonSerializerOptions options,
AIFunctionFactoryCreateOptions options,
ParameterInfo parameter,
ref bool sawAIFunctionContext,
out Func<IReadOnlyDictionary<string, object?>, AIFunctionContext?, object?> marshaller)
Expand Down Expand Up @@ -302,7 +302,7 @@ static bool IsAsyncMethod(MethodInfo method)

// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
Type parameterType = parameter.ParameterType;
JsonTypeInfo typeInfo = options.GetTypeInfo(parameterType);
JsonTypeInfo typeInfo = options.SerializerOptions.GetTypeInfo(parameterType);

// Create a marshaller that simply looks up the parameter by name in the arguments dictionary.
marshaller = (IReadOnlyDictionary<string, object?> arguments, AIFunctionContext? _) =>
Expand All @@ -325,7 +325,7 @@ static bool IsAsyncMethod(MethodInfo method)
#pragma warning disable CA1031 // Do not catch general exception types
try
{
string json = JsonSerializer.Serialize(value, options.GetTypeInfo(value.GetType()));
string json = JsonSerializer.Serialize(value, options.SerializerOptions.GetTypeInfo(value.GetType()));
return JsonSerializer.Deserialize(json, typeInfo);
}
catch
Expand Down Expand Up @@ -361,7 +361,8 @@ static bool IsAsyncMethod(MethodInfo method)
description,
parameter.HasDefaultValue,
parameter.DefaultValue,
options)
options.SerializerOptions,
options.SchemaCreateOptions)
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace Microsoft.Extensions.AI;
public sealed class AIFunctionFactoryCreateOptions
{
private JsonSerializerOptions _options = AIJsonUtilities.DefaultOptions;
private AIJsonSchemaCreateOptions _schemaCreateOptions = AIJsonSchemaCreateOptions.Default;

/// <summary>
/// Initializes a new instance of the <see cref="AIFunctionFactoryCreateOptions"/> class.
Expand All @@ -31,6 +32,15 @@ public JsonSerializerOptions SerializerOptions
set => _options = Throw.IfNull(value);
}

/// <summary>
/// Gets or sets the <see cref="AIJsonSchemaCreateOptions"/> governing the generation of JSON schemas for the function.
/// </summary>
public AIJsonSchemaCreateOptions SchemaCreateOptions
{
get => _schemaCreateOptions;
set => _schemaCreateOptions = Throw.IfNull(value);
}

/// <summary>Gets or sets the name to use for the function.</summary>
/// <value>
/// The name to use for the function. The default value is a name derived from the method represented by the passed <see cref="Delegate"/> or <see cref="MethodInfo"/>.
Expand Down
Loading
Loading