-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent.py
59 lines (49 loc) · 1.56 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from langchain_openai import ChatOpenAI
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode
from langgraph.graph.message import add_messages
from websearch_tool import websearch_tool
from rag_tool import rag_tool
llm = ChatOpenAI(
temperature=0.5,
model="gpt-4o-mini"
)
class State(TypedDict):
messages: Annotated[list, add_messages]
tools = [websearch_tool(llm), rag_tool]
llm_with_tools = llm.bind_tools(tools)
graph_builder = StateGraph(State)
def chatbot(state: State):
return {"messages": [llm_with_tools.invoke(state["messages"])]}
tool_node = ToolNode(tools=tools)
def route_tools(
state: State,
):
"""
Use in the conditional_edge to route to the ToolNode if the last message
has tool calls. Otherwise, route to the end.
"""
if isinstance(state, list):
ai_message = state[-1]
elif messages := state.get("messages", []):
ai_message = messages[-1]
else:
raise ValueError(f"No messages found in input state to tool_edge: {state}")
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return END
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_node("tools", tool_node)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_conditional_edges(
"chatbot",
route_tools,
{
"tools": "tools",
END: END
},
)
graph_builder.add_edge("tools", "chatbot")
agent = graph_builder.compile()