AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


【LLM & RAG 】大模型在知识图谱问答上的核心算法详细思路及实践
发布日期:2024-08-27 07:32:30 浏览次数: 1840 来源:大模型自然语言处理


前言

本文介绍了一个融合RAG(Retrieval-Augmented Generation)思路的KBQA(Knowledge-Based Question Answering)系统的核心算法及实现步骤。KBQA系统的目标是通过自然语言处理技术,从知识图谱中提取和生成精确的答案。系统的实现包括多个关键步骤:mention识别、实体链接及排序、属性选择及排序、文本拼接以及最终的Text2SQL生成。通过这些步骤,系统能够准确识别用户提出的问题中的关键实体和属性,并生成相应的查询语句,从而从知识图谱或数据库中检索所需的信息。本文将详细介绍每个步骤的实现思路和技术细节,并提供核心算法具体的代码示例和开源地址供参考。

一、mention识别

KBQA中的mention识别是指在用户提出的问题中,识别出与知识库中的实体或概念相对应的词语或短语。mention识别是KBQA系统中至关重要的一步,因为准确的mention识别直接影响后续的实体链接、关系抽取和答案生成等步骤。KBQA中的mention识别的主要方法和技术:

  1. 规则方法

    基于规则的方法通常使用手工设计的规则和模式来识别mention。这些规则可以包括命名实体识别(NER)工具、正则表达式、词典匹配等。

  • NER工具:使用现有的NER工具(如Stanford NER、spaCy)来识别出问题中的实体。
  • 正则表达式:设计特定的正则表达式模式来匹配特定类型的mention。
  • 词典匹配:使用预先构建的实体词典进行匹配,通过查找词典中的条目来识别mention。
  • 统计方法

    基于统计的方法利用大规模的训练数据,通过统计特征来识别mention。这些方法通常需要预处理步骤,如词频统计和n-gram分析。

    • n-gram分析:将问题分割成n-gram(如单词、双词、三词短语等),并计算每个n-gram在知识库中的匹配情况。
    • 词频统计:统计每个词或短语在知识库中的出现频率,并根据频率高低来判断其是否为mention。
  • 机器学习方法

    基于机器学习的方法利用有标签的数据,通过训练分类器来识别mention。这些方法通常需要特征工程和模型训练。

    • 特征工程:提取文本的各种特征,如词性、词向量、上下文信息等。
    • 分类模型:使用机器学习算法(如SVM、随机森林、逻辑回归等)训练分类器,判断一个词或短语是否为mention。
  • 深度学习方法

    基于深度学习的方法利用神经网络模型,通过端到端的方式来识别mention。这些方法可以避免复杂的特征工程,通过大量的数据训练模型来自动提取特征。如:BERT-CRF、LLM等模型等。

  • 本文结合大模型的方法进行mention识别。主要流程如下:

    mention识别SFT数据构造

    原始数据

    q1:莫妮卡·贝鲁奇的代表作?
    select ?x where { <莫妮卡·贝鲁奇> <代表作品> ?x. }
    <西西里的美丽传说> 

    通过一些规则方式,构建sft数据如下:

    [
      {
        "instruction""你是一个实体抽取的专家,请你抽取问句:“莫妮卡·贝鲁奇的代表作?”中的实体。",
        "input""",
        "output""莫妮卡·贝鲁奇"
      },
      {
        "instruction""你是一个实体抽取的专家,请你抽取问句:“《湖上草》是谁的诗?”中的实体。",
        "input""",
        "output""湖上草"
      },
      ...
    ]

    LLM微调mention识别

    本文以LLaMA-Factory框架进行微调,微调脚本如下:

    import json
    import os

    model_name_or_path = "ZhipuAI/glm-4-9b-chat"
    template = "glm4"
    cutoff_len = 256
    num_train_epochs = 8
    train_dataset = "train_ner"
    predict_dataset = "test_ner"
    output_dir = f"saves/{train_dataset}-{predict_dataset}-ep{num_train_epochs}-{cutoff_len}-{template}"
    adapter_name_or_path = output_dir


    do_train = True
    do_predict = True

    train_args = dict(
        stage="sft",  # 进行指令监督微调
        do_train=do_train,
        model_name_or_path=model_name_or_path,
        dataset=train_dataset,
        template=template,  
        finetuning_type="lora",  
        cutoff_len=cutoff_len,
        lora_target="all",  
        output_dir=output_dir,  
        per_device_train_batch_size=4,  
        gradient_accumulation_steps=2,  
        lr_scheduler_type="cosine"
        logging_steps=10,  
        warmup_ratio=0.1
        save_steps=1000
        learning_rate=1e-4,  
        num_train_epochs=num_train_epochs,  
        max_samples=7625,  
        max_grad_norm=1.0
        fp16=True,  
        temperature=0.1,
        ddp_timeout=180000000,
        overwrite_cache=True,
        overwrite_output_dir=True
    )

    predict_args = dict(
        stage="sft",
        do_predict=do_predict,
        model_name_or_path=model_name_or_path,
        adapter_name_or_path=adapter_name_or_path,
        dataset=predict_dataset,
        template=template,
        finetuning_type="lora",
        cutoff_len=cutoff_len,
        per_device_eval_batch_size=2,
        overwrite_cache=True,
        preprocessing_num_workers=16,
        output_dir=f'{output_dir}/predict',
        overwrite_output_dir=True,
        ddp_timeout=180000000,
        temperature=0.1,
        max_samples=1292,
        predict_with_generate=True
    )

    train_args_file = f"config/{train_dataset}-{predict_dataset}-ep{num_train_epochs}-{cutoff_len}-{template}-train.json"
    predict_args_file = f"config/{train_dataset}-{predict_dataset}-ep{num_train_epochs}-{cutoff_len}-{template}-pred.json"

    json.dump(train_args, open(train_args_file, "w", encoding="utf-8"), indent=2)
    json.dump(predict_args, open(predict_args_file, "w", encoding="utf-8"), indent=2)


    if __name__ == '__main__':
        os.system(f'llamafactory-cli train {train_args_file}')
        os.system(f'llamafactory-cli train {predict_args_file}')

    输出示例如:

    question:<篝火圆舞曲>的作曲家属于什么民族?

    mention:篝火圆舞曲

    二、实体链接及实体排序

    中文短文本的实体链指,简称 EL(Entity Linking),是NLP、知识图谱领域的基础任务之一,即对于给定的一个中文短文本(如搜索 Query、微博、对话内容、文章/视频/图片的标题等),EL将其中的实体与给定知识库中对应的实体进行关联。

    针对中文短文本的实体链指存在很大的挑战,主要原因如下:

    1. 口语化严重,导致实体歧义消解困难;
    2. 短文本上下文语境不丰富,须对上下文语境进行精准理解;
    3. 相比英文,中文由于语言自身的特点,在短文本的链指问题上更有挑战。

    EL实现思路-基于“粗排-精排”的两阶段方案

    思路1:

    主要流程描述:

    1. 候选实体召回。通过ES知识库,召回相关实体,把知识库实体的关系转化为:“实体id-实体信息” 和 “实体指称-实体id” 的映射。从原文本的mention文本出发,根据“实体指称-实体id”匹配实体文本召回候选实体。
    2. 候选实体特征提取。首先用指称项分类模型,来预测输入数据的指称项的实体类型。根据候选实体召回结果,对于有召回的实体:用“实体id-实体信息”提取处实体信息,按顺序组织实体信息的文本内容后拼接原始文本丰富实体的语义信息,最后把指称项的实体类型加入构成完整的实体候选集合。对于无召回的实体,就无需进行候选实体排序,直接与排序结果进行后处理整合即可。
    3. 候选实体排序模型。输入标记指称项的原始文本和候选实体信息的拼接,输出指称项和候选实体的匹配程度。
    4. 后处理。“候选实体排序模型”的输出结果

    粗排方式简单,使用ES库进行粗排即可,精排构建一个二分类模型

    训练数据构造形式(正负样本比例1:5):

    {"query""莫妮卡·贝鲁奇的代表作?""query_rewrite""#莫妮卡·贝鲁奇#的代表作?""entity""<莫妮卡·贝鲁奇>""desc""母亲|毕业院校|类型|主演|别名|相关人物|中文名|国籍|作者|外文名|体重|职业|代表作品|出生日期|导演|身高|朋友""label": 1}
    {"query""莫妮卡·贝鲁奇的代表作?""query_rewrite""#莫妮卡·贝鲁奇#的代表作?""entity""\"莫妮卡·贝鲁\" ""desc""中文名""label": 0}
    {"query""莫妮卡·贝鲁奇的代表作?""query_rewrite""#莫妮卡·贝鲁奇#的代表作?""entity""<莫妮卡·贝鲁>""desc""类型|游戏大小|中文名|原版名称|游戏类型""label": 0}
    {"query""莫妮卡·贝鲁奇的代表作?""query_rewrite""#莫妮卡·贝鲁奇#的代表作?""entity""\"莫妮卡贝鲁齐\" ""desc""中文名""label": 0}
    {"query""莫妮卡·贝鲁奇的代表作?""query_rewrite""#莫妮卡·贝鲁奇#的代表作?""entity""<莫妮卡贝鲁齐>""desc""类型|操作指南|基本介绍|中文名|原版名称""label": 0}
    {"query""莫妮卡·贝鲁奇的代表作?""query_rewrite""#莫妮卡·贝鲁奇#的代表作?""entity""\"莫妮卡·安娜·玛丽亚·贝鲁奇\" ""desc""中文名""label": 0}

    字段解释:

    • query:原始问句
    • query_rewrite:重写后的问句
    • entity:链接的实体
    • desc:属性的拼接
    • label:类别标签

    在训练时,我们只选择query_rewritedesc进行拼接,拼接形式如下:

    query_rewrite[SEP]desc
    示例:#莫妮卡·贝鲁奇#的代表作?[SEP]母亲|毕业院校|类型|主演|别名|相关人物|中文名|国籍|作者|外文名|体重|职业|代表作品|出生日期|导演|身高|朋友

    精排模型结构如下:

    计算实体链接得分:

    思路2:

    粗排仍然使用ES进行召回,精排使用一个与思路一相同结构的模型,区别就是数据构造方式不同。训练数据构造形式(正负样本比例1:5):

    {'query''莫妮卡·贝鲁奇的代表作?''mention''莫妮卡·贝鲁奇''label': 1}
    {'query''莫妮卡·贝鲁奇的代表作?','mention''低钙血症''label': 0}
    {'query''莫妮卡·贝鲁奇的代表作?''mention''同居损友''label': 0}
    {'query''莫妮卡·贝鲁奇的代表作?','mention''"1964-09-22"''label': 0}
    {'query''莫妮卡·贝鲁奇的代表作?''mention''夏侯瑾轩''label': 0}
    {'query''莫妮卡·贝鲁奇的代表作?''mention''"日历"''label': 0}

    选择querymention进行拼接,拼接形式如下:

    query[SEP]mention
    示例:莫妮卡·贝鲁奇的代表作?[SEP]莫妮卡·贝鲁奇

    计算实体链接得分:

    为了避免噪声影响,最后,根据得分获取top5的链接实体作为候选实体。

    模型结构代码示例

    import torch
    from torch import nn
    from transformers import BertModel, BertPreTrainedModel

    class BertForSequenceClassification(BertPreTrainedModel):
        def __init__(self, config):
            super().__init__(config)
            self.num_labels = config.num_labels

            self.bert = BertModel(config)
            self.dropout = nn.Dropout(config.hidden_dropout_prob)
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)

            self.init_weights()

        def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
            outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )
            
            pooled_output = outputs[1]
            pooled_output = self.dropout(pooled_output)
            logits = self.classifier(pooled_output)

            loss = None
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = nn.MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = nn.CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

    思路3:

    向量模型的短文本匹配效果也不错,不微调直接使用向量模型配合ES进行实体链接:

    from sentence_transformers import SentenceTransformer
    sentences_1 = "实体"
    sentences_2 = ["es召回的实体1""es召回的实体2",...,"es召回的实体n"]
    model = SentenceTransformer('lier007/xiaobu-embedding-v2')
    embeddings_1 = model.encode(sentences_1, normalize_embeddings=True)
    embeddings_2 = model.encode(sentences_2, normalize_embeddings=True)
    similarity = embeddings_1 @ embeddings_2.T
    print(similarity)

    当然,更进一步的可以根据自己场景进行微调。

    三、属性选择及属性排序

    属性识别及属性排序是从用户提出的问题中识别出与知识库中相关的属性,并根据某些标准对这些属性进行排序。这两个步骤在KBQA系统中非常重要,因为它们直接影响最终答案的准确性和相关性。

    通过上一节中获取到的top5的链接实体,通过ES索引或者mysql等工具建立的知识图谱,召回出对应实体的所有的属性集合。但召回的所有的属性集合存在一个问题,存在着大量的不相关的属性内容,因此,需要训练的一个属性的排序模型选择TOP5的属性保留。

    属性排序训练数据构造形式(正负样本比例1:5):

    {"entity""<莫妮卡·贝鲁奇>""query""莫妮卡·贝鲁奇的代表作?""attr""<代表作品>""label": 1}
    {"entity""<莫妮卡·贝鲁奇>""query""莫妮卡·贝鲁奇的代表作?""attr""<体重>""label": 0}
    {"entity""<莫妮卡·贝鲁奇>""query""莫妮卡·贝鲁奇的代表作?""attr""<出生日期>""label": 0}
    {"entity""<莫妮卡·贝鲁奇>""query""莫妮卡·贝鲁奇的代表作?""attr""<导演>""label": 0}
    {"entity""<莫妮卡·贝鲁奇>""query""莫妮卡·贝鲁奇的代表作?""attr""<职业>""label": 0}
    {"entity""<莫妮卡·贝鲁奇>""query""莫妮卡·贝鲁奇的代表作?""attr""<类型>""label": 0}

    选择queryattr进行拼接,拼接形式如下:

    query[SEP]attr
    示例:莫妮卡·贝鲁奇的代表作?[SEP]<代表作品>

    计算属性得分:

    为了避免噪声影响,最后,根据得分获取top5的属性作为候选属性。

    属性排序模型也是BERT+Linear一个二分类模型:

    四、文本拼接

    因为本文介绍的是结合大模型的思想进行查询语句的生成,本文的链路与RAG的思想非常相似,通过上述路径检索相关文本(相关实体片段相关属性片段),进行组合,组合方式如下:

    prompt+question+候选实体+属性结合

    五、LLM for Text2SQL

    在KBQA系统中,使用大语言模型(LLM)生成SQL查询是至关重要的一步。Text2SQL的任务是将自然语言问题转换为结构化查询语言(SQL),以便从数据库或知识图谱中检索信息。

    根据上节的介绍,对于训练LLM的SFT数据构造示例如下:

    [
      {
        "instruction""你是一个Sparql生成专家,请根据给定的内容,生成Sparql语句。\n问题:“莫妮卡·贝鲁奇的代表作?”,和候选实体信息:[0]名称:<莫妮卡·贝鲁奇>,属性集:<代表作品>,<中文名>,<作者>,<外文名>,<别名>。对应查询图谱的Sparql的语句为:",
        "input""",
        "output""select ?x where { <莫妮卡·贝鲁奇> <代表作品> ?x. }"
      },
      {
        "instruction""你是一个Sparql生成专家,请根据给定的内容,生成Sparql语句。\n问题:“《湖上草》是谁的诗?”,和候选实体信息:[0]名称:<湖上草>,属性集:<主要作品>,<中文名>,<传世之作>,<所著>,<其丈夫>。对应查询图谱的Sparql的语句为:",
        "input""",
        "output""select ?x where { ?x <主要作品> <湖上草>. }"
      },
      ...
    ]

    使用预训练的大语言模型(例如GLM-4-9B)进行微调,使其能够生成正确的SQL查询(本文使用的sparql查询,配合gstore图数据库使用,因gstore问题太多,个人不推荐使用,可以转成其他的查询语句,如:neo4j等)。本文以LLaMA-Factory为例进行微调,微调脚本如下:


    import json
    import os

    model_name_or_path = "ZhipuAI/glm-4-9b-chat"
    template = "glm4"
    cutoff_len = 4096
    num_train_epochs = 8
    train_dataset = "train_data"
    predict_dataset = "test_data"
    output_dir = f"saves/{train_dataset}-{predict_dataset}-ep{num_train_epochs}-{cutoff_len}-{template}"
    adapter_name_or_path = output_dir


    do_train = True
    do_predict = True

    train_args = dict(
        stage="sft",  # 进行指令监督微调
        do_train=do_train,
        model_name_or_path=model_name_or_path,
        dataset=train_dataset,
        template=template,
        finetuning_type="lora",
        cutoff_len=cutoff_len,
        lora_target="all",
        output_dir=output_dir,
        per_device_train_batch_size=2,  # 批处理大小
        gradient_accumulation_steps=4,  # 梯度累积步数
        lr_scheduler_type="cosine",  # 使用余弦学习率退火算法
        logging_steps=10,  # 每 10 步输出一个记录
        warmup_ratio=0.1,  # 使用预热学习率
        save_steps=1000,  # 每 1000 步保存一个检查点
        learning_rate=1e-4,  # 学习率大小
        num_train_epochs=num_train_epochs,  # 训练轮数
        max_samples=7625,  # 使用每个数据集中的 300 条样本
        max_grad_norm=1.0,  # 将梯度范数裁剪至 1.0
        fp16=True,  # 使用 float16 混合精度训练
        temperature=0.1,
        ddp_timeout=180000000,
        overwrite_cache=True,
        overwrite_output_dir=True
    )

    predict_args = dict(
        stage="sft",
        do_predict=do_predict,
        model_name_or_path=model_name_or_path,
        adapter_name_or_path=adapter_name_or_path,
        dataset=predict_dataset,
        template=template,
        finetuning_type="lora",
        cutoff_len=cutoff_len,
        per_device_eval_batch_size=1,
        overwrite_cache=True,
        preprocessing_num_workers=16,
        output_dir=f'{output_dir}/predict',
        overwrite_output_dir=True,
        ddp_timeout=180000000,
        temperature=0.1,
        max_samples=1292,
        predict_with_generate=True
    )


    train_args_file = f"config/{train_dataset}-{predict_dataset}-ep{num_train_epochs}-{cutoff_len}-{template}-train.json"
    predict_args_file = f"config/{train_dataset}-{predict_dataset}-ep{num_train_epochs}-{cutoff_len}-{template}-pred.json"

    json.dump(train_args, open(train_args_file, "w", encoding="utf-8"), indent=2)

    json.dump(predict_args, open(predict_args_file, "w", encoding="utf-8"), indent=2)


    os.system(f'llamafactory-cli train {train_args_file}')
    os.system(f'llamafactory-cli train {predict_args_file}')

    小结

    • 优点:借助大模型的优势,实现文本转化为sparql查询语句,实现单挑、多跳的从kg中查询答案。
    • 缺点:在实践过程中发现,由于大模型的幻觉因素,生成的查询语句“看上去对,实际上错误”,导致查询答案不是准确的答案。

    总结

    本文详细介绍了KBQA(知识图谱问答)系统融合了RAG的思路,分为多个步骤。首先进行mention识别,使用大模型提取文本中的关键实体;接着进行实体链接,将识别到的实体提及与知识图谱中的具体实体匹配和链接;然后对所有可能的实体进行排序,找出最相关的实体;在此基础上进行属性选择及排序,提取与用户问题相关的属性并进行排序,确保返回的结果最符合用户需求;接下来将上述步骤得到的文本内容拼接成完整的上下文;最后,将结构化的文本内容转化为SQL查询,以便从知识图谱或数据库中检索信息。



    53AI,企业落地应用大模型首选服务商

    产品:大模型应用平台+智能体定制开发+落地咨询服务

    承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

    联系我们

    售前咨询
    186 6662 7370
    预约演示
    185 8882 0121

    微信扫码

    与创始人交个朋友

    回到顶部

     
    扫码咨询