diff --git a/Directory.Build.props b/Directory.Build.props index c24c4f7..e462741 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -1,7 +1,7 @@ enable - 0.5.3 + 0.6.0 diff --git a/PromptPlayground/Converters/NullToVisibilityConverter.cs b/PromptPlayground/Converters/NullToVisibilityConverter.cs index 3d7e2e2..1b6e358 100644 --- a/PromptPlayground/Converters/NullToVisibilityConverter.cs +++ b/PromptPlayground/Converters/NullToVisibilityConverter.cs @@ -10,12 +10,12 @@ namespace PromptPlayground.Converters { public class NullToVisibilityConverter : IValueConverter { - public object Convert(object value, Type targetType, object parameter, CultureInfo culture) + public object? Convert(object? value, Type targetType, object? parameter, CultureInfo culture) { return value != null; } - public object ConvertBack(object value, Type targetType, object parameter, CultureInfo culture) + public object? ConvertBack(object? value, Type targetType, object? parameter, CultureInfo culture) { throw new NotImplementedException(); } diff --git a/PromptPlayground/Converters/StringToBooleanConverter.cs b/PromptPlayground/Converters/StringToBooleanConverter.cs index d0feb50..cf5c665 100644 --- a/PromptPlayground/Converters/StringToBooleanConverter.cs +++ b/PromptPlayground/Converters/StringToBooleanConverter.cs @@ -10,7 +10,7 @@ namespace PromptPlayground.Converters { public class StringToBooleanConverter : IValueConverter { - public object Convert(object value, Type targetType, object parameter, CultureInfo culture) + public object? Convert(object? value, Type targetType, object? parameter, CultureInfo culture) { if (value == null || parameter == null || !(value is string) || !(parameter is string)) { @@ -23,7 +23,7 @@ public object Convert(object value, Type targetType, object parameter, CultureIn return strValue.Equals(strParameter, StringComparison.OrdinalIgnoreCase); } - public object ConvertBack(object value, Type targetType, object parameter, CultureInfo culture) + public object? ConvertBack(object? value, Type targetType, object? parameter, CultureInfo culture) { throw new NotSupportedException(); } diff --git a/PromptPlayground/Messages/ResultCountRequestMessage.cs b/PromptPlayground/Messages/ConfigurationRequestMessage.cs similarity index 50% rename from PromptPlayground/Messages/ResultCountRequestMessage.cs rename to PromptPlayground/Messages/ConfigurationRequestMessage.cs index 3c9c52f..043a857 100644 --- a/PromptPlayground/Messages/ResultCountRequestMessage.cs +++ b/PromptPlayground/Messages/ConfigurationRequestMessage.cs @@ -7,7 +7,13 @@ namespace PromptPlayground.Messages { - public class ResultCountRequestMessage : RequestMessage + public class ConfigurationRequestMessage : RequestMessage { + public ConfigurationRequestMessage(string config) + { + Config = config; + } + + public string Config { get; } } } diff --git a/PromptPlayground/PromptPlayground.csproj b/PromptPlayground/PromptPlayground.csproj index b6956df..6098a81 100644 --- a/PromptPlayground/PromptPlayground.csproj +++ b/PromptPlayground/PromptPlayground.csproj @@ -9,9 +9,9 @@ AnyCPU;x64 SKEXP0001;SKEXP0002;SKEXP0004;SKEXP0050 - - $(DefineConstants);WINDOWS - + + $(DefineConstants);WINDOWS + @@ -33,7 +33,7 @@ - + @@ -47,8 +47,8 @@ - - PluginsView.axaml - + + PluginsView.axaml + diff --git a/PromptPlayground/Services/PromptService.cs b/PromptPlayground/Services/PromptService.cs index 414f3ce..f07f69b 100644 --- a/PromptPlayground/Services/PromptService.cs +++ b/PromptPlayground/Services/PromptService.cs @@ -3,8 +3,11 @@ using PromptPlayground.ViewModels.ConfigViewModels; using PromptPlayground.ViewModels.ConfigViewModels.LLM; using System; +using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; using System.Threading; using System.Threading.Tasks; @@ -32,6 +35,35 @@ private Kernel Build() return _kernel; } + public async IAsyncEnumerable RunStreamAsync(string prompt, PromptExecutionSettings? config, + KernelArguments arguments, + [EnumeratorCancellation] + CancellationToken cancellationToken = default) + { + var _kernel = Build(); + var promptFilter = new KernelFilter(); + _kernel.PromptFilters.Add(promptFilter); + _kernel.FunctionFilters.Add(promptFilter); + + var sw = Stopwatch.StartNew(); + var func = _kernel.CreateFunctionFromPrompt(prompt, config); + + var results = func.InvokeStreamingAsync(_kernel, arguments, cancellationToken); + + var sb = new StringBuilder(); + await foreach (var result in results) + { + sb.Append(result.ToString()); + yield return new GenerateResult() + { + Text = sb.ToString(), + Elapsed = sw.Elapsed, + PromptRendered = promptFilter.PromptRendered + }; + } + sw.Stop(); + } + public async Task RunAsync(string prompt, PromptExecutionSettings? config, KernelArguments arguments, diff --git a/PromptPlayground/ViewModels/ConfigViewModel.cs b/PromptPlayground/ViewModels/ConfigViewModel.cs index cdbf740..cda5b81 100644 --- a/PromptPlayground/ViewModels/ConfigViewModel.cs +++ b/PromptPlayground/ViewModels/ConfigViewModel.cs @@ -17,135 +17,147 @@ namespace PromptPlayground.ViewModels; public partial class ConfigViewModel : ViewModelBase, IConfigAttributesProvider, - IRecipient, - IRecipient> + IRecipient, + IRecipient> { - private string[] RequiredAttributes = + private string[] RequiredAttributes = [ #region LLM Config ConfigAttribute.AzureDeployment, - ConfigAttribute.AzureEndpoint, - ConfigAttribute.AzureSecret, - ConfigAttribute.BaiduClientId, - ConfigAttribute.BaiduSecret, - ConfigAttribute.BaiduModel, - ConfigAttribute.OpenAIApiKey, - ConfigAttribute.OpenAIModel, - ConfigAttribute.DashScopeApiKey, - ConfigAttribute.DashScopeModel, - ConfigAttribute.LlamaModelPath + ConfigAttribute.AzureEndpoint, + ConfigAttribute.AzureSecret, + ConfigAttribute.BaiduClientId, + ConfigAttribute.BaiduSecret, + ConfigAttribute.BaiduModel, + ConfigAttribute.OpenAIApiKey, + ConfigAttribute.OpenAIModel, + ConfigAttribute.DashScopeApiKey, + ConfigAttribute.DashScopeModel, + ConfigAttribute.LlamaModelPath #endregion ]; - public List AllAttributes { get; set; } = []; - - [ObservableProperty] - private int maxCount = 3; - - #region Model - private int modelSelectedIndex = 0; - private ProfileService _profile; - - public int ModelSelectedIndex - { - get => modelSelectedIndex; set - { - if (modelSelectedIndex != value) - { - modelSelectedIndex = value; - OnPropertyChanged(nameof(ModelSelectedIndex)); - OnPropertyChanged(nameof(SelectedModel)); - OnPropertyChanged(nameof(ModelAttributes)); - OnPropertyChanged(nameof(SelectedModel.Name)); - } - } - } - - [JsonIgnore] - public List ModelLists => LLMs.Select(_ => _.Name).ToList(); - [JsonIgnore] - public IList ModelAttributes => SelectedModel.SelectAttributes(this.AllAttributes); - [JsonIgnore] - public ILLMConfigViewModel SelectedModel => LLMs[ModelSelectedIndex]; - [JsonIgnore] - private readonly List LLMs = []; - #endregion - - #region IConfigAttributesProvider - IList IConfigAttributesProvider.AllAttributes => this.AllAttributes; - public ILLMConfigViewModel GetLLM() - { - return this.SelectedModel; - } - #endregion - - public ConfigViewModel(bool requireLoadConfig = false) : this() - { - if (requireLoadConfig) - { - WeakReferenceMessenger.Default.RegisterAll(this); - LoadConfigFromUserProfile(); - } - } - - public ConfigViewModel() - { - this.AllAttributes = CheckAttributes(this.AllAttributes); - - LLMs.Add(new AzureOpenAIConfigViewModel(this)); - LLMs.Add(new BaiduConfigViewModel(this)); - LLMs.Add(new OpenAIConfigViewModel(this)); - LLMs.Add(new DashScopeConfigViewModel(this)); - LLMs.Add(new LlamaSharpConfigViewModel(this)); - - this._profile = new ProfileService("user.config"); - } - - private void LoadConfigFromUserProfile() - { - var vm = this._profile.Get(); - if (vm != null) - { - this.AllAttributes = CheckAttributes(vm.AllAttributes); - - this.MaxCount = vm.MaxCount; - this.ModelSelectedIndex = vm.ModelSelectedIndex; - } - } - private List CheckAttributes(List list) - { - foreach (var item in RequiredAttributes) - { - if (!list.Any(_ => _.Name == item)) - { - list.Add(new ConfigAttribute(item)); - } - } - return list; - } - - private void SaveConfigToUserProfile() - { - this._profile.Save(this); - } - - public void SaveConfig() - { - SaveConfigToUserProfile(); - } - - public void ReloadConfig() - { - this.LoadConfigFromUserProfile(); - } - - public void Receive(ResultCountRequestMessage message) - { - message.Reply(this.MaxCount); - } - - public void Receive(RequestMessage message) - { - message.Reply(this); - } + public List AllAttributes { get; set; } = []; + + [ObservableProperty] + private int maxCount = 3; + + [ObservableProperty] + private bool runStream = false; + + #region Model + private int modelSelectedIndex = 0; + private ProfileService _profile; + + public int ModelSelectedIndex + { + get => modelSelectedIndex; set + { + if (modelSelectedIndex != value) + { + modelSelectedIndex = value; + OnPropertyChanged(nameof(ModelSelectedIndex)); + OnPropertyChanged(nameof(SelectedModel)); + OnPropertyChanged(nameof(ModelAttributes)); + OnPropertyChanged(nameof(SelectedModel.Name)); + } + } + } + + [JsonIgnore] + public List ModelLists => LLMs.Select(_ => _.Name).ToList(); + [JsonIgnore] + public IList ModelAttributes => SelectedModel.SelectAttributes(this.AllAttributes); + [JsonIgnore] + public ILLMConfigViewModel SelectedModel => LLMs[ModelSelectedIndex]; + [JsonIgnore] + private readonly List LLMs = []; + #endregion + + #region IConfigAttributesProvider + IList IConfigAttributesProvider.AllAttributes => this.AllAttributes; + public ILLMConfigViewModel GetLLM() + { + return this.SelectedModel; + } + #endregion + + public ConfigViewModel(bool requireLoadConfig = false) : this() + { + if (requireLoadConfig) + { + WeakReferenceMessenger.Default.RegisterAll(this); + LoadConfigFromUserProfile(); + } + } + + public ConfigViewModel() + { + this.AllAttributes = CheckAttributes(this.AllAttributes); + + LLMs.Add(new AzureOpenAIConfigViewModel(this)); + LLMs.Add(new BaiduConfigViewModel(this)); + LLMs.Add(new OpenAIConfigViewModel(this)); + LLMs.Add(new DashScopeConfigViewModel(this)); + // LLMs.Add(new LlamaSharpConfigViewModel(this)); + + this._profile = new ProfileService("user.config"); + } + + private void LoadConfigFromUserProfile() + { + var vm = this._profile.Get(); + if (vm != null) + { + this.AllAttributes = CheckAttributes(vm.AllAttributes); + + this.MaxCount = vm.MaxCount; + this.RunStream = vm.RunStream; + this.ModelSelectedIndex = vm.ModelSelectedIndex; + } + } + private List CheckAttributes(List list) + { + foreach (var item in RequiredAttributes) + { + if (!list.Any(_ => _.Name == item)) + { + list.Add(new ConfigAttribute(item)); + } + } + return list; + } + + private void SaveConfigToUserProfile() + { + this._profile.Save(this); + } + + public void SaveConfig() + { + SaveConfigToUserProfile(); + } + + public void ReloadConfig() + { + this.LoadConfigFromUserProfile(); + } + + public void Receive(RequestMessage message) + { + message.Reply(this); + } + + public void Receive(ConfigurationRequestMessage request) + { + switch (request.Config) + { + case nameof(MaxCount): + request.Reply(this.MaxCount.ToString()); + break; + case nameof(RunStream): + request.Reply(this.RunStream.ToString()); + break; + } + } } diff --git a/PromptPlayground/ViewModels/ConfigViewModels/ConfigAttribute.cs b/PromptPlayground/ViewModels/ConfigViewModels/ConfigAttribute.cs index 41a0d33..b344206 100644 --- a/PromptPlayground/ViewModels/ConfigViewModels/ConfigAttribute.cs +++ b/PromptPlayground/ViewModels/ConfigViewModels/ConfigAttribute.cs @@ -1,5 +1,6 @@ using CommunityToolkit.Mvvm.ComponentModel; using DashScope; +using ERNIE_Bot.SDK; using Humanizer; using System; using System.Collections.Generic; @@ -64,7 +65,7 @@ public string Value public const string BaiduClientId = nameof(BaiduClientId); public const string BaiduSecret = nameof(BaiduSecret); - [ConfigType("select", "Ernie-Bot", "Ernie-Bot-turbo", "BLOOMZ_7B")] + [ConfigType("select", "Ernie-Bot", "Ernie-Bot-turbo", "Ernie-Bot-4", "Ernie-Bot 8k", "Ernie-Bot-speed", "BLOOMZ_7B")] public const string BaiduModel = nameof(BaiduModel); public const string OpenAIApiKey = nameof(OpenAIApiKey); diff --git a/PromptPlayground/ViewModels/ConfigViewModels/LLM/BaiduConfigViewModel.cs b/PromptPlayground/ViewModels/ConfigViewModels/LLM/BaiduConfigViewModel.cs index b27c448..a0385e5 100644 --- a/PromptPlayground/ViewModels/ConfigViewModels/LLM/BaiduConfigViewModel.cs +++ b/PromptPlayground/ViewModels/ConfigViewModels/LLM/BaiduConfigViewModel.cs @@ -39,6 +39,9 @@ public IKernelBuilder CreateKernelBuilder() { "BLOOMZ_7B" => ModelEndpoints.BLOOMZ_7B, "Ernie-Bot-turbo" => ModelEndpoints.ERNIE_Bot_Turbo, + "Ernie-Bot-4" => ModelEndpoints.ERNIE_Bot_4, + "Ernie-Bot 8k" => ModelEndpoints.ERNIE_Bot_8K, + "Ernie-Bot-speed" => ModelEndpoints.ERNIE_Bot_Speed, _ => ModelEndpoints.ERNIE_Bot }; } diff --git a/PromptPlayground/ViewModels/GenerateResult.cs b/PromptPlayground/ViewModels/GenerateResult.cs index b3e0c82..509dabb 100644 --- a/PromptPlayground/ViewModels/GenerateResult.cs +++ b/PromptPlayground/ViewModels/GenerateResult.cs @@ -3,6 +3,7 @@ using CommunityToolkit.Mvvm.Messaging; using PromptPlayground.ViewModels; using System; +using System.Collections.Generic; namespace PromptPlayground.Services { @@ -25,7 +26,7 @@ public partial class GenerateResult : ViewModelBase private string text = string.Empty; [ObservableProperty] - private TimeSpan elapsed; + private TimeSpan? elapsed; [ObservableProperty] private string? error; diff --git a/PromptPlayground/ViewModels/SemanticFunctionViewModel.cs b/PromptPlayground/ViewModels/SemanticFunctionViewModel.cs index 551b207..02aed09 100644 --- a/PromptPlayground/ViewModels/SemanticFunctionViewModel.cs +++ b/PromptPlayground/ViewModels/SemanticFunctionViewModel.cs @@ -224,26 +224,41 @@ public async Task GenerateResultAsync(CancellationToken cancellationToken) var results = Enumerable.Range(0, maxCount) .Select(_ => new GenerateResult() { - Text = "......" + Text = "🤖" }).ToList(); var tasks = results .Select(async r => { Results.Add(r); - var result = await service.RunAsync(Prompt, PromptConfig.DefaultExecutionSettings, new KernelArguments(arguments), cancellationToken); - r.PromptRendered = result.PromptRendered; - r.Text = result.Text; - r.Elapsed = result.Elapsed; - r.Error = result.Error; - r.TokenUsage = result.TokenUsage; + if (GetRunStream()) + { + var results = service.RunStreamAsync(Prompt, PromptConfig.DefaultExecutionSettings, new KernelArguments(arguments), cancellationToken); + await foreach (var result in results) + { + r.PromptRendered = result.PromptRendered; + r.Text = result.Text; + r.Elapsed = result.Elapsed; + r.Error = result.Error; + r.TokenUsage = result.TokenUsage; + } + } + else + { + var result = await service.RunAsync(Prompt, PromptConfig.DefaultExecutionSettings, new KernelArguments(arguments), cancellationToken); + r.PromptRendered = result.PromptRendered; + r.Text = result.Text; + r.Elapsed = result.Elapsed; + r.Error = result.Error; + r.TokenUsage = result.TokenUsage; + } }) .ToList(); await Task.WhenAll(tasks); Average.HasResults = true; - Average.Elapsed = TimeSpan.FromMilliseconds(Results.Where(_ => !_.HasError).Average(_ => _.Elapsed.TotalMilliseconds)); + Average.Elapsed = TimeSpan.FromMilliseconds(Results.Where(_ => !_.HasError).Where(_ => _.Elapsed.HasValue).Average(_ => _.Elapsed!.Value.TotalMilliseconds)); Average.TokenUsage = new ResultTokenUsage( (int)Results.Where(_ => !_.HasError).Average(_ => _.TokenUsage?.Total ?? 0), (int)Results.Where(_ => !_.HasError).Average(_ => _.TokenUsage?.Prompt ?? 0), @@ -266,8 +281,16 @@ public async Task GenerateResultAsync(CancellationToken cancellationToken) private int GetMaxCount() { - var result = WeakReferenceMessenger.Default.Send(new ResultCountRequestMessage()); - return result.Response; + var result = WeakReferenceMessenger.Default.Send(new ConfigurationRequestMessage("MaxCount")); + + return int.Parse(result.Response); + } + + private bool GetRunStream() + { + var result = WeakReferenceMessenger.Default.Send(new ConfigurationRequestMessage("RunStream")); + + return bool.Parse(result.Response); } public ObservableCollection Results { get; set; } = new(); diff --git a/PromptPlayground/Views/ConfigWindow.axaml b/PromptPlayground/Views/ConfigWindow.axaml index 29ed013..a8347ac 100644 --- a/PromptPlayground/Views/ConfigWindow.axaml +++ b/PromptPlayground/Views/ConfigWindow.axaml @@ -29,6 +29,10 @@ + + + + diff --git a/PromptPlayground/Views/ResultsView.axaml b/PromptPlayground/Views/ResultsView.axaml index 0d39474..2d5653d 100644 --- a/PromptPlayground/Views/ResultsView.axaml +++ b/PromptPlayground/Views/ResultsView.axaml @@ -39,7 +39,8 @@ -