Skip to content

Commit

Permalink
Expose options for making schema generation conformant with the subse…
Browse files Browse the repository at this point in the history
…t accepted by OpenAI. (dotnet#5619)

* Expose options for making schema generation conformant with the subset accepted by OpenAI.

* Uses the same set of defaults in all layers.
  • Loading branch information
eiriktsarpalis authored and stephentoub committed Nov 19, 2024
1 parent b84a353 commit 1dbecfb
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,36 @@ public sealed class AIJsonSchemaCreateOptions
/// <summary>
/// Gets a value indicating whether to include the type keyword in inferred schemas for .NET enums.
/// </summary>
public bool IncludeTypeInEnumSchemas { get; init; }
public bool IncludeTypeInEnumSchemas { get; init; } = true;

/// <summary>
/// Gets a value indicating whether to generate schemas with the additionalProperties set to false for .NET objects.
/// </summary>
public bool DisallowAdditionalProperties { get; init; }
public bool DisallowAdditionalProperties { get; init; } = true;

/// <summary>
/// Gets a value indicating whether to include the $schema keyword in inferred schemas.
/// </summary>
public bool IncludeSchemaKeyword { get; init; }

/// <summary>
/// Gets a value indicating whether to mark all properties as required in the schema.
/// </summary>
public bool RequireAllProperties { get; init; } = true;

/// <summary>
/// Gets a value indicating whether to filter keywords that are disallowed by certain AI vendors.
/// </summary>
/// <remarks>
/// Filters a number of non-essential schema keywords that are not yet supported by some AI vendors.
/// These include:
/// <list type="bullet">
/// <item>The "minLength", "maxLength", "pattern", and "format" keywords.</item>
/// <item>The "minimum", "maximum", and "multipleOf" keywords.</item>
/// <item>The "patternProperties", "unevaluatedProperties", "propertyNames", "minProperties", and "maxProperties" keywords.</item>
/// <item>The "unevaluatedItems", "contains", "minContains", "maxContains", "minItems", "maxItems", and "uniqueItems" keywords.</item>
/// </list>
/// See also https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported.
/// </remarks>
public bool FilterDisallowedKeywords { get; init; } = true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
#if !NET9_0_OR_GREATER
Expand Down Expand Up @@ -30,7 +31,9 @@
object? DefaultValue,
bool IncludeSchemaUri,
bool DisallowAdditionalProperties,
bool IncludeTypeInEnumSchemas);
bool IncludeTypeInEnumSchemas,
bool RequireAllProperties,
bool FilterDisallowedKeywords);

namespace Microsoft.Extensions.AI;

Expand All @@ -52,6 +55,10 @@ public static partial class AIJsonUtilities
/// <summary>Gets a JSON schema only accepting null values.</summary>
private static readonly JsonElement _nullJsonSchema = ParseJsonElement("""{"type":"null"}"""u8);

// List of keywords used by JsonSchemaExporter but explicitly disallowed by some AI vendors.
// cf. https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported
private static readonly string[] _schemaKeywordsDisallowedByAIVendors = ["minLength", "maxLength", "pattern", "format"];

/// <summary>
/// Determines a JSON schema for the provided parameter metadata.
/// </summary>
Expand Down Expand Up @@ -122,7 +129,9 @@ public static JsonElement CreateParameterJsonSchema(
defaultValue,
IncludeSchemaUri: false,
inferenceOptions.DisallowAdditionalProperties,
inferenceOptions.IncludeTypeInEnumSchemas);
inferenceOptions.IncludeTypeInEnumSchemas,
inferenceOptions.RequireAllProperties,
inferenceOptions.FilterDisallowedKeywords);

return GetJsonSchemaCached(serializerOptions, key);
}
Expand Down Expand Up @@ -154,7 +163,9 @@ public static JsonElement CreateJsonSchema(
defaultValue,
inferenceOptions.IncludeSchemaKeyword,
inferenceOptions.DisallowAdditionalProperties,
inferenceOptions.IncludeTypeInEnumSchemas);
inferenceOptions.IncludeTypeInEnumSchemas,
inferenceOptions.RequireAllProperties,
inferenceOptions.FilterDisallowedKeywords);

return GetJsonSchemaCached(serializerOptions, key);
}
Expand Down Expand Up @@ -242,6 +253,7 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)
const string PatternPropertyName = "pattern";
const string EnumPropertyName = "enum";
const string PropertiesPropertyName = "properties";
const string RequiredPropertyName = "required";
const string AdditionalPropertiesPropertyName = "additionalProperties";
const string DefaultPropertyName = "default";
const string RefPropertyName = "$ref";
Expand Down Expand Up @@ -275,11 +287,35 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)
}

// Disallow additional properties in object schemas
if (key.DisallowAdditionalProperties && objSchema.ContainsKey(PropertiesPropertyName) && !objSchema.ContainsKey(AdditionalPropertiesPropertyName))
if (key.DisallowAdditionalProperties &&
objSchema.ContainsKey(PropertiesPropertyName) &&
!objSchema.ContainsKey(AdditionalPropertiesPropertyName))
{
objSchema.Add(AdditionalPropertiesPropertyName, (JsonNode)false);
}

// Mark all properties as required
if (key.RequireAllProperties &&
objSchema.TryGetPropertyValue(PropertiesPropertyName, out JsonNode? properties) &&
properties is JsonObject propertiesObj)
{
_ = objSchema.TryGetPropertyValue(RequiredPropertyName, out JsonNode? required);
if (required is not JsonArray { } requiredArray || requiredArray.Count != propertiesObj.Count)
{
requiredArray = [.. propertiesObj.Select(prop => prop.Key)];
objSchema[RequiredPropertyName] = requiredArray;
}
}

// Filter potentially disallowed keywords.
if (key.FilterDisallowedKeywords)
{
foreach (string keyword in _schemaKeywordsDisallowedByAIVendors)
{
_ = objSchema.Remove(keyword);
}
}

// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
// schemas with "type": [...], and only understand "type" being a single value.
// STJ represents .NET integer types as ["string", "integer"], which will then lead to an error.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ static bool IsAsyncMethod(MethodInfo method)
bool sawAIContextParameter = false;
for (int i = 0; i < parameters.Length; i++)
{
if (GetParameterMarshaller(options.SerializerOptions, parameters[i], ref sawAIContextParameter, out _parameterMarshallers[i]) is AIFunctionParameterMetadata parameterView)
if (GetParameterMarshaller(options, parameters[i], ref sawAIContextParameter, out _parameterMarshallers[i]) is AIFunctionParameterMetadata parameterView)
{
parameterMetadata?.Add(parameterView);
}
Expand All @@ -209,7 +209,7 @@ static bool IsAsyncMethod(MethodInfo method)
{
ParameterType = returnType,
Description = method.ReturnParameter.GetCustomAttribute<DescriptionAttribute>(inherit: true)?.Description,
Schema = AIJsonUtilities.CreateJsonSchema(returnType, serializerOptions: options.SerializerOptions),
Schema = AIJsonUtilities.CreateJsonSchema(returnType, serializerOptions: options.SerializerOptions, inferenceOptions: options.SchemaCreateOptions),
},
AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary<string, object?>.Instance,
JsonSerializerOptions = options.SerializerOptions,
Expand Down Expand Up @@ -272,7 +272,7 @@ static bool IsAsyncMethod(MethodInfo method)
/// Gets a delegate for handling the marshaling of a parameter.
/// </summary>
private static AIFunctionParameterMetadata? GetParameterMarshaller(
JsonSerializerOptions options,
AIFunctionFactoryCreateOptions options,
ParameterInfo parameter,
ref bool sawAIFunctionContext,
out Func<IReadOnlyDictionary<string, object?>, AIFunctionContext?, object?> marshaller)
Expand Down Expand Up @@ -302,7 +302,7 @@ static bool IsAsyncMethod(MethodInfo method)

// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
Type parameterType = parameter.ParameterType;
JsonTypeInfo typeInfo = options.GetTypeInfo(parameterType);
JsonTypeInfo typeInfo = options.SerializerOptions.GetTypeInfo(parameterType);

// Create a marshaller that simply looks up the parameter by name in the arguments dictionary.
marshaller = (IReadOnlyDictionary<string, object?> arguments, AIFunctionContext? _) =>
Expand All @@ -325,7 +325,7 @@ static bool IsAsyncMethod(MethodInfo method)
#pragma warning disable CA1031 // Do not catch general exception types
try
{
string json = JsonSerializer.Serialize(value, options.GetTypeInfo(value.GetType()));
string json = JsonSerializer.Serialize(value, options.SerializerOptions.GetTypeInfo(value.GetType()));
return JsonSerializer.Deserialize(json, typeInfo);
}
catch
Expand Down Expand Up @@ -361,7 +361,8 @@ static bool IsAsyncMethod(MethodInfo method)
description,
parameter.HasDefaultValue,
parameter.DefaultValue,
options)
options.SerializerOptions,
options.SchemaCreateOptions)
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace Microsoft.Extensions.AI;
public sealed class AIFunctionFactoryCreateOptions
{
private JsonSerializerOptions _options = AIJsonUtilities.DefaultOptions;
private AIJsonSchemaCreateOptions _schemaCreateOptions = AIJsonSchemaCreateOptions.Default;

/// <summary>
/// Initializes a new instance of the <see cref="AIFunctionFactoryCreateOptions"/> class.
Expand All @@ -31,6 +32,15 @@ public JsonSerializerOptions SerializerOptions
set => _options = Throw.IfNull(value);
}

/// <summary>
/// Gets or sets the <see cref="AIJsonSchemaCreateOptions"/> governing the generation of JSON schemas for the function.
/// </summary>
public AIJsonSchemaCreateOptions SchemaCreateOptions
{
get => _schemaCreateOptions;
set => _schemaCreateOptions = Throw.IfNull(value);
}

/// <summary>Gets or sets the name to use for the function.</summary>
/// <value>
/// The name to use for the function. The default value is a name derived from the method represented by the passed <see cref="Delegate"/> or <see cref="MethodInfo"/>.
Expand Down
Loading

0 comments on commit 1dbecfb

Please sign in to comment.