微信扫码
与创始人交个朋友
我要投稿
有多少次你遵循建议优化机器学习模型,却发现这些解决方案并不完全符合你的特定需求?我知道答案是:很多次,或者更确切地说,一直都是这样。原因在于,一切都取决于你的数据。你需要不断测试、失败、再测试,直到找到适合你特定情况的最佳方法。在本文中,我将介绍四种使用私有数据和私有基础设施来优化检索增强生成 (RAG) 以增强 AI 解决方案的策略。
“我没有失败,我只是找到了 10,000 种行不通的方法。” —— 托马斯·爱迪生
在我之前的文章中,我解释了如何使用检索增强生成 (RAG) 策略将你的私有知识包含到像 LLaMA 3 这样的公共模型中,而无需与他人共享敏感信息。在你的基础设施上使用包含你的私有数据的 RAG 的优势是显而易见的;然而,它们的实现需要在几个方面进行大量的调整。
我将首先快速概述一下 RAG,我在上一篇文章中已经解释过。它主要有两个过程。第一个是“数据收集过程”,从不同的来源收集数据,将其转换为文本,将其分割成更小、更连贯、语义相关的片段,并将结果存储在向量数据库中。第二个是“推理过程”,它首先从用户查询开始,然后使用第一个过程的结果来识别相关的数据片段,最后丰富模型的上下文以获得输出。
你可以在下图中详细查看这两个过程:
RAG 过程。图片由作者提供
在我的上一篇文章中,我使用了 ColdF 这家虚构公司的数据来创建“数据收集”和“推理”过程。在本文中,我将解释评估和优化这些过程结果的基本方法。
首先,让我们确定 RAG 过程中的关键点:
分块方法:优化块大小可确保数据段有意义且与上下文相关。
嵌入模型:选择和微调模型以改进语义表示。
向量搜索方法:选择有效的相似性度量和搜索参数。
馈送到模型的最终提示:精心设计有效的提示以提高输出质量。
在确定每个改进组件后,该策略会比较每个组件的不同配置版本,以确定哪个版本性能更好。它涉及运行这两个版本并根据预定义的指标衡量它们的性能。但是我们如何衡量性能?用什么指标?为了回答这个问题,我们参考了论文 “RAGAS:检索增强生成的自动评估 ”¹. 该论文提出了三个关键指标:
真实性(Faithfulness): 检查答案中的信息是否与上下文中提供的信息相匹配。如果答案中所说的所有内容都可以直接从上下文中找到或推断出来,则该答案是真实的。例如,如果上下文是“在我们 5 月份访问里斯本期间,我和爱丽丝去了阿尔法玛、拜罗阿尔托、贝伦塔以及许多其他地方。”,而答案是“5 月,爱丽丝去了阿尔法玛、拜罗阿尔托、贝伦塔以及许多其他地方。”我们可以说上下文支持所有提取的陈述,因此真实性得分为 100%。但是,如果答案是“5 月,爱丽丝去了阿尔法玛和圣乔治城堡。”,则从答案中提取的两条陈述(例如,“爱丽丝去了阿尔法玛”和“爱丽丝去了圣若热城堡”)只有一条陈述得到上下文的支持,这意味着真实性得分为 50%。
答案相关性(Answer Relevance): 检查生成的答案是否完整并直接回答了所提出的问题。信息是否正确并不重要。例如,如果问题是“葡萄牙的首都是哪里?”,而答案是“里斯本是葡萄牙的首都”,则该答案是相关的,因为它直接回答了问题。如果答案是“里斯本是一个美丽的城市,有许多景点”,则它可能部分相关,但包含了不需要直接回答问题的额外信息。该指标可确保答案保持重点突出。
上下文相关性(Context Relevance): 检查上下文中提供的信息对回答问题的帮助程度。它确保仅包含必要且相关的详细信息,并删除任何对直接回答问题没有帮助的额外、无关信息。例如,如果问题是“5 月,爱丽丝在里斯本参观了哪些地方?”,而上下文是“在我们 5 月份访问里斯本期间,爱丽丝去了阿尔法玛、拜罗阿尔托、贝伦塔以及许多其他地方。”,则该上下文高度相关,因为它仅提供了有关爱丽丝 5 月份访问过哪些地方的必要信息。但是,如果上下文是“在我们 5 月份访问里斯本期间,爱丽丝遇到了许多有趣的人,吃了美味的食物,并去了许多地方。”,则该上下文包含了对回答问题不必要的额外细节,被认为是不相关的。该指标可确保提供的信息直接有助于回答问题,避免不必要的细节。此指标也称为上下文精度(Context Precision)。
该论文还解释了如何通过提示大型语言模型 (LLM) 以全自动的方式测量这些指标。
我将在本次评估中使用的库 Ragas 对这些关键指标进行了改进,添加了一个新指标:
上下文召回率(Context Recall): 该指标以与上下文相关性相同的方式衡量上下文与实际答案之间的一致性;然而,它使用的是实际答案而不是生成的答案。需要有一个真实值才能获得此指标。为了评估这些策略的有效性,我准备了一组 10 个问题,以及基于 ColdF 数据的实际答案。
真实性和答案相关性是生成器指标(Generator Metrics),分别用于衡量幻觉和答案与问题的直接程度。
上下文相关性和上下文召回率是检索器指标(Retriever Metrics),分别用于衡量从向量数据库中检索正确数据块的能力和获取所有必要信息的能力。
基本上,要评估我们之前提出的四个指标,我们需要问题、生成的答案、上下文和实际答案。
我将使用 LangChain 来实现 RAG 流程。要运行代码,我们需要安装 Python(版本 3.11.9)和以下库:
ollama==0.2.1
chromadb==0.5.0
transformers==4.41.2
torch==2.3.1
langchain==0.2.0
ragas==0.1.9
以下是使用 LangChain 的代码片段:
# Import necessary libraries and modules
from langchain.embeddings.base import Embeddings
from transformers import BertModel, BertTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer, RobertaModel, RobertaTokenizer
from langchain.prompts import ChatPromptTemplate
from langchain_text_splitters import MarkdownHeaderTextSplitter
import requests
from langchain_chroma import Chroma
from langchain import hub
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_community.chat_models import ChatOllama
from operator import itemgetter
# Define a custom embedding class using the DPRQuestionEncoder
class DPRQuestionEncoderEmbeddings(Embeddings):
show_progress: bool = False
"""Whether to show a tqdm progress bar. Must have `tqdm` installed."""
def __init__(self, model_name: str = 'facebook/dpr-question-encoder-single-nq-base'):
# Initialize the tokenizer and model with the specified model name
self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_name)
self.model = DPRQuestionEncoder.from_pretrained(model_name)
def embed(self, texts):
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
embeddings = []
if self.show_progress:
try:
from tqdm import tqdm
iter_ = tqdm(texts, desc="Embeddings")
except ImportError:
logger.warning(
"Unable to show progress bar because tqdm could not be imported. "
"Please install with `pip install tqdm`."
)
iter_ = texts
else:
iter_ = texts
for text in iter_:
# Tokenize the input text
inputs = self.tokenizer(text, return_tensors='pt')
# Generate embeddings using the model
outputs = self.model(**inputs)
# Extract the embedding and convert it to a list
embedding = outputs.pooler_output.detach().numpy()[0]
embeddings.append(embedding.tolist())
return embeddings
def embed_documents(self, documents):
return self.embed(documents)
def embed_query(self, query):
return self.embed([query])[0]
# Define a template for generating prompts
template = """
### CONTEXT
{context}
### QUESTION
Question: {question}
### INSTRUCTIONS
使用上方 CONTEXT markdown 文本回答用户问题。
提供简短的答案。
仅根据 CONTEXT 中的事实回答问题。
如果 CONTEXT 不包含回答问题的必要信息,则返回“NONE”。
"""
# 使用模板创建一个 ChatPromptTemplate 实例
prompt = ChatPromptTemplate.from_template(template)
# 从 URL 获取文本数据
url = "https://raw.githubusercontent.com/cgrodrigues/rag-intro/main/coldf_secret_experiments.txt"
response = requests.get(url)
if response.status_code == 200:
text = response.text
else:
raise Exception(f"Failed to fetch the file: {response.status_code}")
# 定义用于分割 markdown 文本的标题
headers_to_split_on = [
("#", "Header 1")
]
# 使用指定的标题创建一个 MarkdownHeaderTextSplitter 实例
markdown_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on, strip_headers=False
)
# 使用 markdown 分割器分割文本
docs_splits = markdown_splitter.split_text(text)
# 初始化一个聊天模型
llm = ChatOllama(model="llama3")
# 使用自定义嵌入从文档中创建一个 Chroma 向量存储
vectorstore = Chroma.from_documents(documents=docs_splits, embedding=DPRQuestionEncoderEmbeddings())
# 从向量存储中创建一个检索器
retriever = vectorstore.as_retriever()
# 定义一个格式化文档以供显示的函数
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# 创建一个检索增强生成 (RAG) 链
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"answer": prompt | llm | StrOutputParser(),
"context": itemgetter("context")}
)
# 使用问题调用 RAG 链
result = rag_chain.invoke("Who led the Experiment 1?")
print(result)
在这段代码的末尾,定义了一个 RAG 链,可以使用以下代码来评估指标:
# 导入必要的库和模块
import pandas as pd
from datasets import Dataset
from ragas import evaluate
from ragas.metrics import (
context_precision,
faithfulness,
answer_relevancy,
context_recall
)
from langchain_community.chat_models import ChatOllama
def get_questions_answers_contexts(rag_chain):
""" 读取问题和答案列表并返回一个
用于评估的 ragas 数据集 """
# 文件的 URL
url = 'https://raw.githubusercontent.com/cgrodrigues/rag-intro/main/coldf_question_and_answer.psv'
# 从 URL 获取文件
response = requests.get(url)
data = response.text
# 将数据按行分割
lines = data.split('\n')
# 通过管道符号分割每一行并创建元组
rag_dataset = []
for line in lines[1:10]: # 仅限前 10 个问题
if line.strip():# 确保该行不为空
question, reference_answer = line.split('|')
result = rag_chain.invoke(question)
generated_answer = result['answer']
contexts = result['context']
rag_dataset.append({
"question": question,
"answer": generated_answer,
"contexts": [contexts],
"ground_truth": reference_answer
})
rag_df = pd.DataFrame(rag_dataset)
rag_eval_datset = Dataset.from_pandas(rag_df)
# 返回 lragas 数据集
return rag_eval_datset
def get_metrics(rag_dataset):
""" 对于 RAG 数据集,计算指标的真实性、
答案相关性、上下文精度和上下文召回率 """
# 我们要评估的指标列表
metrics = [
faithfulness,
answer_relevancy,
context_precision,
context_recall
]
# 我们将使用带有 LLaMA 3 模型的本地 ollama
langchain_llm =ChatOllama(model="llama3")
langchain_embeddings = DPRQuestionEncoderEmbeddings('facebook/dpr-question_encoder-single-nq-base')
# 返回指标
results = evaluate(rag_dataset, metrics=metrics, llm=langchain_llm, embeddings=langchain_embeddings)
return results
# 获取 RAG 数据集
rag_dataset = get_questions_answers_contexts(rag_chain)
# 计算指标
results = get_metrics(rag_dataset)
print(results)
作为此代码结果的示例:
{
'faithfulness': 0.8611,
'answer_relevancy': 0.8653,
'context_precision': 0.7778,
'context_recall': 0.8889
}
如前所述,前两个指标(例如,真实性和答案相关性)与生成相关联。这意味着要改进这些指标,有必要更改语言模型或提供给模型的提示。最后两个指标(例如,上下文精度和上下文召回率)与检索相关,这意味着要改进这些指标,有必要研究如何存储、索引和选择文档。
分块方法确保数据被分割成最优的片段以便检索。该范式涉及尝试不同的块大小,以找到在太小(丢失上下文)和太大(压垮检索系统)之间取得平衡。在基线中,我们基于每个实验对文档进行分块;这意味着实验的某些部分可能会被稀释,并且最终嵌入中没有体现出来。解决这种情况的一种可能方法是使用父文档检索器。此方法不仅检索特定的相关文档片段或段落,还检索其父文档。此方法确保保留相关片段周围的上下文。以下代码用于测试此方法:
# Import necessary libraries and modules
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import InMemoryStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Create the parent document retriever
parent_document_retriever = ParentDocumentRetriever(
vectorstore = Chroma(collection_name="parents",
embedding_function=DPRQuestionEncoderEmbeddings('facebook/dpr-question_encoder-single-nq-base')),
docstore = InMemoryStore(),
child_splitter = RecursiveCharacterTextSplitter(chunk_size=200),
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1500),
)
parent_document_retriever.add_documents(docs_splits)
# Create a retrieval-augmented generation (RAG) chain
rag_chain_pr = (
{"context": parent_document_retriever | format_docs, "question": RunnablePassthrough()}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"answer": prompt | llm | StrOutputParser(),
"context": itemgetter("context")}
)
# Get the RAG dataset
rag_dataset = get_questions_answers_contexts(rag_chain_pr)
# Calculate the metrics
results = get_metrics(rag_dataset)
print(results)
结果如下:
{
'faithfulness': 0.6667,
'answer_relevancy': 0.4867,
'context_precision': 0.7778,
'context_recall': 0.6574
}
结果表明,这种变化无助于提高性能。上下文召回率的下降表明检索过程没有正常工作,并且上下文没有完整的信息。忠实度和答案相关性指标的变化是由于上下文不佳造成的。在这种情况下,我们可以尝试评估另一种分块和检索方法。
嵌入模型将文本块转换为密集向量表示。不同的模型可以针对不同的主题进行训练,有时可以改进嵌入效果。选择嵌入方法时应考虑计算效率和嵌入质量之间的平衡。
我们比较了不同的嵌入模型,例如密集段落检索("facebook/dpr-question_encoder-single-nq-base")、Sentence-BERT("paraphrase-MiniLM-L6-v2")或 Chroma 的默认模型("all-MiniLM-L6-v2")。每个模型都有其优势,在特定领域的数据上对其进行评估有助于确定哪种模型能够提供最准确的语义表示。
要更改嵌入模型,需要定义一个新类 "SentenceBertEncoderEmbeddings"。这个新类实现了 Sentence-BERT 模型。新类将替换我们之前的嵌入 "DPRQuestionEncoderEmbeddings",后者实现了密集段落检索模型。以下是使用 Sentence-BERT 模型进行测试的代码:
# Import necessary libraries and modules
import pandas as pd
from datasets import Dataset
from ragas import evaluate
from ragas.metrics import (
context_precision,
faithfulness,
answer_relevancy,
context_recall
)
from langchain_community.chat_models import ChatOllama
from sentence_transformers import SentenceTransformer
# Define a custom embedding class using the DPRQuestionEncoder
class SentenceBertEncoderEmbeddings(Embeddings):
show_progress: bool = False
"""Whether to show a tqdm progress bar. Must have `tqdm` installed."""
def __init__(self, model_name: str = 'paraphrase-MiniLM-L6-v2'):
# Initialize the tokenizer and model with the specified model name
self.model = SentenceTransformer(model_name)
def embed(self, texts):
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
embeddings = []
if self.show_progress:
try:
from tqdm import tqdm
iter_ = tqdm(texts, desc="Embeddings")
except ImportError:
logger.warning(
"Unable to show progress bar because tqdm could not be imported. "
"Please install with `pip install tqdm`."
)
iter_ = texts
else:
iter_ = texts
for text in iter_:
embeddings.append(self.model.encode(text).tolist())
return embeddings
def embed_documents(self, documents):
return self.embed(documents)
def embed_query(self, query):
return self.embed([query])[0]
# Create a Chroma vector store from the documents using the custom embeddings
vectorstore = Chroma.from_documents(documents=docs_splits, embedding=SentenceBertEncoderEmbeddings())
# Create a retriever from the vector store
retriever = vectorstore.as_retriever()
# Create a retrieval-augmented generation (RAG) chain
rag_chain_ce = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"answer": prompt | llm | StrOutputParser(),
"context": itemgetter("context")})
# Get the RAG dataset
rag_dataset = get_questions_answers_contexts(rag_chain_ce)
# Calculate the metrics
results = get_metrics(rag_dataset)
print(results)
结果如下:
{
'faithfulness': 0.5278,
'answer_relevancy': 0.5306,
'context_precision': 0.5556,
'context_recall': 0.7997
}
在这种情况下,编码器的变化表示指标性能下降。这是预料之中的,因为 DPR 的检索精度高于 Sentence-BERT,这使得它更适合我们的情况,即精确的文档检索至关重要。当切换到 Sentence-BERT 时,'faithfulness' 和 'answer relevancy' 指标的显著下降突出了为需要高检索精度的任务选择合适的嵌入模型的重要性。
向量搜索方法根据相似度度量检索最相关的块。常见的方法包括欧氏距离 (L2)、余弦相似度等。更改此搜索方法可以提高最终输出质量。
代码如下:
# Import necessary libraries and modules
import pandas as pd
from datasets import Dataset
from ragas import evaluate
from ragas.metrics import (
context_precision,
faithfulness,
answer_relevancy,
context_recall
)
from langchain_community.chat_models import ChatOllama
# Create a Chroma vector store from the documents
# using the custom embeddings and also changing to
# cosine similarity search
vectorstore = Chroma.from_documents(collection_name="dist",
documents=docs_splits,
embedding=DPRQuestionEncoderEmbeddings(),
collection_metadata={"hnsw:space": "cosine"})
# Create a retriever from the vector store
retriever = vectorstore.as_retriever()
# Create a retrieval-augmented generation (RAG) chain
rag_chain_dist = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"answer": prompt | llm | StrOutputParser(),
"context": itemgetter("context")})
# Get the RAG dataset
rag_dataset = get_questions_answers_contexts(rag_chain_dist)
# Calculate the metrics
results = get_metrics(rag_dataset)
print(results)
结果如下:
{
'faithfulness': 0.9444,
'answer_relevancy': 0.8504,
'context_precision': 0.6667,
'context_recall': 0.8889
}
'faithfulness' 的提高表明,即使 'context precision' 下降了,使用余弦相似度进行向量搜索也增强了检索到的文档与查询的一致性。总体较高的 'faithfulness' 和 'context recall' 表明,在这种情况下,余弦相似度是一种更有效的向量搜索方法,这支持了向量搜索方法选择在优化检索性能方面的重要性。
最终提示的构建涉及将检索到的数据集成到模型的查询中。提示中的微小变化会对结果产生重大影响,使其成为一个反复试验的过程。在提示中提供示例可以引导模型生成更准确、更相关的输出。
优化检索增强生成 (RAG) 管道是一个迭代过程,很大程度上取决于应用程序的特定数据和上下文。在本文中,我们探讨了四种关键策略:改进分块方法、选择和微调嵌入模型、选择有效的向量搜索方法以及创建精确的提示。这些组件中的每一个都在提高 RAG 系统的性能方面发挥着至关重要的作用。
结果表明,没有万能的解决方案。例如,虽然密集段落检索 (DPR) 在我们的上下文中优于 Sentence-BERT,但这可能会因数据集或要求的不同而异。类似地,切换到余弦相似度进行向量搜索会产生更好的置信度和上下文回忆,这表明即使在检索过程中进行细微的更改也会产生影响。
优化 RAG 管道的旅程包括持续测试、从失败中学习以及进行明智的调整。通过采用这种迭代方法,您可以定制您的 AI 解决方案,以更有效地满足您的特定需求。请记住,成功的关键在于了解您的数据、试验不同的策略并持续改进您的流程。
订阅我的个人资料和电子邮件列表,以随时了解我的最新作品。我们可以共同应对 AI 优化的复杂性,并充分发挥您的数据驱动型解决方案的潜力。
[1]Es, S., James, J., Espinosa-Anke, L., & Schockaert, S. (2023). RAGAS: Automated Evaluation of Retrieval Augmented Generation. Exploding Gradients, CardiffNLP, Cardiff University, AMPLYFI.
53AI,企业落地应用大模型首选服务商
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-05-14
2024-04-26
2024-03-30
2024-04-12
2024-05-10
2024-05-22
2024-07-18
2024-04-25
2024-05-28
2024-04-26
2024-11-05
2024-11-05
2024-11-04
2024-11-04
2024-11-04
2024-11-01
2024-11-01
2024-10-31