AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


讲解 LangGraph 构造中的进阶用法
发布日期:2024-04-26 07:58:19 浏览次数: 2956 来源:同学小张




书接上文(【AI Agent】【LangGraph】0. 快速上手:协同LangChain,LangGraph帮你用图结构轻松构建多智能体),前面我们了解了 LangGraph 的概念和基本构造方法,今天我们来看下 LangGraph 构造中的进阶用法:给边加个条件 - 条件分支(Conditional edges)。

LangGraph 构造的是个图的数据结构,有节点(node) 和边(edge),那它的边也可以是带条件的。如何给边加入条件呢?可以通过 add_conditional_edges 函数添加带条件的边。

1. 完整代码及运行

废话不多说,先上完整代码,和运行结果。先跑起来看看效果再说。

from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, BaseMessage
from langgraph.graph import END, MessageGraph
import json
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.utils.function_calling import convert_to_openai_tool
from typing import List

@tool
def multiply(first_number: int, second_number: int):
    """Multiplies two numbers together."""
    return first_number * second_number

model = ChatOpenAI(temperature=0)
model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])

graph = MessageGraph()

def invoke_model(state: List[BaseMessage]):
    return model_with_tools.invoke(state)

graph.add_node("oracle", invoke_model)

def invoke_tool(state: List[BaseMessage]):
    tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
    multiply_call = None

    for tool_call in tool_calls:
        if tool_call.get("function").get("name") == "multiply":
            multiply_call = tool_call

    if multiply_call is None:
        raise Exception("No adder input found.")

    res = multiply.invoke(
        json.loads(multiply_call.get("function").get("arguments"))
    )

    return ToolMessage(
        tool_call_id=multiply_call.get("id"),
        content=res
    )

graph.add_node("multiply", invoke_tool)

graph.add_edge("multiply", END)

graph.set_entry_point("oracle")

def router(state: List[BaseMessage]):
    tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
    if len(tool_calls):
        return "multiply"
    else:
        return "end"

graph.add_conditional_edges("oracle", router, {
    "multiply": "multiply",
    "end": END,
})

runnable = graph.compile()

response = runnable.invoke(HumanMessage("What is 123 * 456?"))
print(response)

运行结果如下:

2. 代码详解

下面对上面的代码进行详细解释。

2.1 add_conditional_edges

首先,我们知道了可以通过 add_conditional_edges 来对边进行条件添加。这部分代码如下:

graph.add_conditional_edges("oracle", router, {
    "multiply": "multiply",
    "end": END,
})

add_conditional_edges接收三个参数:

  • • 第一个为这条边的第一个node的名称

  • • 第二个为这条边的条件

  • • 第三个为条件返回结果的映射(根据条件结果映射到相应的node)

如上面的代码,意思就是往 “oracle” node上添加边,这个node有两条边,一条是往“multiply” node上走,一条是往“END”上走。怎么决定往哪个方向去:条件是 router(后面解释),如果 router 返回的是“multiply”,则往“multiply”方向走,如果 router 返回的是 “end”,则走“END”。

来看下这个函数的源码:

def add_conditional_edges(
    self,
    start_key: str,
    condition: Callable[..., str],
    conditional_edge_mapping: Optional[Dict[str, str]] = None,
) -> None:
    if self.compiled:
        logger.warning(
            "Adding an edge to a graph that has already been compiled. This will "
            "not be reflected in the compiled graph."
        )
    if start_key not in self.nodes:
        raise ValueError(f"Need to add_node `{start_key}` first")
    if iscoroutinefunction(condition):
        raise ValueError("Condition cannot be a coroutine function")
    if conditional_edge_mapping and set(
        conditional_edge_mapping.values()
    ).difference([END]).difference(self.nodes):
        raise ValueError(
            f"Missing nodes which are in conditional edge mapping. Mapping "
            f"contains possible destinations: "
            f"{list(conditional_edge_mapping.values())}. Possible nodes are "
            f"{list(self.nodes.keys())}."
        )

    self.branches[start_key].append(Branch(condition, conditional_edge_mapping))

重点是这一句:self.branches[start_key].append(Branch(condition, conditional_edge_mapping)),给当前node添加分支Branch。

2.2 条件 router

条件代码如下:判断执行结果中是否有 tool_calls 参数,如果有,则返回"multiply",没有,则返回“end”。

def router(state: List[BaseMessage]):
    tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
    if len(tool_calls):
        return "multiply"
    else:
        return "end"

2.3 各node的定义

(1)起始node:oracle

@tool
def multiply(first_number: int, second_number: int):
    """Multiplies two numbers together."""
    return first_number * second_number

model = ChatOpenAI(temperature=0)
model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])

graph = MessageGraph()

def invoke_model(state: List[BaseMessage]):
    return model_with_tools.invoke(state)

graph.add_node("oracle", invoke_model)

这个node是一个带有Tools 的 ChatOpenAI。在LangChain中使用Tools的详细教程请看这篇文章:【AI大模型应用开发】【LangChain系列】5. LangChain入门:智能体Agents模块的实战详解。简单解释就是:这个node的执行结果,将返回是否应该使用绑定的Tools。

(2)multiply

def invoke_tool(state: List[BaseMessage]):
    tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
    multiply_call = None

    for tool_call in tool_calls:
        if tool_call.get("function").get("name") == "multiply":
            multiply_call = tool_call

    if multiply_call is None:
        raise Exception("No adder input found.")

    res = multiply.invoke(
        json.loads(multiply_call.get("function").get("arguments"))
    )

    return ToolMessage(
        tool_call_id=multiply_call.get("id"),
        content=res
    )

graph.add_node("multiply", invoke_tool)

这个node的作用就是执行Tools。

2.4 总体流程






53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询