Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feat: Enable native tool calling for Granite #364

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion python/beeai_framework/agents/runners/default/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@


class DefaultRunner(BaseRunner):
# TODO: 333 - Currently a global variable is used to determine if the tool calling should be used or not
use_native_tool_calling: bool = False

def default_templates(self) -> BeeAgentTemplates:
return BeeAgentTemplates(
system=SystemPromptTemplate,
Expand Down Expand Up @@ -98,8 +101,13 @@ async def new_token(value: tuple[ChatModelOutput, Callable], event: EventMeta) -
def observe(llm_emitter: Emitter) -> None:
llm_emitter.on("newToken", new_token)

# TODO: 333 check this exists, shallow copy
tools: list[Tool] = self._input.tools[:]
output: ChatModelOutput = await self._input.llm.create(
ChatModelInput(messages=self.memory.messages[:], stream=True)
# For native tool calling we pass the tools to the llm call.
ChatModelInput(
messages=self.memory.messages[:], stream=True, tools=tools if self.use_native_tool_calling else None
)
).observe(fn=observe)

# Pick up any remaining lines in parser buffer
Expand Down
24 changes: 22 additions & 2 deletions python/beeai_framework/agents/runners/granite/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

GraniteAssistantPromptTemplate = PromptTemplate(
schema=AssistantPromptTemplateInput,
template="{{#thought}}Thought: {{.}}\n{{/thought}}{{#tool_name}}Tool Name: {{.}}\n{{/tool_name}}{{#tool_input}}Tool Input: {{&.}}\n{{/tool_input}}{{#tool_output}}Tool Output: {{&.}}\n{{/tool_output}}{{#final_answer}}Final Answer: {{.}}{{/final_answer}}", # noqa: E501
# TODO: 333 updated to match ts - not passing tools
template="{{#thought}}Thought: {{.}}\n{{/thought}}{{#tool_name}}Tool Name: {{.}}\n{{/tool_name}}{{#tool_input}}Tool Input: {{.}}\n{{/tool_input}}{{#final_answer}}Final Answer: {{.}}{{/final_answer}}", # noqa: E501
)

GraniteSystemPromptTemplate = PromptTemplate(
Expand Down Expand Up @@ -78,14 +79,33 @@
""", # noqa: E501
)

# TODO: 333 Additional prompts in ts
# export const GraniteBeeSchemaErrorPrompt = BeeSchemaErrorPrompt.fork((config) => {
# config.template = `Error: The generated response does not adhere to the communication structure mentioned in the
# system prompt.
# You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by 'Tool Name' and then
# 'Tool Input' or 'Thought' followed by 'Final Answer'.`;
# });

# export const GraniteBeeUserPrompt = BeeUserPrompt.fork((config) => {
# config.template = `{{input}}`;
# });
# TODO: 333 Updated text string to match typescript (minor)
GraniteToolNotFoundErrorTemplate = PromptTemplate(
schema=ToolNotFoundErrorTemplateInput,
template="""The tool does not exist!
template="""Tool does not exist!
{{#tools.length}}
Use one of the following tools: {{#trim}}{{#tools}}{{name}},{{/tools}}{{/trim}}
{{/tools.length}}""",
)

# TODO: 333 Additional prompt in ts
# export const GraniteBeeToolErrorPrompt = BeeToolErrorPrompt.fork((config) => {
# config.template = `The tool has failed; the error log is shown below. If the tool cannot accomplish what you want,
# use a different tool or explain why you can't use it.

# {{reason}}`;
# });
GraniteToolInputErrorTemplate = PromptTemplate(
schema=ToolInputErrorTemplateInput,
template="""{{reason}}
Expand Down
3 changes: 3 additions & 0 deletions python/beeai_framework/agents/runners/granite/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@


class GraniteRunner(DefaultRunner):
# TODO:333 - Currently a global variable is used to determine if the tool calling should be used or not
use_native_tool_calling: bool = True

def create_parser(self) -> LinePrefixParser:
"""Prefixes are renamed for granite"""
prefixes = [
Expand Down