diff --git a/Directory.Build.props b/Directory.Build.props
index c24c4f7..e462741 100644
--- a/Directory.Build.props
+++ b/Directory.Build.props
@@ -1,7 +1,7 @@
enable
- 0.5.3
+ 0.6.0
diff --git a/PromptPlayground/Converters/NullToVisibilityConverter.cs b/PromptPlayground/Converters/NullToVisibilityConverter.cs
index 3d7e2e2..1b6e358 100644
--- a/PromptPlayground/Converters/NullToVisibilityConverter.cs
+++ b/PromptPlayground/Converters/NullToVisibilityConverter.cs
@@ -10,12 +10,12 @@ namespace PromptPlayground.Converters
{
public class NullToVisibilityConverter : IValueConverter
{
- public object Convert(object value, Type targetType, object parameter, CultureInfo culture)
+ public object? Convert(object? value, Type targetType, object? parameter, CultureInfo culture)
{
return value != null;
}
- public object ConvertBack(object value, Type targetType, object parameter, CultureInfo culture)
+ public object? ConvertBack(object? value, Type targetType, object? parameter, CultureInfo culture)
{
throw new NotImplementedException();
}
diff --git a/PromptPlayground/Converters/StringToBooleanConverter.cs b/PromptPlayground/Converters/StringToBooleanConverter.cs
index d0feb50..cf5c665 100644
--- a/PromptPlayground/Converters/StringToBooleanConverter.cs
+++ b/PromptPlayground/Converters/StringToBooleanConverter.cs
@@ -10,7 +10,7 @@ namespace PromptPlayground.Converters
{
public class StringToBooleanConverter : IValueConverter
{
- public object Convert(object value, Type targetType, object parameter, CultureInfo culture)
+ public object? Convert(object? value, Type targetType, object? parameter, CultureInfo culture)
{
if (value == null || parameter == null || !(value is string) || !(parameter is string))
{
@@ -23,7 +23,7 @@ public object Convert(object value, Type targetType, object parameter, CultureIn
return strValue.Equals(strParameter, StringComparison.OrdinalIgnoreCase);
}
- public object ConvertBack(object value, Type targetType, object parameter, CultureInfo culture)
+ public object? ConvertBack(object? value, Type targetType, object? parameter, CultureInfo culture)
{
throw new NotSupportedException();
}
diff --git a/PromptPlayground/Messages/ResultCountRequestMessage.cs b/PromptPlayground/Messages/ConfigurationRequestMessage.cs
similarity index 50%
rename from PromptPlayground/Messages/ResultCountRequestMessage.cs
rename to PromptPlayground/Messages/ConfigurationRequestMessage.cs
index 3c9c52f..043a857 100644
--- a/PromptPlayground/Messages/ResultCountRequestMessage.cs
+++ b/PromptPlayground/Messages/ConfigurationRequestMessage.cs
@@ -7,7 +7,13 @@
namespace PromptPlayground.Messages
{
- public class ResultCountRequestMessage : RequestMessage
+ public class ConfigurationRequestMessage : RequestMessage
{
+ public ConfigurationRequestMessage(string config)
+ {
+ Config = config;
+ }
+
+ public string Config { get; }
}
}
diff --git a/PromptPlayground/PromptPlayground.csproj b/PromptPlayground/PromptPlayground.csproj
index b6956df..6098a81 100644
--- a/PromptPlayground/PromptPlayground.csproj
+++ b/PromptPlayground/PromptPlayground.csproj
@@ -9,9 +9,9 @@
AnyCPU;x64
SKEXP0001;SKEXP0002;SKEXP0004;SKEXP0050
-
- $(DefineConstants);WINDOWS
-
+
+ $(DefineConstants);WINDOWS
+
@@ -33,7 +33,7 @@
-
+
@@ -47,8 +47,8 @@
-
- PluginsView.axaml
-
+
+ PluginsView.axaml
+
diff --git a/PromptPlayground/Services/PromptService.cs b/PromptPlayground/Services/PromptService.cs
index 414f3ce..f07f69b 100644
--- a/PromptPlayground/Services/PromptService.cs
+++ b/PromptPlayground/Services/PromptService.cs
@@ -3,8 +3,11 @@
using PromptPlayground.ViewModels.ConfigViewModels;
using PromptPlayground.ViewModels.ConfigViewModels.LLM;
using System;
+using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text;
using System.Threading;
using System.Threading.Tasks;
@@ -32,6 +35,35 @@ private Kernel Build()
return _kernel;
}
+ public async IAsyncEnumerable RunStreamAsync(string prompt, PromptExecutionSettings? config,
+ KernelArguments arguments,
+ [EnumeratorCancellation]
+ CancellationToken cancellationToken = default)
+ {
+ var _kernel = Build();
+ var promptFilter = new KernelFilter();
+ _kernel.PromptFilters.Add(promptFilter);
+ _kernel.FunctionFilters.Add(promptFilter);
+
+ var sw = Stopwatch.StartNew();
+ var func = _kernel.CreateFunctionFromPrompt(prompt, config);
+
+ var results = func.InvokeStreamingAsync(_kernel, arguments, cancellationToken);
+
+ var sb = new StringBuilder();
+ await foreach (var result in results)
+ {
+ sb.Append(result.ToString());
+ yield return new GenerateResult()
+ {
+ Text = sb.ToString(),
+ Elapsed = sw.Elapsed,
+ PromptRendered = promptFilter.PromptRendered
+ };
+ }
+ sw.Stop();
+ }
+
public async Task RunAsync(string prompt,
PromptExecutionSettings? config,
KernelArguments arguments,
diff --git a/PromptPlayground/ViewModels/ConfigViewModel.cs b/PromptPlayground/ViewModels/ConfigViewModel.cs
index cdbf740..cda5b81 100644
--- a/PromptPlayground/ViewModels/ConfigViewModel.cs
+++ b/PromptPlayground/ViewModels/ConfigViewModel.cs
@@ -17,135 +17,147 @@
namespace PromptPlayground.ViewModels;
public partial class ConfigViewModel : ViewModelBase, IConfigAttributesProvider,
- IRecipient,
- IRecipient>
+ IRecipient,
+ IRecipient>
{
- private string[] RequiredAttributes =
+ private string[] RequiredAttributes =
[
#region LLM Config
ConfigAttribute.AzureDeployment,
- ConfigAttribute.AzureEndpoint,
- ConfigAttribute.AzureSecret,
- ConfigAttribute.BaiduClientId,
- ConfigAttribute.BaiduSecret,
- ConfigAttribute.BaiduModel,
- ConfigAttribute.OpenAIApiKey,
- ConfigAttribute.OpenAIModel,
- ConfigAttribute.DashScopeApiKey,
- ConfigAttribute.DashScopeModel,
- ConfigAttribute.LlamaModelPath
+ ConfigAttribute.AzureEndpoint,
+ ConfigAttribute.AzureSecret,
+ ConfigAttribute.BaiduClientId,
+ ConfigAttribute.BaiduSecret,
+ ConfigAttribute.BaiduModel,
+ ConfigAttribute.OpenAIApiKey,
+ ConfigAttribute.OpenAIModel,
+ ConfigAttribute.DashScopeApiKey,
+ ConfigAttribute.DashScopeModel,
+ ConfigAttribute.LlamaModelPath
#endregion
];
- public List AllAttributes { get; set; } = [];
-
- [ObservableProperty]
- private int maxCount = 3;
-
- #region Model
- private int modelSelectedIndex = 0;
- private ProfileService _profile;
-
- public int ModelSelectedIndex
- {
- get => modelSelectedIndex; set
- {
- if (modelSelectedIndex != value)
- {
- modelSelectedIndex = value;
- OnPropertyChanged(nameof(ModelSelectedIndex));
- OnPropertyChanged(nameof(SelectedModel));
- OnPropertyChanged(nameof(ModelAttributes));
- OnPropertyChanged(nameof(SelectedModel.Name));
- }
- }
- }
-
- [JsonIgnore]
- public List ModelLists => LLMs.Select(_ => _.Name).ToList();
- [JsonIgnore]
- public IList ModelAttributes => SelectedModel.SelectAttributes(this.AllAttributes);
- [JsonIgnore]
- public ILLMConfigViewModel SelectedModel => LLMs[ModelSelectedIndex];
- [JsonIgnore]
- private readonly List LLMs = [];
- #endregion
-
- #region IConfigAttributesProvider
- IList IConfigAttributesProvider.AllAttributes => this.AllAttributes;
- public ILLMConfigViewModel GetLLM()
- {
- return this.SelectedModel;
- }
- #endregion
-
- public ConfigViewModel(bool requireLoadConfig = false) : this()
- {
- if (requireLoadConfig)
- {
- WeakReferenceMessenger.Default.RegisterAll(this);
- LoadConfigFromUserProfile();
- }
- }
-
- public ConfigViewModel()
- {
- this.AllAttributes = CheckAttributes(this.AllAttributes);
-
- LLMs.Add(new AzureOpenAIConfigViewModel(this));
- LLMs.Add(new BaiduConfigViewModel(this));
- LLMs.Add(new OpenAIConfigViewModel(this));
- LLMs.Add(new DashScopeConfigViewModel(this));
- LLMs.Add(new LlamaSharpConfigViewModel(this));
-
- this._profile = new ProfileService("user.config");
- }
-
- private void LoadConfigFromUserProfile()
- {
- var vm = this._profile.Get();
- if (vm != null)
- {
- this.AllAttributes = CheckAttributes(vm.AllAttributes);
-
- this.MaxCount = vm.MaxCount;
- this.ModelSelectedIndex = vm.ModelSelectedIndex;
- }
- }
- private List CheckAttributes(List list)
- {
- foreach (var item in RequiredAttributes)
- {
- if (!list.Any(_ => _.Name == item))
- {
- list.Add(new ConfigAttribute(item));
- }
- }
- return list;
- }
-
- private void SaveConfigToUserProfile()
- {
- this._profile.Save(this);
- }
-
- public void SaveConfig()
- {
- SaveConfigToUserProfile();
- }
-
- public void ReloadConfig()
- {
- this.LoadConfigFromUserProfile();
- }
-
- public void Receive(ResultCountRequestMessage message)
- {
- message.Reply(this.MaxCount);
- }
-
- public void Receive(RequestMessage message)
- {
- message.Reply(this);
- }
+ public List AllAttributes { get; set; } = [];
+
+ [ObservableProperty]
+ private int maxCount = 3;
+
+ [ObservableProperty]
+ private bool runStream = false;
+
+ #region Model
+ private int modelSelectedIndex = 0;
+ private ProfileService _profile;
+
+ public int ModelSelectedIndex
+ {
+ get => modelSelectedIndex; set
+ {
+ if (modelSelectedIndex != value)
+ {
+ modelSelectedIndex = value;
+ OnPropertyChanged(nameof(ModelSelectedIndex));
+ OnPropertyChanged(nameof(SelectedModel));
+ OnPropertyChanged(nameof(ModelAttributes));
+ OnPropertyChanged(nameof(SelectedModel.Name));
+ }
+ }
+ }
+
+ [JsonIgnore]
+ public List ModelLists => LLMs.Select(_ => _.Name).ToList();
+ [JsonIgnore]
+ public IList ModelAttributes => SelectedModel.SelectAttributes(this.AllAttributes);
+ [JsonIgnore]
+ public ILLMConfigViewModel SelectedModel => LLMs[ModelSelectedIndex];
+ [JsonIgnore]
+ private readonly List LLMs = [];
+ #endregion
+
+ #region IConfigAttributesProvider
+ IList IConfigAttributesProvider.AllAttributes => this.AllAttributes;
+ public ILLMConfigViewModel GetLLM()
+ {
+ return this.SelectedModel;
+ }
+ #endregion
+
+ public ConfigViewModel(bool requireLoadConfig = false) : this()
+ {
+ if (requireLoadConfig)
+ {
+ WeakReferenceMessenger.Default.RegisterAll(this);
+ LoadConfigFromUserProfile();
+ }
+ }
+
+ public ConfigViewModel()
+ {
+ this.AllAttributes = CheckAttributes(this.AllAttributes);
+
+ LLMs.Add(new AzureOpenAIConfigViewModel(this));
+ LLMs.Add(new BaiduConfigViewModel(this));
+ LLMs.Add(new OpenAIConfigViewModel(this));
+ LLMs.Add(new DashScopeConfigViewModel(this));
+ // LLMs.Add(new LlamaSharpConfigViewModel(this));
+
+ this._profile = new ProfileService("user.config");
+ }
+
+ private void LoadConfigFromUserProfile()
+ {
+ var vm = this._profile.Get();
+ if (vm != null)
+ {
+ this.AllAttributes = CheckAttributes(vm.AllAttributes);
+
+ this.MaxCount = vm.MaxCount;
+ this.RunStream = vm.RunStream;
+ this.ModelSelectedIndex = vm.ModelSelectedIndex;
+ }
+ }
+ private List CheckAttributes(List list)
+ {
+ foreach (var item in RequiredAttributes)
+ {
+ if (!list.Any(_ => _.Name == item))
+ {
+ list.Add(new ConfigAttribute(item));
+ }
+ }
+ return list;
+ }
+
+ private void SaveConfigToUserProfile()
+ {
+ this._profile.Save(this);
+ }
+
+ public void SaveConfig()
+ {
+ SaveConfigToUserProfile();
+ }
+
+ public void ReloadConfig()
+ {
+ this.LoadConfigFromUserProfile();
+ }
+
+ public void Receive(RequestMessage message)
+ {
+ message.Reply(this);
+ }
+
+ public void Receive(ConfigurationRequestMessage request)
+ {
+ switch (request.Config)
+ {
+ case nameof(MaxCount):
+ request.Reply(this.MaxCount.ToString());
+ break;
+ case nameof(RunStream):
+ request.Reply(this.RunStream.ToString());
+ break;
+ }
+ }
}
diff --git a/PromptPlayground/ViewModels/ConfigViewModels/ConfigAttribute.cs b/PromptPlayground/ViewModels/ConfigViewModels/ConfigAttribute.cs
index 41a0d33..b344206 100644
--- a/PromptPlayground/ViewModels/ConfigViewModels/ConfigAttribute.cs
+++ b/PromptPlayground/ViewModels/ConfigViewModels/ConfigAttribute.cs
@@ -1,5 +1,6 @@
using CommunityToolkit.Mvvm.ComponentModel;
using DashScope;
+using ERNIE_Bot.SDK;
using Humanizer;
using System;
using System.Collections.Generic;
@@ -64,7 +65,7 @@ public string Value
public const string BaiduClientId = nameof(BaiduClientId);
public const string BaiduSecret = nameof(BaiduSecret);
- [ConfigType("select", "Ernie-Bot", "Ernie-Bot-turbo", "BLOOMZ_7B")]
+ [ConfigType("select", "Ernie-Bot", "Ernie-Bot-turbo", "Ernie-Bot-4", "Ernie-Bot 8k", "Ernie-Bot-speed", "BLOOMZ_7B")]
public const string BaiduModel = nameof(BaiduModel);
public const string OpenAIApiKey = nameof(OpenAIApiKey);
diff --git a/PromptPlayground/ViewModels/ConfigViewModels/LLM/BaiduConfigViewModel.cs b/PromptPlayground/ViewModels/ConfigViewModels/LLM/BaiduConfigViewModel.cs
index b27c448..a0385e5 100644
--- a/PromptPlayground/ViewModels/ConfigViewModels/LLM/BaiduConfigViewModel.cs
+++ b/PromptPlayground/ViewModels/ConfigViewModels/LLM/BaiduConfigViewModel.cs
@@ -39,6 +39,9 @@ public IKernelBuilder CreateKernelBuilder()
{
"BLOOMZ_7B" => ModelEndpoints.BLOOMZ_7B,
"Ernie-Bot-turbo" => ModelEndpoints.ERNIE_Bot_Turbo,
+ "Ernie-Bot-4" => ModelEndpoints.ERNIE_Bot_4,
+ "Ernie-Bot 8k" => ModelEndpoints.ERNIE_Bot_8K,
+ "Ernie-Bot-speed" => ModelEndpoints.ERNIE_Bot_Speed,
_ => ModelEndpoints.ERNIE_Bot
};
}
diff --git a/PromptPlayground/ViewModels/GenerateResult.cs b/PromptPlayground/ViewModels/GenerateResult.cs
index b3e0c82..509dabb 100644
--- a/PromptPlayground/ViewModels/GenerateResult.cs
+++ b/PromptPlayground/ViewModels/GenerateResult.cs
@@ -3,6 +3,7 @@
using CommunityToolkit.Mvvm.Messaging;
using PromptPlayground.ViewModels;
using System;
+using System.Collections.Generic;
namespace PromptPlayground.Services
{
@@ -25,7 +26,7 @@ public partial class GenerateResult : ViewModelBase
private string text = string.Empty;
[ObservableProperty]
- private TimeSpan elapsed;
+ private TimeSpan? elapsed;
[ObservableProperty]
private string? error;
diff --git a/PromptPlayground/ViewModels/SemanticFunctionViewModel.cs b/PromptPlayground/ViewModels/SemanticFunctionViewModel.cs
index 551b207..02aed09 100644
--- a/PromptPlayground/ViewModels/SemanticFunctionViewModel.cs
+++ b/PromptPlayground/ViewModels/SemanticFunctionViewModel.cs
@@ -224,26 +224,41 @@ public async Task GenerateResultAsync(CancellationToken cancellationToken)
var results = Enumerable.Range(0, maxCount)
.Select(_ => new GenerateResult()
{
- Text = "......"
+ Text = "🤖"
}).ToList();
var tasks = results
.Select(async r =>
{
Results.Add(r);
- var result = await service.RunAsync(Prompt, PromptConfig.DefaultExecutionSettings, new KernelArguments(arguments), cancellationToken);
- r.PromptRendered = result.PromptRendered;
- r.Text = result.Text;
- r.Elapsed = result.Elapsed;
- r.Error = result.Error;
- r.TokenUsage = result.TokenUsage;
+ if (GetRunStream())
+ {
+ var results = service.RunStreamAsync(Prompt, PromptConfig.DefaultExecutionSettings, new KernelArguments(arguments), cancellationToken);
+ await foreach (var result in results)
+ {
+ r.PromptRendered = result.PromptRendered;
+ r.Text = result.Text;
+ r.Elapsed = result.Elapsed;
+ r.Error = result.Error;
+ r.TokenUsage = result.TokenUsage;
+ }
+ }
+ else
+ {
+ var result = await service.RunAsync(Prompt, PromptConfig.DefaultExecutionSettings, new KernelArguments(arguments), cancellationToken);
+ r.PromptRendered = result.PromptRendered;
+ r.Text = result.Text;
+ r.Elapsed = result.Elapsed;
+ r.Error = result.Error;
+ r.TokenUsage = result.TokenUsage;
+ }
})
.ToList();
await Task.WhenAll(tasks);
Average.HasResults = true;
- Average.Elapsed = TimeSpan.FromMilliseconds(Results.Where(_ => !_.HasError).Average(_ => _.Elapsed.TotalMilliseconds));
+ Average.Elapsed = TimeSpan.FromMilliseconds(Results.Where(_ => !_.HasError).Where(_ => _.Elapsed.HasValue).Average(_ => _.Elapsed!.Value.TotalMilliseconds));
Average.TokenUsage = new ResultTokenUsage(
(int)Results.Where(_ => !_.HasError).Average(_ => _.TokenUsage?.Total ?? 0),
(int)Results.Where(_ => !_.HasError).Average(_ => _.TokenUsage?.Prompt ?? 0),
@@ -266,8 +281,16 @@ public async Task GenerateResultAsync(CancellationToken cancellationToken)
private int GetMaxCount()
{
- var result = WeakReferenceMessenger.Default.Send(new ResultCountRequestMessage());
- return result.Response;
+ var result = WeakReferenceMessenger.Default.Send(new ConfigurationRequestMessage("MaxCount"));
+
+ return int.Parse(result.Response);
+ }
+
+ private bool GetRunStream()
+ {
+ var result = WeakReferenceMessenger.Default.Send(new ConfigurationRequestMessage("RunStream"));
+
+ return bool.Parse(result.Response);
}
public ObservableCollection Results { get; set; } = new();
diff --git a/PromptPlayground/Views/ConfigWindow.axaml b/PromptPlayground/Views/ConfigWindow.axaml
index 29ed013..a8347ac 100644
--- a/PromptPlayground/Views/ConfigWindow.axaml
+++ b/PromptPlayground/Views/ConfigWindow.axaml
@@ -29,6 +29,10 @@
+
+
+
+
diff --git a/PromptPlayground/Views/ResultsView.axaml b/PromptPlayground/Views/ResultsView.axaml
index 0d39474..2d5653d 100644
--- a/PromptPlayground/Views/ResultsView.axaml
+++ b/PromptPlayground/Views/ResultsView.axaml
@@ -39,7 +39,8 @@
-