Create multi_agent.py
Browse files- multi_agent.py +137 -0
multi_agent.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools, operator
|
2 |
+
|
3 |
+
from datetime import date
|
4 |
+
|
5 |
+
from typing import Annotated, Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
6 |
+
|
7 |
+
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
8 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
9 |
+
from langchain_core.messages import BaseMessage, HumanMessage
|
10 |
+
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
|
11 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
12 |
+
from langchain_core.tools import tool
|
13 |
+
from langchain_openai import ChatOpenAI
|
14 |
+
|
15 |
+
from langgraph.graph import StateGraph, END
|
16 |
+
|
17 |
+
class AgentState(TypedDict):
|
18 |
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
19 |
+
next: str
|
20 |
+
|
21 |
+
def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
|
22 |
+
prompt = ChatPromptTemplate.from_messages(
|
23 |
+
[
|
24 |
+
("system", system_prompt),
|
25 |
+
MessagesPlaceholder(variable_name="messages"),
|
26 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
27 |
+
]
|
28 |
+
)
|
29 |
+
|
30 |
+
agent = create_openai_tools_agent(llm, tools, prompt)
|
31 |
+
executor = AgentExecutor(agent=agent, tools=tools)
|
32 |
+
|
33 |
+
return executor
|
34 |
+
|
35 |
+
def agent_node(state, agent, name):
|
36 |
+
result = agent.invoke(state)
|
37 |
+
return {"messages": [HumanMessage(content=result["output"], name=name)]}
|
38 |
+
|
39 |
+
@tool
|
40 |
+
def today_tool(text: str) -> str:
|
41 |
+
"""Returns today's date. Use this for any questions related to knowing today's date.
|
42 |
+
The input should always be an empty string, and this function will always return today's date.
|
43 |
+
Any date mathematics should occur outside this function."""
|
44 |
+
return (str(date.today()) + "\n\nIf you have completed all tasks, respond with FINAL ANSWER.")
|
45 |
+
|
46 |
+
def create_graph(model, topic):
|
47 |
+
tavily_tool = TavilySearchResults(max_results=10)
|
48 |
+
|
49 |
+
members = ["Researcher"]
|
50 |
+
options = ["FINISH"] + members
|
51 |
+
|
52 |
+
system_prompt = (
|
53 |
+
"You are a Manager tasked with managing a conversation between the "
|
54 |
+
"following agent(s): {members}. Given the following user request, "
|
55 |
+
"respond with the agent to act next. Each agent will perform a "
|
56 |
+
"task and respond with their results and status. When finished, "
|
57 |
+
"respond with FINISH."
|
58 |
+
)
|
59 |
+
|
60 |
+
function_def = {
|
61 |
+
"name": "route",
|
62 |
+
"description": "Select the next role.",
|
63 |
+
"parameters": {
|
64 |
+
"title": "routeSchema",
|
65 |
+
"type": "object",
|
66 |
+
"properties": {
|
67 |
+
"next": {
|
68 |
+
"title": "Next",
|
69 |
+
"anyOf": [
|
70 |
+
{"enum": options},
|
71 |
+
],
|
72 |
+
}
|
73 |
+
},
|
74 |
+
"required": ["next"],
|
75 |
+
},
|
76 |
+
}
|
77 |
+
|
78 |
+
prompt = ChatPromptTemplate.from_messages(
|
79 |
+
[
|
80 |
+
("system", system_prompt),
|
81 |
+
MessagesPlaceholder(variable_name="messages"),
|
82 |
+
(
|
83 |
+
"system",
|
84 |
+
"Given the conversation above, who should act next? "
|
85 |
+
"Or should we FINISH? Select one of: {options}.",
|
86 |
+
),
|
87 |
+
]
|
88 |
+
).partial(options=str(options), members=", ".join(members))
|
89 |
+
|
90 |
+
llm = ChatOpenAI(model=model)
|
91 |
+
|
92 |
+
supervisor_chain = (
|
93 |
+
prompt
|
94 |
+
| llm.bind_functions(functions=[function_def], function_call="route")
|
95 |
+
| JsonOutputFunctionsParser()
|
96 |
+
)
|
97 |
+
|
98 |
+
researcher_agent = create_agent(llm, [tavily_tool, today_tool], system_prompt=
|
99 |
+
"1. Research content on topic: " + topic + ". "
|
100 |
+
"2. Based on your research, write an in-depth article on the topic. "
|
101 |
+
"3. The output must be in markdown format (omit the triple backticks). "
|
102 |
+
"4. At the beginning of the article, add current date and author: Multi-Agent AI System. "
|
103 |
+
"5. Also at the beginning of the article, add a references section with links to relevant content.")
|
104 |
+
researcher_node = functools.partial(agent_node, agent=researcher_agent, name="Researcher")
|
105 |
+
|
106 |
+
workflow = StateGraph(AgentState)
|
107 |
+
|
108 |
+
workflow.add_node("Manager", supervisor_chain)
|
109 |
+
workflow.add_node("Researcher", researcher_node)
|
110 |
+
|
111 |
+
for member in members:
|
112 |
+
workflow.add_edge(member, "Manager")
|
113 |
+
|
114 |
+
conditional_map = {k: k for k in members}
|
115 |
+
conditional_map["FINISH"] = END
|
116 |
+
workflow.add_conditional_edges("Manager", lambda x: x["next"], conditional_map)
|
117 |
+
|
118 |
+
workflow.set_entry_point("Manager")
|
119 |
+
|
120 |
+
return workflow.compile()
|
121 |
+
|
122 |
+
def run_multi_agent(llm, topic):
|
123 |
+
graph = create_graph(llm, topic)
|
124 |
+
|
125 |
+
result = graph.invoke({
|
126 |
+
"messages": [
|
127 |
+
HumanMessage(content=topic)
|
128 |
+
]
|
129 |
+
})
|
130 |
+
|
131 |
+
article = result['messages'][-1].content
|
132 |
+
|
133 |
+
print("===")
|
134 |
+
print(article)
|
135 |
+
print("===")
|
136 |
+
|
137 |
+
return article
|