AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


SELF-RAG:通过自我反思进行批判学习如何检索、生成和整理
发布日期:2024-05-14 03:15:19 浏览次数: 2066 来源:人工智障与神经病网络研究所


  • 机构:华盛顿大学、艾伦人工智能研究所、IBM
  • 代码:https://github.com/AkariAsai/self-rag
  • 论文:https://arxiv.org/abs/2310.11511
  • 发布:2023.10

现有问题:LLM的事实不准确性

LLM经常产生幻觉,特别是在长尾情况下,它们的知识变得过时,并且缺乏归因。

检索增强生成是否是万能药?

传统RAG可以无差别地检索和合并一定数量的检索段落,无论检索是否必要或段落是否相关,都可能导致无用的生成。

1. Self-RAG?

自我反思检索增强生成(Self-RAG),本质是微调两个大模型。一个是评估大模型(critic model),另一个是生成大模型(generator model)。微调的内容不是领域知识,而是作为RAG应用所应具备的技能,比如什么时候去检索、生成内容是否有幻觉、如何确保生成内容的真实可靠可用。我分别从模型训练和推理去介绍Self-RAG。

Critic模型训练

学习面对各种各样的query时,1.是否需要检索,2.检索知识是否和Query相关(准确性),3.基于检索生成的内容是否真的来自检索知识还是大模型自己yy的(事实支持性),以及4.检索生成的内容是否真的对用户Query有帮助(有用性)。训练critic模型的目标函数是最大化似然度,其中 是数据集, 是自省标记(reflection tokens)。从公式可以看到,根据x和y来计算r的条件生成概率。而自省标记r包括4种:粗体的文本表示最理想的自省标记。x、y、d分别表示输入、输出和相关文档段落。

Generator模型训练

有了Critic模型生成自省tokens能力作为基础,进一步构建增强(Augmented)训练数据,下面是构建流程,其中 表示critic模型, 是检索模块, 是相关文档段落。

  • 用Critic模型判断x是否需要检索,并预测 [Retrieve] token值,并把值拼接到x后面;如果是Yes再通过检索模块找出 K 个最相关的文档段落集合
  • 对于每个段落,Critic模型会进一步评估段落和x是否相关,并预测 [IsREL] token值;如果段落是相关的,Critic模型又会进一步评估,段落是否能支持模型的生成,并预测 [IsSUP] token值;最后把这两个token值拼接在检索生成内容y后面
  • 当整个y生成出来后,再预测y的 [IsUSE] token值,并把值拼接到y后,

下面是增强数据的样例备注:文本chunk之间用 <p></p> 包住。

以此,生成整个数据集 ,并基于次数据集进行生成器模型训,目标函数即求x预测【y和r】的条件生成概率的最大对数似然估计因为纪要预测y,也要预测自省标记r,因此需要将r扩进词表中。

汇总

整个Self-RAG的训练过程伪代码如下:

2. 推理流程

大致推理流程如上,我们展开描述一下:

  1. 判断是否需要检索时,当 时,再基于 去检索相关知识片段

  2. 如果需要检索,假设检索出的知识片段集合为 ,对于每个

    注意,在每个时间步都用LLM进行并行推理输出 个不同的 候选集,并且记录他们的得分 然后进行都进行Beam Search(设置Beam大小为),如下简图所示,最终获取 个最优的候选片段序列 分数是 的加权,权重可以认为调整。而对应的自省token的得分也比较简单,看A.3附录即可。

  • 预测判断 的相关性:
  • 当前时间步的检索生成:
  • 预测判断当前生成是否满足事实支持性和可用性:
  • 基于三个自省标记( [IsREL]、[IsSUP]、[IsUSE] )的预测结果,对 进行打分
  • 如果不需要检索,

    • 预测
    • 评估 分数:

    注意:Critic模型是不参与Self-RAG的推理,但它在训练阶段的作用是至关重要的。它确保了Generator模型能够学习到如何生成高质量的输出,并在需要时进行有效的自我评估和批判。

    实践

    论文也开源了微调模型,可以下载一个GGUF版本,并使用llama.cpp进行推理。先安装,

    pip install llama_cpp_python
    pip install huggingface-hub

    然后下载模型

    huggingface-cli download m4r1/selfrag_llama2_7b-GGUF selfrag_llama2_7b.q4_k_m.gguf --local-dir ./model --local-dir-use-symlinks False

    给一个简单的运行示例

    from llama_cpp import Llama

    # 定义模型参数和生成参数
    MODEL_KWARGS = {
        "logits_all": True,
        "n_ctx": 2048,
        "n_gpu_layers": 200
    }
    GENERATE_KWARGS = {
        "temperature": 0.0,
        "top_p": 1.0,
        "max_tokens": 1024,
        "logprobs": 1000
    }

    # 初始化模型
    llm = Llama(model_path="selfrag_llama2_7b.q4_k_m.gguf", **MODEL_KWARGS)

    # 格式化Prompt函数
    def format_prompt(query, paragraph=None):
        """
        格式化查询为模型所需的prompt格式。
        
        :param query: 输入的问题或指令。
        :param paragraph: 可选的,与查询相关的段落信息,用于检索。
        :return: 格式化后的prompt字符串。
        "
    ""
        prompt = "### Instruction:\n{0}\n\n### Response:\n".format(query)
        if paragraph:
            prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
        return prompt

    # 测试问题
    queries = [
        "撰写一首表达对老师的感激之情的短诗",
        "简述一下人工智能在医疗领域的应用"
    ]

    # 测试并打印结果
    for query in queries:
        prompt = format_prompt(query)
        result = llm(prompt, **GENERATE_KWARGS)
        
        # 提取并打印生成的文本
        generated_text = result["choices"][0]["text"]
        print("\nResponse:\n{0}".format(generated_text))
        
        # 如果需要,打印详细信息
        # print(result["choices"][0])



    53AI,企业落地应用大模型首选服务商

    产品:大模型应用平台+智能体定制开发+落地咨询服务

    承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

    联系我们

    售前咨询
    186 6662 7370
    预约演示
    185 8882 0121

    微信扫码

    与创始人交个朋友

    回到顶部

     
    扫码咨询