Skip to content

Commit

Permalink
Merge pull request #175 from Puchaczov/bugfix/no_max_output_tokens
Browse files Browse the repository at this point in the history
fixed no max output tokens
  • Loading branch information
awaescher authored Jan 22, 2025
2 parents 7d5ffbb + 7478092 commit 71ccc5e
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 35 deletions.
1 change: 1 addition & 0 deletions src/Constants/Application.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ internal static class Application
public const string Seed = "seed";
public const string TfsZ = "tfs_z";
public const string NumPredict = "num_predict";
public const string MaxOutputTokens = "max_output_tokens";
public const string TopK = "top_k";
public const string TopP = "top_p";
public const string MinP = "min_p";
Expand Down
71 changes: 37 additions & 34 deletions src/MicrosoftAi/AbstractionMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,46 +65,49 @@ public static ChatRequest ToOllamaSharpChatRequest(IList<ChatMessage> chatMessag
Temperature = options?.Temperature,
TopP = options?.TopP,
TopK = options?.TopK,
NumPredict = options?.MaxOutputTokens
},
Stream = stream,
Template = null,
Tools = ToOllamaSharpTools(options?.Tools)
};

if (options?.AdditionalProperties?.Any() ?? false)
{
TryAddOllamaOption<bool?>(options, OllamaOption.F16kv, v => request.Options.F16kv = (bool?)v);
TryAddOllamaOption<float?>(options, OllamaOption.FrequencyPenalty, v => request.Options.FrequencyPenalty = (float?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.LogitsAll, v => request.Options.LogitsAll = (bool?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.LowVram, v => request.Options.LowVram = (bool?)v);
TryAddOllamaOption<int?>(options, OllamaOption.MainGpu, v => request.Options.MainGpu = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.MinP, v => request.Options.MinP = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.MiroStat, v => request.Options.MiroStat = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.MiroStatEta, v => request.Options.MiroStatEta = (float?)v);
TryAddOllamaOption<float?>(options, OllamaOption.MiroStatTau, v => request.Options.MiroStatTau = (float?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.Numa, v => request.Options.Numa = (bool?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumBatch, v => request.Options.NumBatch = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumCtx, v => request.Options.NumCtx = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumGpu, v => request.Options.NumGpu = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumGqa, v => request.Options.NumGqa = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumKeep, v => request.Options.NumKeep = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumPredict, v => request.Options.NumPredict = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumThread, v => request.Options.NumThread = (int?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.PenalizeNewline, v => request.Options.PenalizeNewline = (bool?)v);
TryAddOllamaOption<float?>(options, OllamaOption.PresencePenalty, v => request.Options.PresencePenalty = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.RepeatLastN, v => request.Options.RepeatLastN = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.RepeatPenalty, v => request.Options.RepeatPenalty = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.Seed, v => request.Options.Seed = (int?)v);
TryAddOllamaOption<string[]?>(options, OllamaOption.Stop, v => request.Options.Stop = (v as IEnumerable<string>)?.ToArray());
TryAddOllamaOption<float?>(options, OllamaOption.Temperature, v => request.Options.Temperature = (float?)v);
TryAddOllamaOption<float?>(options, OllamaOption.TfsZ, v => request.Options.TfsZ = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.TopK, v => request.Options.TopK = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.TopP, v => request.Options.TopP = (float?)v);
TryAddOllamaOption<float?>(options, OllamaOption.TypicalP, v => request.Options.TypicalP = (float?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.UseMlock, v => request.Options.UseMlock = (bool?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.UseMmap, v => request.Options.UseMmap = (bool?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.VocabOnly, v => request.Options.VocabOnly = (bool?)v);
}
var hasAdditionalProperties = options?.AdditionalProperties?.Any() ?? false;
if (!hasAdditionalProperties)
return request;

TryAddOllamaOption<bool?>(options, OllamaOption.F16kv, v => request.Options.F16kv = (bool?)v);
TryAddOllamaOption<float?>(options, OllamaOption.FrequencyPenalty, v => request.Options.FrequencyPenalty = (float?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.LogitsAll, v => request.Options.LogitsAll = (bool?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.LowVram, v => request.Options.LowVram = (bool?)v);
TryAddOllamaOption<int?>(options, OllamaOption.MainGpu, v => request.Options.MainGpu = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.MinP, v => request.Options.MinP = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.MiroStat, v => request.Options.MiroStat = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.MiroStatEta, v => request.Options.MiroStatEta = (float?)v);
TryAddOllamaOption<float?>(options, OllamaOption.MiroStatTau, v => request.Options.MiroStatTau = (float?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.Numa, v => request.Options.Numa = (bool?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumBatch, v => request.Options.NumBatch = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumCtx, v => request.Options.NumCtx = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumGpu, v => request.Options.NumGpu = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumGqa, v => request.Options.NumGqa = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumKeep, v => request.Options.NumKeep = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumPredict, v => request.Options.NumPredict = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.MaxOutputTokens, v => request.Options.NumPredict = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumThread, v => request.Options.NumThread = (int?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.PenalizeNewline, v => request.Options.PenalizeNewline = (bool?)v);
TryAddOllamaOption<float?>(options, OllamaOption.PresencePenalty, v => request.Options.PresencePenalty = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.RepeatLastN, v => request.Options.RepeatLastN = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.RepeatPenalty, v => request.Options.RepeatPenalty = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.Seed, v => request.Options.Seed = (int?)v);
TryAddOllamaOption<string[]?>(options, OllamaOption.Stop, v => request.Options.Stop = (v as IEnumerable<string>)?.ToArray());
TryAddOllamaOption<float?>(options, OllamaOption.Temperature, v => request.Options.Temperature = (float?)v);
TryAddOllamaOption<float?>(options, OllamaOption.TfsZ, v => request.Options.TfsZ = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.TopK, v => request.Options.TopK = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.TopP, v => request.Options.TopP = (float?)v);
TryAddOllamaOption<float?>(options, OllamaOption.TypicalP, v => request.Options.TypicalP = (float?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.UseMlock, v => request.Options.UseMlock = (bool?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.UseMmap, v => request.Options.UseMmap = (bool?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.VocabOnly, v => request.Options.VocabOnly = (bool?)v);

return request;
}
Expand Down
6 changes: 6 additions & 0 deletions src/Models/OllamaOption.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ public class OllamaOption(string name)
/// (Default: 128, -1 = infinite generation, -2 = fill context)
/// </summary>
public static OllamaOption NumPredict { get; } = new(Application.NumPredict);

/// <summary>
/// The number of tokens to generate in the output.
/// (Default: -1, infinite generation)
/// </summary>
public static OllamaOption MaxOutputTokens { get; } = new(Application.MaxOutputTokens);

/// <summary>
/// Sets the number of threads to use during computation. By default,
Expand Down
101 changes: 100 additions & 1 deletion test/AbstractionMapperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,105 @@ public void Maps_Messages_With_Tools()
tool.Type.Should().Be("function");
}

[Test]
public void Maps_All_Options_With_AdditionalProperties()
{
// Arrange
List<ChatMessage> chatMessages = [];

var options = new ChatOptions
{
AdditionalProperties = new AdditionalPropertiesDictionary()
{
// Boolean options
[OllamaOption.F16kv.Name] = true,
[OllamaOption.LogitsAll.Name] = true,
[OllamaOption.LowVram.Name] = true,
[OllamaOption.Numa.Name] = true,
[OllamaOption.PenalizeNewline.Name] = true,
[OllamaOption.UseMlock.Name] = true,
[OllamaOption.UseMmap.Name] = true,
[OllamaOption.VocabOnly.Name] = true,

// Float options
[OllamaOption.FrequencyPenalty.Name] = 0.5f,
[OllamaOption.MinP.Name] = 0.1f,
[OllamaOption.MiroStatEta.Name] = 0.1f,
[OllamaOption.MiroStatTau.Name] = 0.2f,
[OllamaOption.PresencePenalty.Name] = 0.3f,
[OllamaOption.RepeatPenalty.Name] = 0.4f,
[OllamaOption.Temperature.Name] = 0.7f,
[OllamaOption.TfsZ.Name] = 0.8f,
[OllamaOption.TopP.Name] = 0.9f,
[OllamaOption.TypicalP.Name] = 0.95f,

// Integer options
[OllamaOption.MainGpu.Name] = 0,
[OllamaOption.MiroStat.Name] = 1,
[OllamaOption.NumBatch.Name] = 512,
[OllamaOption.NumCtx.Name] = 4096,
[OllamaOption.NumGpu.Name] = 1,
[OllamaOption.NumGqa.Name] = 8,
[OllamaOption.NumKeep.Name] = 64,
[OllamaOption.NumPredict.Name] = 1024,
[OllamaOption.MaxOutputTokens.Name] = 2048,
[OllamaOption.NumThread.Name] = 8,
[OllamaOption.RepeatLastN.Name] = 64,
[OllamaOption.Seed.Name] = 42,
[OllamaOption.TopK.Name] = 40,

// String array options
[OllamaOption.Stop.Name] = new[] { "stop1", "stop2" }
}
};

// Act
var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, options, stream: true, JsonSerializerOptions.Default);

// Assert
chatRequest.Options.Should().NotBeNull();

// Boolean assertions
chatRequest.Options!.F16kv.Should().BeTrue();
chatRequest.Options!.LogitsAll.Should().BeTrue();
chatRequest.Options!.LowVram.Should().BeTrue();
chatRequest.Options!.Numa.Should().BeTrue();
chatRequest.Options!.PenalizeNewline.Should().BeTrue();
chatRequest.Options!.UseMlock.Should().BeTrue();
chatRequest.Options!.UseMmap.Should().BeTrue();
chatRequest.Options!.VocabOnly.Should().BeTrue();

// Float assertions
chatRequest.Options!.FrequencyPenalty.Should().Be(0.5f);
chatRequest.Options!.MinP.Should().Be(0.1f);
chatRequest.Options!.MiroStatEta.Should().Be(0.1f);
chatRequest.Options!.MiroStatTau.Should().Be(0.2f);
chatRequest.Options!.PresencePenalty.Should().Be(0.3f);
chatRequest.Options!.RepeatPenalty.Should().Be(0.4f);
chatRequest.Options!.Temperature.Should().Be(0.7f);
chatRequest.Options!.TfsZ.Should().Be(0.8f);
chatRequest.Options!.TopP.Should().Be(0.9f);
chatRequest.Options!.TypicalP.Should().Be(0.95f);

// Integer assertions
chatRequest.Options!.MainGpu.Should().Be(0);
chatRequest.Options!.MiroStat.Should().Be(1);
chatRequest.Options!.NumBatch.Should().Be(512);
chatRequest.Options!.NumCtx.Should().Be(4096);
chatRequest.Options!.NumGpu.Should().Be(1);
chatRequest.Options!.NumGqa.Should().Be(8);
chatRequest.Options!.NumKeep.Should().Be(64);
chatRequest.Options!.NumPredict.Should().Be(2048);
chatRequest.Options!.NumThread.Should().Be(8);
chatRequest.Options!.RepeatLastN.Should().Be(64);
chatRequest.Options!.Seed.Should().Be(42);
chatRequest.Options!.TopK.Should().Be(40);

// String array assertions
chatRequest.Options!.Stop.Should().NotBeNull();
chatRequest.Options!.Stop.Should().BeEquivalentTo("stop1", "stop2");
}

[TestCaseSource(nameof(StopSequencesTestData))]
public void Maps_Messages_With_IEnumerable_StopSequences(object? enumerable)
{
Expand Down Expand Up @@ -435,6 +534,7 @@ public void Maps_Options()
chatRequest.Options.Seed.Should().Be(11);
chatRequest.Stream.Should().BeTrue();
chatRequest.Template.Should().BeNull();
chatRequest.Options.NumPredict.Should().Be(1000);

// not defined in ChatOptions
chatRequest.CustomHeaders.Should().BeEmpty();
Expand All @@ -453,7 +553,6 @@ public void Maps_Options()
chatRequest.Options.NumGpu.Should().BeNull();
chatRequest.Options.NumGqa.Should().BeNull();
chatRequest.Options.NumKeep.Should().BeNull();
chatRequest.Options.NumPredict.Should().BeNull();
chatRequest.Options.NumThread.Should().BeNull();
chatRequest.Options.PenalizeNewline.Should().BeNull();
chatRequest.Options.RepeatLastN.Should().BeNull();
Expand Down

0 comments on commit 71ccc5e

Please sign in to comment.