diff --git a/docs/source/en/reference/agents.md b/docs/source/en/reference/agents.md index 8e33bcb7f..7cf44dd11 100644 --- a/docs/source/en/reference/agents.md +++ b/docs/source/en/reference/agents.md @@ -57,3 +57,11 @@ _This class is deprecated since 1.8.0: now you simply need to pass attributes `n > You must have `gradio` installed to use the UI. Please run `pip install smolagents[gradio]` if it's not the case. [[autodoc]] GradioUI + +## Prompts + +[[autodoc]] smolagents.agents.PromptTemplates + +[[autodoc]] smolagents.agents.PlanningPromptTemplate + +[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate diff --git a/docs/source/hi/reference/agents.md b/docs/source/hi/reference/agents.md index d49ecae62..dc3a18ce3 100644 --- a/docs/source/hi/reference/agents.md +++ b/docs/source/hi/reference/agents.md @@ -154,4 +154,12 @@ model = OpenAIServerModel( api_base="https://api.openai.com/v1", api_key=os.environ["OPENAI_API_KEY"], ) -``` \ No newline at end of file +``` + +## Prompts + +[[autodoc]] smolagents.agents.PromptTemplates + +[[autodoc]] smolagents.agents.PlanningPromptTemplate + +[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate diff --git a/docs/source/zh/reference/agents.md b/docs/source/zh/reference/agents.md index b8fdea376..471d245d5 100644 --- a/docs/source/zh/reference/agents.md +++ b/docs/source/zh/reference/agents.md @@ -146,4 +146,12 @@ model = LiteLLMModel("anthropic/claude-3-5-sonnet-latest", temperature=0.2, max_ print(model(messages)) ``` -[[autodoc]] LiteLLMModel \ No newline at end of file +[[autodoc]] LiteLLMModel + +## Prompts + +[[autodoc]] smolagents.agents.PromptTemplates + +[[autodoc]] smolagents.agents.PlanningPromptTemplate + +[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index f9fb74821..ca3f557e1 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -14,6 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +__all__ = ["AgentMemory", "CodeAgent", "MultiStepAgent", "ToolCallingAgent"] + import importlib.resources import inspect import re @@ -21,7 +24,7 @@ import time from collections import deque from logging import getLogger -from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, TypedDict, Union import yaml from jinja2 import StrictUndefined, Template @@ -80,6 +83,69 @@ def populate_template(template: str, variables: Dict[str, Any]) -> str: raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}") +class PlanningPromptTemplate(TypedDict): + """ + Prompt templates for the planning step. + + Args: + initial_facts (`str`): Initial facts prompt. + initial_plan (`str`): Initial plan prompt. + update_facts_pre_messages (`str`): Update facts pre-messages prompt. + update_facts_post_messages (`str`): Update facts post-messages prompt. + update_plan_pre_messages (`str`): Update plan pre-messages prompt. + update_plan_post_messages (`str`): Update plan post-messages prompt. + """ + + initial_facts: str + initial_plan: str + update_facts_pre_messages: str + update_facts_post_messages: str + update_plan_pre_messages: str + update_plan_post_messages: str + + +class ManagedAgentPromptTemplate(TypedDict): + """ + Prompt templates for the managed agent. + + Args: + task (`str`): Task prompt. + report (`str`): Report prompt. + """ + + task: str + report: str + + +class PromptTemplates(TypedDict): + """ + Prompt templates for the agent. + + Args: + system_prompt (`str`): System prompt. + planning ([`~agents.PlanningPromptTemplate`]): Planning prompt template. + managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt template. + """ + + system_prompt: str + planning: PlanningPromptTemplate + managed_agent: ManagedAgentPromptTemplate + + +EMPTY_PROMPT_TEMPLATES = PromptTemplates( + system_prompt="", + planning=PlanningPromptTemplate( + initial_facts="", + initial_plan="", + update_facts_pre_messages="", + update_facts_post_messages="", + update_plan_pre_messages="", + update_plan_post_messages="", + ), + managed_agent=ManagedAgentPromptTemplate(task="", report=""), +) + + class MultiStepAgent: """ Agent class that solves the given task step by step, using the ReAct framework: @@ -88,7 +154,7 @@ class MultiStepAgent: Args: tools (`list[Tool]`): [`Tool`]s that the agent can use. model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. - prompt_templates (`dict`, *optional*): Prompt templates. + prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates. max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task. tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output. add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools. @@ -107,7 +173,7 @@ def __init__( self, tools: List[Tool], model: Callable[[List[Dict[str, str]]], ChatMessage], - prompt_templates: Optional[dict] = None, + prompt_templates: Optional[PromptTemplates] = None, max_steps: int = 6, tool_parser: Optional[Callable] = None, add_base_tools: bool = False, @@ -125,7 +191,7 @@ def __init__( tool_parser = parse_json_tool_call self.agent_name = self.__class__.__name__ self.model = model - self.prompt_templates = prompt_templates or {} + self.prompt_templates = prompt_templates or EMPTY_PROMPT_TEMPLATES self.max_steps = max_steps self.step_number: int = 0 self.tool_parser = tool_parser @@ -634,7 +700,7 @@ class ToolCallingAgent(MultiStepAgent): Args: tools (`list[Tool]`): [`Tool`]s that the agent can use. model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. - prompt_templates (`dict`, *optional*): Prompt templates. + prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates. planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. **kwargs: Additional keyword arguments. """ @@ -643,7 +709,7 @@ def __init__( self, tools: List[Tool], model: Callable[[List[Dict[str, str]]], ChatMessage], - prompt_templates: Optional[dict] = None, + prompt_templates: Optional[PromptTemplates] = None, planning_interval: Optional[int] = None, **kwargs, ): @@ -756,7 +822,7 @@ class CodeAgent(MultiStepAgent): Args: tools (`list[Tool]`): [`Tool`]s that the agent can use. model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. - prompt_templates (`dict`, *optional*): Prompt templates. + prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates. grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output. additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent. planning_interval (`int`, *optional*): Interval at which the agent will run a planning step. @@ -770,7 +836,7 @@ def __init__( self, tools: List[Tool], model: Callable[[List[Dict[str, str]]], ChatMessage], - prompt_templates: Optional[dict] = None, + prompt_templates: Optional[PromptTemplates] = None, grammar: Optional[Dict[str, str]] = None, additional_authorized_imports: Optional[List[str]] = None, planning_interval: Optional[int] = None, @@ -922,6 +988,3 @@ def step(self, memory_step: ActionStep) -> Union[None, Any]: self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO) memory_step.action_output = output return output if is_final_answer else None - - -__all__ = ["MultiStepAgent", "CodeAgent", "ToolCallingAgent", "AgentMemory"]