微信扫码
与创始人交个朋友
我要投稿
前言
在上篇文章《AI大模型实战篇:Reflexion,通过强化学习提升模型推理能力》中,风叔结合原理和具体源代码,详细介绍了Reflexion这种本质是强化学习的AI Agent设计模式。Reflexion已经算是一种非常高级的设计框架,在解决很多复杂问题时,也能有比较好的表现。
在这篇文章中,风叔将为大家介绍可能是目前最强大的AI Agent设计框架,集多种规划和反思技术的集大成者,LATS。文章内容会相对比较复杂难懂,值得收藏和反复研读。
LATS的概念
LATS,全称是Language Agent Tree Search,说的更直白一些,LATS = Tree search + ReAct + Plan&Execute + Reflexion。这么来看,LATS确实非常高级和复杂,下面我们根据上面的等式,先从宏观上拆解一下LATS。
1. Tree Search
Tree Search是一种树搜索算法,LATS 使用蒙特卡罗树搜索(MCTS)算法,通过平衡探索和利用,找到最优决策路径。
蒙特卡罗方法可能大家都比较熟悉了,是一种通过随机采样模拟来求解问题的方法。通过生成随机数,建立概率模型,以解决难以通过其他方法解决的数值问题。蒙特卡罗方法的一个典型应用是求定积分。假设我们要计算函数 f(x) 在[a, b]之间的积分,即阴影部分面积。
蒙特卡罗方法的解法如下:在[a, b]之间取一个随机数 x,用 f(x)⋅(b−a) 来估计阴影部分的面积。为了提高估计精度,可以取多个随机数 x,然后取这些估计值的平均值作为最终结果。当取的随机数 x 越多,结果将越准确,估计值将越接近真实值。
蒙特卡罗树搜索(MCTS)则是一种基于树结构的蒙特卡罗方法。它在整个 2^N(N 为决策次数,即树深度)空间中进行启发式搜索,通过反馈机制寻找最优路径。MCTS 的五个主要核心部分是:
MCTS 的每个循环包括四个步骤:
2. ReAct
ReAct的概念和设计模式,风叔在此前的文章中《AI大模型实战篇:AI Agent设计模式 - ReAct》已做过详细介绍。
它的典型流程如下图所示,可以用一个有趣的循环来描述:思考(Thought)→ 行动(Action)→ 观察(Observation),简称TAO循环。
思考(Thought):面对一个问题,我们需要进行深入的思考。这个思考过程是关于如何定义问题、确定解决问题所需的关键信息和推理步骤。
行动(Action):确定了思考的方向后,接下来就是行动的时刻。根据我们的思考,采取相应的措施或执行特定的任务,以期望推动问题向解决的方向发展。
观察(Observation):行动之后,我们必须仔细观察结果。这一步是检验我们的行动是否有效,是否接近了问题的答案。
循环迭代
3. Plan & Execute
Plan & Execute的概念和设计模式,风叔同样在此前的文章中《AI大模型实战篇:AI Agent设计模式 - Plan & Execute》已做过详细介绍,因此不再赘述。
Plan-and-Execute这个方法的本质是先计划再执行,即先把用户的问题分解成一个个的子任务,然后再执行各个子任务,并根据执行情况调整计划。
4. Reflexion
Reflexion的概念和设计模式,风叔在上篇文章《AI大模型实战篇:Reflexion,通过强化学习提升模型推理能力》做了详细介绍。
Reflexion的本质是Basic Reflection加上强化学习,完整的Reflexion框架由三个部分组成:
因此,融合了Tree Search、ReAct、Plan & Execute、Reflexion的能力于一身之后,LATS成为AI Agent设计模式中,集反思模式和规划模式的大成者。
LATS的工作流程
LATS的工作流程如下图所示,包括以下步骤:
下图是在langchain中实现LATS的过程:
第一步,选择:根据下面步骤中的总奖励选择最佳的下一步行动,如果找到解决方案或达到最大搜索深度,做出响应;否则就继续搜索。
第二步,扩展和执行:生成N个潜在操作,并且并行执行。
第三步,反思和评估:观察行动的结果,并根据反思和外部反馈对决策评分。
第四步,反向传播:根据结果更新轨迹的分数。
LATS的实现过程
下面,风叔通过实际的源码,详细介绍LATS模式的实现方法,具体的源代码地址可以在文末获取。
第一步 构建树节点
LATS 基于蒙特卡罗树搜索。对于每个搜索步骤,它都会选择具有最高“置信上限”的节点,这是一个平衡开发(最高平均奖励)和探索(最低访问量)的指标。从该节点开始,它会生成 N(在本例中为 5)个新的候选操作,并将它们添加到树中。当它生成有效解决方案或达到最大次数(搜索树深度)时,会停止搜索。
在Node节点中,我们定义了几个关键的函数:
import math
from collections import deque
from typing import Optional
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
class Node:
def __init__(
self,
messages: list[BaseMessage],
reflection: Reflection,
parent: Optional[Node] = None,
):
self.messages = messages
self.parent = parent
self.children = []
self.value = 0
self.visits = 0
self.reflection = reflection
self.depth = parent.depth + 1 if parent is not None else 1
self._is_solved = reflection.found_solution if reflection else False
if self._is_solved:
self._mark_tree_as_solved()
self.backpropagate(reflection.normalized_score)
def __repr__(self) -> str:
return (
f"<Node value={self.value}, visits={self.visits},"
f" solution={self.messages} reflection={self.reflection}/>"
)
@property
def is_solved(self):
"""If any solutions exist, we can end the search."""
return self._is_solved
@property
def is_terminal(self):
return not self.children
@property
def best_child(self):
"""Select the child with the highest UCT to search next."""
if not self.children:
return None
all_nodes = self._get_all_children()
return max(all_nodes, key=lambda child: child.upper_confidence_bound())
@property
def best_child_score(self):
"""Return the child with the highest value."""
if not self.children:
return None
return max(self.children, key=lambda child: int(child.is_solved) * child.value)
@property
def height(self) -> int:
"""Check for how far we've rolled out the tree."""
if self.children:
return 1 + max([child.height for child in self.children])
return 1
def upper_confidence_bound(self, exploration_weight=1.0):
"""Return the UCT score. This helps balance exploration vs. exploitation of a branch."""
if self.parent is None:
raise ValueError("Cannot obtain UCT from root node")
if self.visits == 0:
return self.value
# Encourages exploitation of high-value trajectories
average_reward = self.value / self.visits
# Encourages exploration of less-visited trajectories
exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
return average_reward + exploration_weight * exploration_term
def backpropagate(self, reward: float):
"""Update the score of this node and its parents."""
node = self
while node:
node.visits += 1
node.value = (node.value * (node.visits - 1) + reward) / node.visits
node = node.parent
def get_messages(self, include_reflections: bool = True):
if include_reflections:
return self.messages + [self.reflection.as_message()]
return self.messages
def get_trajectory(self, include_reflections: bool = True) -> list[BaseMessage]:
"""Get messages representing this search branch."""
messages = []
node = self
while node:
messages.extend(
node.get_messages(include_reflections=include_reflections)[::-1]
)
node = node.parent
# Reverse the final back-tracked trajectory to return in the correct order
return messages[::-1]# root solution, reflection, child 1, ...
def _get_all_children(self):
all_nodes = []
nodes = deque()
nodes.append(self)
while nodes:
node = nodes.popleft()
all_nodes.extend(node.children)
for n in node.children:
nodes.append(n)
return all_nodes
def get_best_solution(self):
"""Return the best solution from within the current sub-tree."""
all_nodes = [self] + self._get_all_children()
best_node = max(
all_nodes,
# We filter out all non-terminal, non-solution trajectories
key=lambda node: int(node.is_terminal and node.is_solved) * node.value,
)
return best_node
def _mark_tree_as_solved(self):
parent = self.parent
while parent:
parent._is_solved = True
parent = parent.parent
第二步 构建Agent
Agent将主要处理三个事项:
对于更多实际的应用,比如代码生成,可以将代码执行结果集成到反馈或奖励中,这种外部反馈对Agent效果的提升将非常有用。
对于Agent,首先构建工具Tools,我们只使用了一个搜索引擎工具。
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)
tools = [tavily_tool]
tool_executor = ToolExecutor(tools=tools)
然后,构建反射系统,反射系统将根据决策和工具使用结果,对Agent的输出进行打分,我们将在其他两个节点中调用此方法。
class Reflection(BaseModel):
reflections: str = Field(
description="The critique and reflections on the sufficiency, superfluency,"
" and general quality of the response"
)
score: int = Field(
description="Score from 0-10 on the quality of the candidate response.",
gte=0,
lte=10,
)
found_solution: bool = Field(
description="Whether the response has fully solved the question or task."
)
def as_message(self):
return HumanMessage(
content=f"Reasoning: {self.reflections}\nScore: {self.score}"
)
@property
def normalized_score(self) -> float:
return self.score / 10.0
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"Reflect and grade the assistant response to the user question below.",
),
("user", "{input}"),
MessagesPlaceholder(variable_name="candidate"),
]
)
reflection_llm_chain = (
prompt
| llm.bind_tools(tools=[Reflection], tool_choice="Reflection").with_config(
run_name="Reflection"
)
| PydanticToolsParser(tools=[Reflection])
)
@as_runnable
def reflection_chain(inputs) -> Reflection:
tool_choices = reflection_llm_chain.invoke(inputs)
reflection = tool_choices[0]
if not isinstance(inputs["candidate"][-1], AIMessage):
reflection.found_solution = False
return reflection
接下来,我们从根节点开始,根据用户输入进行响应
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.runnables import RunnableConfig
prompt_template = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an AI assistant.",
),
("user", "{input}"),
MessagesPlaceholder(variable_name="messages", optional=True),
]
)
initial_answer_chain = prompt_template | llm.bind_tools(tools=tools).with_config(
run_name="GenerateInitialCandidate"
)
parser = JsonOutputToolsParser(return_id=True)
initial_response = initial_answer_chain.invoke(
{"input": "Write a research report on lithium pollution."}
)
然后开始根节点,我们将候选节点生成和reflection打包到单个节点中。
import json
# Define the node we will add to the graph
def generate_initial_response(state: TreeState) -> dict:
"""Generate the initial candidate response."""
res = initial_answer_chain.invoke({"input": state["input"]})
parsed = parser.invoke(res)
tool_responses = tool_executor.batch(
[ToolInvocation(tool=r["type"], tool_input=r["args"]) for r in parsed]
)
output_messages = [res] + [
ToolMessage(content=json.dumps(resp), tool_call_id=tool_call["id"])
for resp, tool_call in zip(tool_responses, parsed)
]
reflection = reflection_chain.invoke(
{"input": state["input"], "candidate": output_messages}
)
root = Node(output_messages, reflection=reflection)
return {
**state,
"root": root,
}
第三步 生成候选节点
对于每个节点,生成5个待探索的候选节点。
def generate_candidates(messages: ChatPromptValue, config: RunnableConfig):
n = config["configurable"].get("N", 5)
bound_kwargs = llm.bind_tools(tools=tools).kwargs
chat_result = llm.generate(
[messages.to_messages()],
n=n,
callbacks=config["callbacks"],
run_name="GenerateCandidates",
**bound_kwargs,
)
return [gen.message for gen in chat_result.generations[0]]
expansion_chain = prompt_template | generate_candidates
res = expansion_chain.invoke({"input": "Write a research report on lithium pollution."})
将候选节点生成和refleciton步骤打包在下面的扩展节点中,所有操作都以批处理的方式进行,以加快执行速度。
from collections import defaultdict
def expand(state: TreeState, config: RunnableConfig) -> dict:
"""Starting from the "best" node in the tree, generate N candidates for the next step."""
root = state["root"]
best_candidate: Node = root.best_child if root.children else root
messages = best_candidate.get_trajectory()
# Generate N candidates from the single child candidate
new_candidates = expansion_chain.invoke(
{"input": state["input"], "messages": messages}, config
)
parsed = parser.batch(new_candidates)
flattened = [
(i, tool_call)
for i, tool_calls in enumerate(parsed)
for tool_call in tool_calls
]
tool_responses = tool_executor.batch(
[
ToolInvocation(tool=tool_call["type"], tool_input=tool_call["args"])
for _, tool_call in flattened
]
)
collected_responses = defaultdict(list)
for (i, tool_call), resp in zip(flattened, tool_responses):
collected_responses[i].append(
ToolMessage(content=json.dumps(resp), tool_call_id=tool_call["id"])
)
output_messages = []
for i, candidate in enumerate(new_candidates):
output_messages.append([candidate] + collected_responses[i])
# Reflect on each candidate
# For tasks with external validation, you'd add that here.
reflections = reflection_chain.batch(
[{"input": state["input"], "candidate": msges} for msges in output_messages],
config,
)
# Grow tree
child_nodes = [
Node(cand, parent=best_candidate, reflection=reflection)
for cand, reflection in zip(output_messages, reflections)
]
best_candidate.children.extend(child_nodes)
# We have already extended the tree directly, so we just return the state
return state
第四步 构建流程图
下面,我们构建流程图,将根节点和扩展节点加入进来
from typing import Literal
from langgraph.graph import END, StateGraph, START
def should_loop(state: TreeState) -> Literal["expand", "__end__"]:
"""Determine whether to continue the tree search."""
root = state["root"]
if root.is_solved:
return END
if root.height > 5:
return END
return "expand"
builder = StateGraph(TreeState)
builder.add_node("start", generate_initial_response)
builder.add_node("expand", expand)
builder.add_edge(START, "start")
builder.add_conditional_edges(
"start",
# Either expand/rollout or finish
should_loop,
)
builder.add_conditional_edges(
"expand",
# Either continue to rollout or finish
should_loop,
)
graph = builder.compile()
至此,整个LATS的核心逻辑就介绍完了。大家可以关注公众号【风叔云】,回复关键词【LATS源码】,获取LATS设计模式的完整源代码。
总结
与其他基于树的方法相比,LATS实现了自我反思的推理步骤,显著提升了性能。当采取行动后,LATS不仅利用环境反馈,还结合来自语言模型的反馈,以判断推理中是否存在错误并提出替代方案。这种自我反思的能力与其强大的搜索算法相结合,使得LATS更适合处理一些相对复杂的任务。
然而,由于算法本身的复杂性以及涉及的反思步骤,LATS通常比其他单智能体方法使用更多的计算资源,并且完成任务所需的时间更长。
53AI,企业落地应用大模型首选服务商
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2025-01-11
人工智能:大模型从技术到业务应用
2025-01-11
深度长文|Agentic AI 时代:NVIDIA 的技术革命与雄心
2025-01-11
AI是否会终结传统搜索引擎?
2025-01-11
亚马逊云科技:LLMOps驱动生成式 AI 应用的运营化
2025-01-11
蚂蚁集团基于 Ray 构建的分布式 AI Agent 框架
2025-01-10
我们即将进入 Agentic AI 时代 ,而第一个落地就是 Coding Agent
2025-01-10
2025 AI Agent迷局:谁在玩真的,谁在演戏?
2025-01-10
AGI 通用人工智能模型:基础理论与实现路径
2024-08-13
2024-05-28
2024-04-26
2024-08-21
2024-06-13
2024-08-04
2024-07-09
2024-09-23
2024-07-18
2024-04-11