Skip to content

Commit

Permalink
tool-call: support Command R7B (+ return tool_plan "thoughts" in AP…
Browse files Browse the repository at this point in the history
…I) (#11585)

* `tool-call`: support Command R7B (w/ tool_plan return)

* `tool-call`: cleaner preservation of tokens + warn when likely bad chat template override

* `tool-call`: test cleanup / handle lazy grammar triggers
  • Loading branch information
ochafik authored Feb 2, 2025
1 parent 6980448 commit bfcce4d
Show file tree
Hide file tree
Showing 8 changed files with 420 additions and 56 deletions.
86 changes: 84 additions & 2 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ std::string common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
default:
throw std::runtime_error("Unknown chat format");
}
Expand Down Expand Up @@ -317,6 +318,79 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
}

static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
schemas.push_back({
{"type", "object"},
{"properties", {
{"tool_call_id", {
{"type", "string"},
// Command-R's template expects an integer string.
{"pattern", "^[0-9]{1,10}$"},
}},
{"tool_name", {
{"type", "string"},
{"const", function["name"]},
}},
{"parameters", function["parameters"]},
}},
{"required", json::array({"tool_call_id", "tool_name", "parameters"})},
});
});
auto schema = json {
{"type", "array"},
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!inputs.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
}, grammar_options);
data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false});
data.preserved_tokens = {
"<|START_RESPONSE|>",
"<|END_RESPONSE|>",
"<|START_THINKING|>",
"<|END_THINKING|>",
"<|END_ACTION|>",
};
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
return data;
}
static common_chat_msg common_chat_parse_command_r7b(const std::string & input) {
static std::regex response_regex("<\\|START_RESPONSE\\|>(.*?)<\\|END_RESPONSE\\|>");
static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
std::smatch match;

common_chat_msg result;
result.role = "assistant";
if (std::regex_match(input, match, response_regex)) {
result.content = match[1].str();
} else if (std::regex_match(input, match, thought_action_regex)) {
result.tool_plan = match[1].str();
auto actions_str = match[2].str();
auto actions = json::parse(actions_str);
for (const auto & action : actions) {
result.tool_calls.push_back({
/* .name = */ action["tool_name"],
/* .arguments = */ action["parameters"].dump(),
/* .id = */ action["tool_call_id"],
});
}
} else {
LOG_ERR("Failed to parse command_r output");
result.content = input;
}
return result;
}

static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
Expand Down Expand Up @@ -462,6 +536,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
});
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
data.preserved_tokens = {
"<|tool▁sep|>",
"<|tool▁call▁end|>",
};
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
}, grammar_options);
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
Expand Down Expand Up @@ -704,8 +782,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
// Not really a trigger but need to print this special token to get a successful parse.
data.grammar_triggers.push_back({"</tool_call>", /* .at_start = */ false});
data.preserved_tokens = { "</tool_call>" };
}, grammar_options);

data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
Expand Down Expand Up @@ -822,6 +899,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
if (src.find("[TOOL_CALLS]") != std::string::npos) {
return common_chat_params_init_mistral_nemo(tmpl, inputs);
}
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
return common_chat_params_init_command_r7b(tmpl, inputs);
}
return common_chat_params_init_generic(tmpl, inputs);
}

Expand Down Expand Up @@ -855,6 +935,8 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
return common_chat_parse_hermes_2_pro(input);
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
return common_chat_parse_firefunction_v2(input);
case COMMON_CHAT_FORMAT_COMMAND_R7B:
return common_chat_parse_command_r7b(input);
default:
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
}
Expand Down
2 changes: 2 additions & 0 deletions common/chat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,

COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
Expand All @@ -42,6 +43,7 @@ struct common_chat_params {
std::string grammar;
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};

Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "llama-cpp.h"

#include <set>
#include <string>
#include <vector>
#include <sstream>
Expand Down Expand Up @@ -163,6 +164,7 @@ struct common_params_sampling {
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
std::set<llama_token> preserved_tokens;

std::vector<llama_logit_bias> logit_bias; // logit biases to apply

Expand Down Expand Up @@ -621,6 +623,7 @@ struct common_chat_msg {
std::string role;
std::string content;
std::vector<common_tool_call> tool_calls;
std::string tool_plan = "";
};

// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
Expand Down
22 changes: 15 additions & 7 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,7 @@ curl http://localhost:8080/v1/chat/completions \
- Hermes 2/3, Qwen 2.5
- Mistral Nemo
- Firefunction v2
- Command R7B
- DeepSeek R1 (WIP / seems reluctant to call any tools?)

<details>
Expand Down Expand Up @@ -1202,21 +1203,28 @@ curl http://localhost:8080/v1/chat/completions \
```shell
# Native support:
llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Llama-3.2-3B-Instruct-GGUF:Q6_K
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B )
llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
# Native support requires the right template for these GGUFs:
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/firellama-3-firefunction-v2 )
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \
--chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
# Generic format support
llama-server --jinja -fa -hf bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0
llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K
```
- Test in CLI:
Expand Down
52 changes: 38 additions & 14 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ struct slot_params {
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
}

std::vector<std::string> grammar_trigger_words;
for (const auto & trigger : sampling.grammar_trigger_words) {
grammar_trigger_words.push_back(trigger.word);
}

return json {
{"n_predict", n_predict}, // Server configured n_predict
{"seed", sampling.seed},
Expand Down Expand Up @@ -165,8 +170,9 @@ struct slot_params {
{"n_probs", sampling.n_probs},
{"min_keep", sampling.min_keep},
{"grammar", sampling.grammar},
// {"grammar_trigger_words", sampling.grammar_trigger_words},
{"grammar_trigger_words", grammar_trigger_words},
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
{"preserved_tokens", sampling.preserved_tokens},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
Expand Down Expand Up @@ -363,12 +369,26 @@ struct server_task {
if (ids.size() == 1) {
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
continue;
}
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
}
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
for (const auto & t : *preserved_tokens) {
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_DBG("Preserved token: %d\n", ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
}
}
}
if (params.sampling.grammar_lazy) {
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
}
Expand Down Expand Up @@ -695,19 +715,19 @@ struct server_task_result_cmpl_final : server_task_result {

json to_json_oaicompat_chat() {
std::string finish_reason = "length";
common_chat_msg message;
common_chat_msg msg;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
LOG_DBG("Parsing chat message: %s\n", content.c_str());
message = common_chat_parse(content, oaicompat_chat_format);
finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls";
msg = common_chat_parse(content, oaicompat_chat_format);
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
} else {
message.content = content;
msg.content = content;
}

json tool_calls;
if (!message.tool_calls.empty()) {
if (!msg.tool_calls.empty()) {
tool_calls = json::array();
for (const auto & tc : message.tool_calls) {
for (const auto & tc : msg.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
Expand All @@ -719,14 +739,19 @@ struct server_task_result_cmpl_final : server_task_result {
}
}

json message {
{"content", msg.content},
{"tool_calls", tool_calls},
{"role", "assistant"},
};
if (!msg.tool_plan.empty()) {
message["tool_plan"] = msg.tool_plan;
}

json choice {
{"finish_reason", finish_reason},
{"index", 0},
{"message", json {
{"content", message.content},
{"tool_calls", tool_calls},
{"role", "assistant"},
}},
{"message", message},
};

if (!stream && probs_output.size() > 0) {
Expand Down Expand Up @@ -2833,8 +2858,7 @@ struct server_context {
server_slot * slot_batched = nullptr;

auto accept_special_token = [&](server_slot & slot, llama_token token) {
const auto & trigger_tokens = slot.params.sampling.grammar_trigger_tokens;
return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end();
return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
};

// frist, add sampled tokens from any ongoing sequences
Expand Down
1 change: 1 addition & 0 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ static json oaicompat_completion_params_parse(
});
}
llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}
Expand Down
Loading

0 comments on commit bfcce4d

Please sign in to comment.