Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dynamic AI Backend Configuration #17

Merged
merged 5 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions Game.Api/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
using Game.DialogManagement.Api;
using Game.Shared.External;
using Game.Shared.External.Providers.Ai;
using Game.Shared.External.Providers.Ai.Llama;

var builder = WebApplication.CreateBuilder(args);

// Add services to the container.
builder.Services.Configure<AiEndpointConfig>(builder.Configuration.GetSection("AiEndpoint"));
builder.Services.AddAiProvider<LlamaAiProvider>();
builder.Services.Configure<AiBackendConfig>(builder.Configuration.GetSection("AiEndpoint"));

RegisterTypeFromConfiguration(builder, "AiBackendProvider", builder.Services.AddAiProvider);
RegisterTypeFromConfiguration(
builder,
"AiPromptTemplateProvider",
builder.Services.AddPromptTemplateProvider
);

builder.Services.AddAccountManagementModule().AddDialogManagementModule();
builder.Services.AddControllers();
Expand All @@ -31,3 +36,25 @@
app.MapControllers();

app.Run();

static void RegisterTypeFromConfiguration(
WebApplicationBuilder builder,
string configuration,
Func<Type, IServiceCollection> action
)
{
string? implementationType = builder.Configuration[configuration];
if (implementationType == null)
throw new InvalidOperationException($"Invalid configuration, {configuration} is null");

Type? aiBackendType = Type.GetType(implementationType);

if (aiBackendType != null)
{
action.Invoke(aiBackendType);
}
else
{
throw new InvalidOperationException($"Invalid configuration for {configuration}");
}
}
2 changes: 2 additions & 0 deletions Game.Api/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@
"AiEndpoint": {
"EndpointUrl": ""
},
"AiBackendProvider": "Game.Shared.External.Providers.Ai.LlamaCpp.LlamaCppBackendProvider, Game.Shared.External, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null",
"AiPromptTemplateProvider": "Game.Shared.External.Providers.Ai.AlpacaPromptProvider, Game.Shared.External, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null",
"AllowedHosts": "*"
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
<ItemGroup>
<Folder Include="PersistenceLayer\" />
<Folder Include="PersistenceLayer\Repositories\" />
<Folder Include="Providers\" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Game.DialogManagement.Domain\Game.DialogManagement.Domain.csproj" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ public class AiChatProvider : IAiChatProvider
/// <summary>
/// The AI provider we're using for prompts.
/// </summary>
private readonly IAiProvider _aiProvider;
private readonly IAiBackendProvider _aiProvider;
private readonly IPromptTemplateProvider _templateProvider;

/// <summary>
/// The amount of tokens to generate.
Expand All @@ -25,10 +26,15 @@ public class AiChatProvider : IAiChatProvider
/// <summary>
/// Constructs a new AI chat provider.
/// </summary>
/// <param name="aiProvider">The backing AI provider.</param>
public AiChatProvider(IAiProvider aiProvider)
/// <param name="aiProvider">The AI backen service provider.</param>
/// <param name="templateProvider">The AI prompt templating provider</param>
public AiChatProvider(
IAiBackendProvider aiProvider,
IPromptTemplateProvider templateProvider
)
{
_aiProvider = aiProvider;
_templateProvider = templateProvider;
}

/// <summary>
Expand All @@ -37,7 +43,7 @@ public AiChatProvider(IAiProvider aiProvider)
/// <param name="ctx">The dialogue context.</param>
/// <param name="message">The new message.</param>
/// <returns>The prompt to be sent to the AI.</returns>
private static string GeneratePrompt(DialogueContext ctx, string message)
private string GeneratePrompt(DialogueContext ctx, string message)
{
// TODO: Should these prompts be passed from somewhere else?
// (i.e. appsettings, or some other module)
Expand All @@ -47,16 +53,14 @@ private static string GeneratePrompt(DialogueContext ctx, string message)

var dialogue = GenerateDialogue(ctx, message);

return $"""
{systemPrompt}

### Instruction:
{ctx.NpcStory}
{dialogue}

### Response {responsePrompt}:
{ctx.NpcName}:
""";
var promptTemplate = new PromptTemplate
{
System = systemPrompt,
Input = ctx.NpcStory + "\n" + dialogue,
Response = ctx.NpcName + ":",
ResponseParams = responsePrompt
};
return _templateProvider.GetPrompt(promptTemplate);
}

/// <summary>
Expand Down
38 changes: 36 additions & 2 deletions Shared/Game.Shared.External/DependencyInjection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,43 @@ public static class DependencyInjection
/// Adds the shared external modules into the dependency injection container.
/// </summary>
public static IServiceCollection AddAiProvider<TProvider>(this IServiceCollection services)
where TProvider : class, IAiProvider
where TProvider : class, IAiBackendProvider
{
services.AddSingleton<IAiProvider, TProvider>();
return services.AddAiProvider(typeof(TProvider));
}

/// <summary>
/// Adds the shared external modules into the dependency injection container.
/// </summary>
public static IServiceCollection AddAiProvider(
this IServiceCollection services,
Type provider
)
{
services.AddSingleton(typeof(IAiBackendProvider), provider);
return services;
}

/// <summary>
/// Adds the shared external modules into the dependency injection container.
/// </summary>
public static IServiceCollection AddPromptTemplateProvider<TProvider>(
this IServiceCollection services
)
where TProvider : class, IPromptTemplateProvider
{
return services.AddAiProvider(typeof(TProvider));
}

/// <summary>
/// Adds the shared external modules into the dependency injection container.
/// </summary>
public static IServiceCollection AddPromptTemplateProvider(
this IServiceCollection services,
Type provider
)
{
services.AddSingleton(typeof(IPromptTemplateProvider), provider);
return services;
}
}
Expand Down
4 changes: 0 additions & 4 deletions Shared/Game.Shared.External/Game.Shared.External.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
<None Remove="Providers\" />
<None Remove="Providers\Ai\" />
</ItemGroup>
<ItemGroup>
<Folder Include="Providers\" />
<Folder Include="Providers\Ai\" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="7.0.0" />
<PackageReference Include="Microsoft.Extensions.Options" Version="2.2.0" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
/// <summary>
/// The config for the AI provider's endpoint.
/// </summary>
public class AiEndpointConfig
public class AiBackendConfig
{
/// <summary>
/// The endpoint's URL.
Expand Down
18 changes: 18 additions & 0 deletions Shared/Game.Shared.External/Providers/Ai/AlpacaPromptProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
namespace Game.Shared.External.Providers.Ai
{
public class AlpacaPromptProvider : IPromptTemplateProvider
{
public string GetPrompt(PromptTemplate template)
{
return $"""
{template.System}

### Instruction:
{template.Input}

### Response {template.ResponseParams}:
{template.Response}
""";
}
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
using System;
using Game.Shared.External.Providers.Ai;
using System;

namespace Game.Shared.External.Providers.Ai
{
/// <summary>
/// Provides the ability to generate the response to a prompt.
/// </summary>
public interface IAiProvider
public interface IAiBackendProvider
{
/// <summary>
/// Generates the response to a prompt.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Game.Shared.External.Providers.Ai
{
public interface IPromptTemplateProvider
{
string GetPrompt(PromptTemplate template);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using System.Text.Json.Serialization;

namespace Game.Shared.External.Providers.Ai.Llama.Internal
namespace Game.Shared.External.Providers.Ai.LlamaCpp.Internal
{
/// <summary>
/// The JSON prompt we're sending to the Llama AI.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
using System.Text.Json;
using Game.Shared.External.Providers.Ai;
using Game.Shared.External.Providers.Ai.Llama.Internal;
using Game.Shared.External.Providers.Ai.LlamaCpp.Internal;
using Microsoft.Extensions.Options;

namespace Game.Shared.External.Providers.Ai.Llama
namespace Game.Shared.External.Providers.Ai.LlamaCpp
{
/// <summary>
/// Provides access to the Llama AI model.
/// </summary>
public class LlamaAiProvider : IAiProvider
public class LlamaCppBackendProvider : IAiBackendProvider
{
/// <summary>
/// The client used to communicate with the AI provider.
Expand All @@ -17,7 +19,7 @@ public class LlamaAiProvider : IAiProvider
/// <summary>
/// Constructs a new Llama provider.
/// </summary>
public LlamaAiProvider(IOptions<AiEndpointConfig> opts)
public LlamaCppBackendProvider(IOptions<AiBackendConfig> opts)
{
_client = new HttpClient { BaseAddress = new(opts.Value.EndpointUrl) };
}
Expand Down
19 changes: 19 additions & 0 deletions Shared/Game.Shared.External/Providers/Ai/PromptTemplate.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Game.Shared.External.Providers.Ai
{
public class PromptTemplate
{
public required string System { get; set; }

public required string Input { get; set; }

public required string Response { get; set; }

public string? ResponseParams { get; set; }
}
}
Loading