微信扫码
与创始人交个朋友
我要投稿
RAG 代理工作流程
在这里,我们将使用 LangGraph、Groq-Llama-3 和 Chroma 构建可靠的 RAG 代理。我们将结合以下概念来构建 RAG 代理。
自适应 RAG (论文**)**。我们已经实现了本文中描述的概念,构建了一个路由器,用于将问题路由到不同的检索方法。
校正 RAG (论文**)**。我们已经实现了本文中描述的概念,开发了一个回退机制,用于在检索到的上下文与所问问题不相关时继续进行。
自身 RAG (论文**)**。我们已经实现了本文中描述的概念,开发了一个幻觉评分器,即修正那些产生幻觉或未回答所问问题的答案。
代理背后的基本概念涉及使用语言模型来选择一系列动作。在链中,这个序列被硬编码在代码中。相反,代理利用语言模型作为推理引擎来决定要采取的动作及其顺序。
它包括 3 个组件:
代理可以通过使用 Langchain 的 ReAct 概念或使用 LangGraph 来体现。
Langchain 代理和 LangGraph 之间的权衡:
*可靠性*
ReAct / Langchain 代理:可靠性较低,因为 LLM 需要在每个步骤上做出正确的决策
LangGraph:可靠性更高,因为控制流已经设置好,LLM 在每个节点上有具体的任务
*灵活性*
ReAct / Langchain 代理:更灵活,因为 LLM 可以选择任何动作序列
LangGraph:灵活性较低,因为动作受限于在每个节点上设置控制流
*与较小 LLM 的兼容性*
ReAct / Langchain 代理:兼容性较差
LangGraph:兼容性较好
在这里,我们使用 LangGraph 创建了代理。
LangChain 是一个用于开发由语言模型驱动的应用程序的框架。它支持以下应用程序:
LangGraph 是一个扩展 LangChain 的库,为 LLM 应用程序提供了循环计算功能。虽然 LangChain 支持定义计算链(有向无环图或 DAG),但 LangGraph 允许包含循环。这允许更复杂、更像代理的行为,其中 LLM 可以在循环中被调用以确定下一步要采取的动作。
有状态图:LangGraph 围绕着有状态图的概念展开,图中的每个节点代表我们计算的一个步骤,并且图保持一个状态,该状态随着计算的进行而传递和更新。
节点:节点是 LangGraph 的构建块。每个节点代表一个功能或一个计算步骤。我们定义节点来执行特定的任务,例如处理输入、做出决策或与外部 API 进行交互。
边:边连接图中的节点,定义计算的流程。LangGraph 支持条件边,允许您根据图的当前状态动态确定要执行的下一个节点。
Tavily 搜索 API 是针对 LLM 进行优化的搜索引擎,旨在实现高效、快速和持久的搜索结果。与其他搜索 API(如 Serp 或 Google)不同,Tavily 专注于优化搜索,以满足 AI 开发人员和自主 AI 代理的需求。
Groq 提供了针对开发人员的高性能 AI 模型和 API 访问,具有比竞争对手更快的推理速度和更低的成本。
支持的模型
![img](https://miro.medium.com/v2/resize
根据问题,路由器决定是从向量存储中检索上下文还是进行网页搜索。
如果路由器决定将问题定向到向量存储以进行检索,则从向量存储中检索匹配的文档;否则,使用 tavily-api 进行网页搜索。
文档评分器然后将文档评分为相关或不相关。
如果检索到的上下文被评为相关,则使用幻觉评分器检查是否存在幻觉。如果评分器决定响应缺乏幻觉,则将响应呈现给用户。
如果上下文被评为不相关,则进行网页搜索以检索内容。
检索后,文档评分器对从网页搜索生成的内容进行评分。如果发现相关,则使用 LLM 进行综合,然后呈现响应。
嵌入模型:BAAI/bge-base-en-v1.5
LLM:Llama-3-8B
向量存储:Chroma
图/代理:LangGraph
安装所需库
! pip install -U langchain-nomic langchain_community tiktoken langchainhub chromadb langchain langgraph tavily-python gpt4all fastembed langchain-groq
导入所需库
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
实例化嵌入模型
embed_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5")
实例化 LLM
from groq import Groq
from langchain_groq import ChatGroq
from google.colab import userdata
llm = ChatGroq(temperature=0,
model_name="Llama3-8b-8192",
api_key=userdata.get("GROQ_API_KEY"),)
下载数据
urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
print(f"len of documents :{len(docs_list)}")
将文档分块以与 LLM 上下文窗口同步
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=512, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
print(f"length of document chunks generated :{len(doc_splits)}")
加载文档到向量存储
vectorstore = Chroma.from_documents(documents=doc_splits,
embedding=embed_model,
collection_name="local-rag")
实例化检索器
retriever = vectorstore.as_retriever(search_kwargs={"k":2})
实现路由器
import time
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.output_parsers import StrOutputParser
prompt = PromptTemplate(
template="""system You are an expert at routing a
user question to a vectorstore or web search. Use the vectorstore for questions on LLM agents,
prompt engineering, and adversarial attacks. You do not need to be stringent with the keywords
in the question related to these topics. Otherwise, use web-search. Give a binary choice 'web_search'
or 'vectorstore' based on the question. Return the a JSON with a single key 'datasource' and
no premable or explaination. Question to route: {question} assistant""",
input_variables=["question"],
)
start = time.time()
question_router = prompt | llm | JsonOutputParser()
question = "llm agent memory"
print(question_router.invoke({"question": question}))
end = time.time()
print(f"The time required to generate response by Router Chain in seconds:{end - start}")
#############################RESPONSE ###############################
{'datasource': 'vectorstore'}
The time required to generate response by Router Chain in seconds:0.34175705909729004
实现生成链
prompt = PromptTemplate(
template="""system You are an assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know.
Use three sentences maximum and keep the answer concise user
Question: {question}
Context: {context}
Answer: assistant""",
input_variables=["question", "document"],
)
# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# Chain
start = time.time()
rag_chain = prompt | llm | StrOutputParser()
#############################RESPONSE##############################
The time required to generate response by the generation chain in seconds:1.0384225845336914
The agent memory in the context of LLM-powered autonomous agents refers to the ability of the agent to learn from its past experiences and adapt to new situations.
实现检索评分器
#
prompt = PromptTemplate(
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> 你是一个评分员,评估检索到的文档与用户问题的相关性。如果文档包含与用户问题相关的关键词,将其评分为相关。这不需要是一个严格的测试。目标是过滤掉错误的检索结果。\n
给出一个二进制分数 'yes' 或 'no' 来指示文档是否与问题相关。\n
将二进制分数作为一个带有单个键 'score' 的 JSON 提供,不包含任何前言或解释。\n
<|eot_id|><|start_header_id|>user<|end_header_id|>
这是检索到的文档:\n\n {document} \n\n
这是用户问题:{question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
""",
input_variables=["question", "document"],
)
start = time.time()
retrieval_grader = prompt | llm | JsonOutputParser()
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content
print(retrieval_grader.invoke({"question": question, "document": doc_txt}))
end = time.time()
print(f"检索评分器生成响应所需的时间(秒):{end - start}")
############################响应###############################
{'score': 'yes'}
检索评分器生成响应所需的时间(秒):0.8115921020507812
实现幻觉评分器
# 提示
prompt = PromptTemplate(
template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|> 你是一个评分员,评估答案是否基于一组事实。给出一个二进制 'yes' 或 'no' 分数来指示答案是否基于一组事实。将二进制分数作为一个带有单个键 'score' 的 JSON 提供,不包含任何前言或解释。 <|eot_id|><|start_header_id|>user<|end_header_id|>
这些是事实:
\n ------- \n
{documents}
\n ------- \n
这是答案:{generation} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
input_variables=["generation", "documents"],
)
start = time.time()
hallucination_grader = prompt | llm | JsonOutputParser()
hallucination_grader_response = hallucination_grader.invoke({"documents": docs, "generation": generation})
end = time.time()
print(f"生成链生成响应所需的时间(秒):{end - start}")
print(hallucination_grader_response)
####################################响应#################################
生成链生成响应所需的时间(秒):1.020448923110962
{'score': 'yes'}
实现答案评分器
# 提示
prompt = PromptTemplate(
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> 你是一个评分员,评估答案是否有助于解决问题。给出一个二进制分数 'yes' 或 'no' 来指示答案是否有助于解决问题。将二进制分数作为一个带有单个键 'score' 的 JSON 提供,不包含任何前言或解释。
<|eot_id|><|start_header_id|>user<|end_header_id|> 这是答案:
\n ------- \n
{generation}
\n ------- \n
这是问题:{question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
input_variables=["generation", "question"],
)
start = time.time()
answer_grader = prompt | llm | JsonOutputParser()
answer_grader_response = answer_grader.invoke({"question": question,"generation": generation})
end = time.time()
print(f"答案评分器生成响应所需的时间(秒):{end - start}")
print(answer_grader_response)
##############################响应###############################
答案评分器生成响应所需的时间(秒):0.2455885410308838
{'score': 'yes'}
实现网络搜索工具
import os
from langchain_community.tools.tavily_search import TavilySearchResults
os.environ['TAVILY_API_KEY'] = "YOUR API KEY"
web_search_tool = TavilySearchResults(k=3)
定义图状态:表示图的状态。
定义以下属性:
问题
生成:LLM 生成
网络搜索:是否添加搜索
文档:文档列表
from typing_extensions import TypedDict
from typing import List
### 状态
class GraphState(TypedDict):
question : str
generation : str
web_search : str
documents : List[str]
定义节点
from langchain.schema import Document
def retrieve(state):
"""
从向量存储中检索文档
Args:
state (dict): 当前图状态
返回:
state (dict): 新增了一个名为 documents 的键到 state 字典中,其中包含检索到的文档
“”“
print("---RETRIEVE---")
question = state["question"]
# Retrieval
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
def generate(state):
"""
使用 RAG 在检索到的文档上生成答案
Args:
state (dict): 当前图状态
Returns:
state (dict): 新增了一个名为 generation 的键到 state 字典中,其中包含 LLM 生成的内容
"""
print("---生成---")
question = state["question"]
documents = state["documents"]
# RAG 生成
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def grade_documents(state):
"""
确定检索到的文档是否与问题相关
如果任何文档不相关,我们将设置一个标志来运行网络搜索
Args:
state (dict): 当前图状态
Returns:
state (dict): 过滤掉不相关文档并更新 web_search 状态
"""
print("---检查文档是否与问题相关---")
question = state["question"]
documents = state["documents"]
# 对每个文档进行评分
filtered_docs = []
web_search = "否"
for d in documents:
score = retrieval_grader.invoke({"question": question, "document": d.page_content})
grade = score['score']
# 文档相关
if grade.lower() == "是":
print("---评分:文档相关---")
filtered_docs.append(d)
# 文档不相关
else:
print("---评分:文档不相关---")
# 我们不将文档包括在 filtered_docs 中
# 我们设置一个标志来指示我们要运行网络搜索
web_search = "是"
continue
return {"documents": filtered_docs, "question": question, "web_search": web_search}
def web_search(state):
"""
基于问题进行网络搜索
Args:
state (dict): 当前图状态
Returns:
state (dict): 将网络搜索结果附加到文档中
"""
print("---网络搜索---")
question = state["question"]
documents = state["documents"]
# 网络搜索
docs = web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
if documents is not None:
documents.append(web_results)
else:
documents = [web_results]
return {"documents": documents, "question": question}
定义边的条件
def route_question(state):
"""
将问题路由到网络搜索或 RAG。
Args:
state (dict): 当前图状态
Returns:
str: 下一个要调用的节点
"""
print("---路由问题---")
question = state["question"]
source = question_router.invoke({"question": question})
if source['datasource'] == 'web_search':
print("---将问题路由到网络搜索---")
return "websearch"
elif source['datasource'] == 'vectorstore':
print("---将问题路由到 RAG---")
return "vectorstore"
def decide_to_generate(state):
"""
确定是否生成答案,或添加网络搜索
Args:
state (dict): 当前图状态
Returns:
str: 下一个要调用的节点的二进制决策
"""
print("---评估评分文档---")
question = state["question"]
web_search = state["web_search"]
filtered_documents = state["documents"]
if web_search == "是":
# 所有文档都已经过滤,检查相关性
# 我们将生成一个新的查询
print("---决策:所有文档与问题不相关,包括网络搜索---")
return "websearch"
else:
# 我们有相关的文档,所以生成答案
print("---决策:生成---")
return "generate"
def grade_generation_v_documents_and_question(state):
"""
确定生成的内容是否基于文档并回答问题。
Args:
state (dict): 当前图状态
Returns:
str: 下一个要调用的节点的决策
"""
print("---检查幻觉---")
question = state["question"]
documents = state["documents"]
generation = state["generation"]
score = hallucination_grader.invoke({"documents": documents, "generation": generation})
grade = score['score']
# 检查幻觉
if grade == "yes":
print("---决策:生成内容基于文档---")
# 检查问答
print("---生成内容评分 vs 问题---")
score = answer_grader.invoke({"question": question,"generation": generation})
grade = score['score']
if grade == "yes":
print("---决策:生成内容回答了问题---")
return "有用的"
else:
print("---决策:生成内容未回答问题---")
return "无用的"
else:
pprint("---决策:生成内容不基于文档,重新尝试---")
return "不支持的"
添加节点
from langgraph.graph import END, StateGraph
workflow = StateGraph(GraphState)
# 定义节点
workflow.add_node("websearch", web_search) # 网络搜索
workflow.add_node("retrieve", retrieve) # 检索
workflow.add_node("grade_documents", grade_documents) # 评分文档
workflow.add_node("generate", generate) # 生成
设置入口点和结束点
workflow.set_conditional_entry_point(
route_question,
{
"websearch": "websearch",
"vectorstore": "retrieve",
},
)
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"websearch": "websearch",
"generate": "generate",
},
)
workflow.add_edge("websearch", "generate")
workflow.add_conditional_edges(
"generate",
grade_generation_v_documents_and_question,
{
"not supported": "generate",
"useful": END,
"not useful": "websearch",
},
)
编译工作流程
app = workflow.compile()
测试工作流程
from pprint import pprint
inputs = {"question": "什么是提示工程?"}
for output in app.stream(inputs):
for key, value in output.items():
pprint(f"运行完成:{key}:")
pprint(value["generation"])
########################回应##############################
---路由问题---
什么是提示工程?
{'数据源': '向量存储'}
向量存储
---路由问题到RAG---
---检索---
'运行完成:检索:'
---检查文档与问题的相关性---
---评分:文档相关---
---评分:文档相关---
---评估已评分的文档---
---决策:生成---
'运行完成:评分文档:'
---生成---
---检查幻觉---
---决策:生成内容基于文档---
---生成内容评分 vs 问题---
---决策:生成内容回答了问题---
'运行完成:生成:'
('提示工程是指通过与大型语言模型交流来引导其行为以实现期望的结果,而无需更新模型权重。这是一门需要大量实验和启发式的经验科学。')
针对不同问题测试工作流程
app = workflow.compile()
# 测试
from pprint import pprint
inputs = {"question": "熊队在NFL选秀中预计首轮选秀谁?"}
for output in app.stream(inputs):
for key, value in output.items():
pprint(f"运行完成:{key}:")
pprint(value["generation"])
#############################回应##############################
---路由问题---
熊队在NFL选秀中预计首轮选秀谁?
{'数据源': '网络搜索'}
网络搜索
---路由问题到网络搜索---
---网络搜索---
'运行完成:websearch:'
---生成---
---检查幻觉---
---决策:生成内容基于文档---
---生成内容评分 vs 问题---
---决策:生成内容回答了问题---
'运行完成:生成:'
('根据提供的背景,芝加哥熊队预计将在NFL选秀中用第一顺位选秀南加州大学的四分卫Caleb Williams。')
针对不同问题测试工作流程
app = workflow.compile()
#
inputs = {"question": "代理记忆有哪些类型?"}
for output in app.stream(inputs):
for key, value in output.items():
pprint(f"运行完成:{key}:")
pprint(value["generation"])
###########################回应############################
---路由问题---
代理记忆有哪些类型?
{'数据源': '向量存储'}
向量存储
---路由问题到RAG---
---检索---
'运行完成:检索:'
---检查文档与问题的相关性---
---评分:文档相关---
---评分:文档不相关---
---评估已评分的文档---
---决策:所有文档与问题不相关,包括网络搜索---
'运行完成:评分文档:'
---网络搜索---
'运行完成:websearch:'
---生成---
---检查幻觉---
---决策:生成内容基于文档---
---生成等级与问题---
---决策:生成解决问题---
'完成运行:生成:'
('文本提到以下类型的代理记忆:\n'
'\n'
'1. 短期记忆(STM)或工作记忆:它存储代理当前意识到并需要执行复杂认知任务所需的信息。\n'
'2. 长期记忆(LTM):它可以存储信息长达数天至数十年,具有基本无限的存储容量。')
可视化代理/图
!apt-get install python3-dev graphviz libgraphviz-dev pkg-config
!pip install pygraphviz
from IPython.display import Image
Image(app.get_graph().draw_png())
LangGraph 是一个灵活的工具,旨在利用LLM构建复杂的、有状态的应用程序。初学者可以通过掌握其基本原理并参与基本示例来利用其功能进行项目开发。重点是要专注于管理状态、处理条件边缘,并确保图中没有死胡同节点。
在我看来,与 ReAct 代理相比,这更有益,因为我们可以完全控制工作流程,而不是让代理做决定。
53AI,企业落地应用大模型首选服务商
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-03-30
2024-04-26
2024-05-10
2024-04-12
2024-05-28
2024-05-14
2024-04-25
2024-07-18
2024-04-26
2024-05-06
2024-12-22
2024-12-21
2024-12-21
2024-12-21
2024-12-21
2024-12-20
2024-12-20
2024-12-19