AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


手把手教程,改造 GraphRAG 支持自定义 LLM
发布日期:2024-07-24 18:16:54 浏览次数: 3650



导语:最近 GraphRAG 在社区很火,作者亲自体验后,发现了一些可以探讨和改进的地方,本文主要介绍了如何改造 GraphRAG 以支持自定义的 LLM。

01

为什么在 RAG 中引入知识图谱?

传统的 RAG 在处理复杂问题时往往表现不理想,主要是传统 RAG 未能有效捕捉实体间的复杂关系和层次结构,且通常只检索固定数量的最相关文本块:

  • 缺少事情之间关系的理解:当需要关联不同信息以提供综合见解时,传统 RAG 很难将这些点连接起来。
  • 缺乏整体视角:当要求 RAG 全面理解大型数据集甚至单个大型文档的整体语义概念时,缺乏宏观视角,例如,当给它一本小说并问它“这本书的主旨是什么”时,十有八九会给不出靠谱的答案。
这个问题在我们上一篇文章《为什么说知识图谱 + RAG > 传统 RAG?》也有详细分析,感兴趣可以点击上面的链接查看。
微软的 GraphRAG 通过引入知识图谱来解决传统 RAG 的局限性,在索引数据集时,GraphRAG 提取实体和实体间的关系,构建知识图谱,这让 GraphRAG 能够更全面地理解文档的语义,捕捉实体间的复杂关联,从而在处理复杂查询时表现出色。
  • 这种方法适合处理需要对整个数据集进行综合理解的问题,如“数据集中的主要主题是什么?”这类问题;
  • 相比传统的 RAG 方法,Graph RAG 在处理全局性问题时表现出更好;

02

GraphRAG 改造计划

设计的理念很不错,但是真的去体验使用的时候,发现几个问题:

  1. 强依赖于 OpenAI 或 Azure 的服务。对于国内用户来说,OpenAI 的 key 还是需要国外银行卡,Azure 的 API 申请也比较繁琐,还有国外的云一般都是绑定信用卡,可能不小心用超了,上次体验 AWS 的产品,忘了删除了,后面发现扣了我快 1000 块钱,我只是体验下产品而已...
  2. GraphRAG 目前更像是一个 Demo 产品,想和业务结合现在也没什么可以操作的地方,肯定是需要自定义的。

想着能不能让 GraphRAG 集成到业务中,准备对 GraphRAG 做一些改造,主要从以下几个方向进行:
  1. 支持自定义 LLM,OpenAI 也比较贵,换成一些更便宜的模型。我首先选择了自家的 Qwen 模型,大家可以在我的基础上扩展其他模型的支持。Qwen 默认给 50W 的 Token 使用量,够玩一段时间的,而且可以用更便宜的 turbo 模型;
  2. 支持自定义向量数据库,方便线上使用;
  3. 引入一些业务属性,看看如何能和业务结合在一起;
  4. 优化下使用体验,实现生成的知识图谱可视化。

这篇文章我会首先介绍下如何改造 GraphRAG 以支持自定义的 LLM,同时我把修改 GraphRAG 的代码也开源在 GitHub 上了,也欢迎感兴趣的朋友共同建设...

03

环境准备

3.1 安装依赖

因为我们是修改 GraphRAG 的代码,就不从 pip 进行安装了,另外对版本有一定的要求:
  • Python 3.10 ~ 3.12版本
git clone git@github.com:microsoft/graphrag.git
安装 poetry:
# 先安装pipx
brew install pipx
pipx ensurepath
sudo pipx ensurepath --global # optional to allow pipx actions in global scope. See "Global installation" section below.

# 安装poetry
pipx install poetry
poetry completions zsh > ~/.zfunc/_poetry
mkdir $ZSH_CUSTOM/plugins/poetry
poetry completions zsh > $ZSH_CUSTOM/plugins/poetry/_poetry

graphrag 仓库下安装依赖:
poetry install
另外在 PyCharm 中安装下 BigData 的文件预览插件,可以看到 index 过程中的文件结构类型:

3.2 项目结构

graphrag 是 GraphRAG 项目的核心包,包含了所有的关键代码逻辑。下面有几个重要的子目录,每个目录负责不同的功能模块:
  • config 目录:存储 GraphRAG 配置后的对象,在 GraphRAG 启动时,会读取配置文件,并将配置解析为 config 目录下的各种对象;
  • index 目录:核心包,所有索引相关的核心逻辑;
  • query 目录:核心包,查询相关的类和逻辑,当用户提交查询请求时,query 目录下的代码会负责解析查询、检索知识图谱、生成回答等一系列操作;
  • model 目录:核心领域模型,如文本、文档、主题、关系等,GraphRAG 中的核心概念和数据结构,其他模块都围绕着这些模型进行操作和处理;
  • llm 目录:支持的 LLM 的实现。如果要自定义集成通义千问,就需要在这个目录下进行实现;
  • vector_stores 目录:包含向量数据库的实现。如果要自定义向量存储,需要在这个目录下进行实现。

3.3 运行& Debug 项目

不同于 pip 的包安装,这里我们要在 pycharm 里面配置下如何从代码的形式运行项目的内容,官方入门给的几个案例,我们通过代码的形式运行:
mkdir -p ./ragtest/input

# 这一步可以随便替换成一些其他的文档,小一点的, 这样效率比较开,可以更快的验证下我们的改造结果
curl https://www.gutenberg.org/cache/epub/24022/pg24022.txt > ./ragtest/input/book.txt

初始化项目:

python -m graphrag.index --init --root ./ragtest

对文档进行索引:

python -m graphrag.index --root ./ragtest

进行本地查询:

python -m graphrag.query \
--root ./ragtest \
--method local \
"Who is Scrooge, and what are his main relationships?"
如果直接运行上面的命令,会发现无法运行,让我们配置一下:
  • 运行方式选择模块运行;
  • 模块后面参考上述官方的命令,给出的具体模块;
  • 接下来填具体的参数,还有工作目录不要忘了。

在上述配置完成之后,你就可以 debug 项目,一步步了解项目中内部的各种细节了,模块的入口类在包下的 __main__.py 文件中。

04

GraphRAG 支持通义千问

4.1 修改的内容

1、项目中默认支持的 LLM 类型是没有通义千问的,因此在枚举类型上要支持通义千问;
2、在进行 index 的时候,会有一步 load_llm 的操作,我们在配置文件中定义的千问类型,在 load_llm 中实现,兼容下原本的接口。
3、在查询的时候,默认使用 OpenAI 的客户端,判断下配置文件的类型,如果是 qwen 的类型,使用我们自己的千问实现。
项目中的 index 和 query 的 llm 是两套不同的视线,我觉得其实可以合并在一起的,不过为了先走通,就是在 index 和 query 都实现了一遍。
核心是在 llm 目录下新增了一个 qwen 的包;在 query 的 llm/qwen 目录下新增了 qwen 的问答实现。

4.2 支持 Qwen 类型的配置

在 config 的 enums 中增加下千问的几个枚举,不然直接在配置文件中写 qwen 会报类型无法转换错误。

4.3 使用 Qwen 进行 Index

在 index 的时候,执行逻辑会走到 load_llm,在加载 llm 的部分,支持下 QwenLLM 的实现。
然后实现对应的方法和类,我再给出我们的 QwenCompletionLLM 以及
def _load_qwen_llm(
        on_error: ErrorHandlerFn,
        cache: LLMCache,
        config: dict[str, Any],
        azure=False,
):
    log.info(f"Loading Qwen completion LLM with config {config}")
    return QwenCompletionLLM(config)

def _load_qwen_embeddings_llm(
        on_error: ErrorHandlerFn,
        cache: LLMCache,
        config: dict[str, Any],
        azure=False,
):
    log.info(f"Loading Qwen embeddings LLM with config {config}")
    return DashscopeEmbeddingsLLM(config);

通过兼容原本的方法,到这里索引部分就可以通过 Qwen 完全进行使用了。


# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

import asyncio
import json
import logging
from http import HTTPStatus
from typing import Unpack, List, Dict

import dashscope
import regex as re

from graphrag.config import LLMType
from graphrag.llm import LLMOutput
from graphrag.llm.base import BaseLLM
from graphrag.llm.base.base_llm import TIn, TOut
from graphrag.llm.types import (
    CompletionInput,
    CompletionOutput,
    LLMInput,
)

log = logging.getLogger(__name__)


class QwenCompletionLLM(
    BaseLLM[
        CompletionInput,
        CompletionOutput,
    ]
)
:

    def __init__(self, llm_config: dict = None):
        log.info(f"llm_config: {llm_config}")
        self.llm_config = llm_config or {}
        self.api_key = self.llm_config.get("api_key", "")
        self.model = self.llm_config.get("model", dashscope.Generation.Models.qwen_turbo)
        # self.chat_mode = self.llm_config.get("chat_mode", False)
        self.llm_type = llm_config.get("type", LLMType.StaticResponse)
        self.chat_mode = (llm_config.get("type", LLMType.StaticResponse) == LLMType.QwenChat)

    async def _execute_llm(
            self,
            input: CompletionInput,
            **kwargs: Unpack[LLMInput],
    )
 -> CompletionOutput:

        log.info(f"input: {input}")
        log.info(f"kwargs: {kwargs}")

        variables = kwargs.get("variables", {})

        # 使用字符串替换功能替换占位符
        formatted_input = replace_placeholders(input, variables)

        if self.chat_mode:
            history = kwargs.get("history", [])
            messages = [
                *history,
                {"role""user", "content": formatted_input},
            ]
            response = self.call_with_messages(messages)
        else:
            response = self.call_with_prompt(formatted_input)

        if response.status_code == HTTPStatus.OK:
            if self.chat_mode:
                return response.output["choices"][0]["message"]["content"]
            else:
                return response.output["text"]
        else:
            raise Exception(f"Error {response.code}{response.message}")

    def call_with_prompt(self, query: str):
        print("call_with_prompt {}".format(query))
        response = dashscope.Generation.call(
            model=self.model,
            prompt=query,
            api_key=self.api_key
        )
        return response

    def call_with_messages(self, messages: list[dict[str, str]]):
        print("call_with_messages {}".format(messages))
        response = dashscope.Generation.call(
            model=self.model,
            messages=messages,
            api_key=self.api_key,
            result_format='message'
        )
        return response

    # 主函数
    async def _invoke_json(self, input: TIn, **kwargs) -> LLMOutput[TOut]:
        try:
            output = await self._execute_llm(input, **kwargs)
        except Exception as e:
            print(f"Error executing LLM: {e}")
            return LLMOutput[TOut](output=None, json=None)

        # 解析output的内容
        extracted_jsons = extract_json_strings(output)

        if len(extracted_jsons) > 0:
            json_data = extracted_jsons[0]
        else:
            json_data = None

        try:
            output_str = json.dumps(json_data)
        except (TypeError, ValueError) as e:
            print(f"Error serializing JSON: {e}")
            output_str = None

        return LLMOutput[TOut](
            output=output_str,
            json=json_data
        )


def replace_placeholders(input_str, variables):
    for key, value in variables.items():
        placeholder = "{" + key + "}"
        input_str = input_str.replace(placeholder, value)
    return input_str


def preprocess_input(input_str):
    # 预处理输入字符串,移除或转义特殊字符
    return input_str.replace('<', '<').replace('>', '>')


def extract_json_strings(input_string: str) -> List[Dict]:
    # 正则表达式模式,用于匹配 JSON 对象
    json_pattern = re.compile(r'(\{(?:[^{}]|(?R))*\})')

    # 查找所有匹配的 JSON 子字符串
    matches = json_pattern.findall(input_string)

    json_objects = []
    for match in matches:
        try:
            # 尝试解析 JSON 子字符串
            json_object = json.loads(match)
            json_objects.append(json_object)
        except json.JSONDecodeError:
            # 如果解析失败,忽略此子字符串
            log.warning(f"Invalid JSON string: {match}")
            pass

    return json_objects


实现下对应的 Embeding 模型;

"""The EmbeddingsLLM class."""
import logging

log = logging.getLogger(__name__)

from typing import Unpack
from graphrag.llm.base import BaseLLM
from graphrag.llm.types import (
    EmbeddingInput,
    EmbeddingOutput,
    LLMInput,
)

from http import HTTPStatus
import dashscope
import logging

log = logging.getLogger(__name__)


class QwenEmbeddingsLLM(BaseLLM[EmbeddingInput, EmbeddingOutput]):
    """A text-embedding generator LLM using Dashscope's API."""

    def __init__(self, llm_config: dict = None):
        log.info(f"llm_config: {llm_config}")
        self.llm_config = llm_config or {}
        self.api_key = self.llm_config.get("api_key""")
        self.model = self.llm_config.get("model", dashscope.TextEmbedding.Models.text_embedding_v1)

    async def _execute_llm(
            self, input: EmbeddingInput, **kwargs: Unpack[LLMInput]
    )
 -> EmbeddingOutput:

        log.info(f"input: {input}")

        response = dashscope.TextEmbedding.call(
            model=self.model,
            input=input,
            api_key=self.api_key
        )

        if response.status_code == HTTPStatus.OK:
            res = [embedding["embedding"for embedding in response.output["embeddings"]]
            return res
        else:
            raise Exception(f"Error {response.code}{response.message}")
通过刚才我们配置的运行方式,配置下 Qwen,运行下,然后可以通过 BigData Viwer 看到里面的内容:
在 indexing-engine.log 里面也可以看到详细的内容

4.4 使用 Qwen 进行 Query

在 GraphRAG 中,query 和 index 用的是不同的 BaseLLM 抽象类,并且在 Query 这里默认用的 OpenAIEmbeding,这里我们也修改一下。
  • query 相比 index 支持了流式的输出内容:

Qwen 的 Query 问答实现:
import asyncio
import logging
from http import HTTPStatus
from typing import Any

import dashscope
from tenacity import (
    Retrying,
    RetryError,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential_jitter,
)

from graphrag.query.llm.base import BaseLLMCallback, BaseLLM
from graphrag.query.progress import StatusReporter, ConsoleStatusReporter

log = logging.getLogger(__name__)

class DashscopeGenerationLLM(BaseLLM):
    def __init__(
        self,
        api_key: str | None = None,
        model: str | None = None,
        max_retries: int = 10
        request_timeout: float = 180.0
        retry_error_types: tuple[type[BaseException]] = (Exception,)
        reporter: StatusReporter = ConsoleStatusReporter()
    )
:

        self.api_key = api_key
        self.model = model or dashscope.Generation.Models.qwen_turbo
        self.max_retries = max_retries
        self.request_timeout = request_timeout
        self.retry_error_types = retry_error_types
        self._reporter = reporter

    def generate(
        self,
        messages: str | list[str],
        streaming: bool = False,
        callbacks: list[BaseLLMCallback] | None = None,
        **kwargs: Any,
    )
 -> str:

        try:
            retryer = Retrying(
                stop=stop_after_attempt(self.max_retries),
                wait=wait_exponential_jitter(max=10),
                reraise=True
                retry=retry_if_exception_type(self.retry_error_types),
            )
            for attempt in retryer:
                with attempt:
                    return self._generate(
                        messages=messages,
                        streaming=streaming,
                        callbacks=callbacks,
                        **kwargs,
                    )
        except RetryError as e:
            self._reporter.error(
                message="Error at generate()", details={self.__class__.__name__: str(e)}
            )
            return ""
        else:
            return ""

    async def agenerate(
        self,
        messages: str | list[str],
        streaming: bool = False,
        callbacks: list[BaseLLMCallback] | None = None,
        **kwargs: Any,
    )
 -> str:

        try:
            retryer = Retrying(
                stop=stop_after_attempt(self.max_retries),
                wait=wait_exponential_jitter(max=10),
                reraise=True
                retry=retry_if_exception_type(self.retry_error_types),
            )
            for attempt in retryer:
                with attempt:
                    return await asyncio.to_thread(
                        self._generate,
                        messages=messages,
                        streaming=streaming,
                        callbacks=callbacks,
                        **kwargs,
                    )
        except RetryError as e:
            self._reporter.error(f"Error at agenerate(): {e}")
            return ""
        else:
            return ""

    def _generate(
            self,
            messages: str | list[str],
            streaming: bool = False,
            callbacks: list[BaseLLMCallback] | None = None,
            **kwargs: Any,
    )
 -> str:

        if isinstance(messages, list):
            response = dashscope.Generation.call(
                model=self.model,
                messages=messages,
                api_key=self.api_key,
                stream=streaming,
                incremental_output=streaming,
                timeout=self.request_timeout,
                result_format='message'
                **kwargs,
            )
        else:
            response = dashscope.Generation.call(
                model=self.model,
                prompt=messages,
                api_key=self.api_key,
                stream=streaming,
                incremental_output=streaming,
                timeout=self.request_timeout,
                **kwargs,
            )

        # if response.status_code != HTTPStatus.OK:
        #     raise Exception(f"Error {response.code}: {response.message}")

        if streaming:
            full_response = ""
            for chunk in response:
                if chunk.status_code != HTTPStatus.OK:
                    raise Exception(f"Error {chunk.code}{chunk.message}")

                decoded_chunk = chunk.output.choices[0]['message']['content']
                full_response += decoded_chunk
                if callbacks:
                    for callback in callbacks:
                        callback.on_llm_new_token(decoded_chunk)
            return full_response
        else:
            if isinstance(messages, list):
                return response.output["choices"][0]["message"]["content"]
            else:
                return response.output["text"]

实现 Query 的 Embedding 对象:

import asyncio
import logging
from typing import Any

import dashscope
from tenacity import (
    Retrying,
    RetryError,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential_jitter,
)

from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.progress import StatusReporter, ConsoleStatusReporter

log = logging.getLogger(__name__)


class DashscopeEmbedding(BaseTextEmbedding):

    def __init__(
            self,
            api_key: str | None = None,
            model: str = dashscope.TextEmbedding.Models.text_embedding_v1,
            max_retries: int = 10
            retry_error_types: tuple[type[BaseException]] = (Exception,)
            reporter: StatusReporter = ConsoleStatusReporter()
    )
:

        self.api_key = api_key
        self.model = model
        self.max_retries = max_retries
        self.retry_error_types = retry_error_types
        self._reporter = reporter

    def embed(self, text: str, **kwargs: Any) -> list[float]:
        try:
            embedding = self._embed_with_retry(text, **kwargs)
            return embedding
        except Exception as e:
            self._reporter.error(
                message="Error embedding text"
                details={self.__class__.__name__: str(e)},
            )
            return []

    async def aembed(self, text: str, **kwargs: Any) -> list[float]:
        try:
            embedding = await asyncio.to_thread(self._embed_with_retry, text, **kwargs)
            return embedding
        except Exception as e:
            self._reporter.error(
                message="Error embedding text asynchronously"
                details={self.__class__.__name__: str(e)},
            )
            return []

    def _embed_with_retry(self, text: str, **kwargs: Any) -> list[float]:
        try:
            retryer = Retrying(
                stop=stop_after_attempt(self.max_retries),
                wait=wait_exponential_jitter(max=10),
                reraise=True
                retry=retry_if_exception_type(self.retry_error_types),
            )
            for attempt in retryer:
                with attempt:
                    response = dashscope.TextEmbedding.call(
                        model=self.model,
                        input=text,
                        api_key=self.api_key,
                        **kwargs,
                    )
                    if response.status_code == 200:
                        embedding = response.output["embeddings"][0]["embedding"]
                        return embedding
                    else:
                        raise Exception(f"Error {response.code}{response.message}")
        except RetryError as e:
            self._reporter.error(
                message="Error at embed_with_retry()"
                details={self.__class__.__name__: str(e)},
            )
            return []

运行下 Query 的效果:

可以看到使用的是 Qwen 的模型进行的问答:

4.5 项目中的一些关键节点

创建工作流的地方,以及默认的工作流:

4.6 遇到错误怎么办

上面已经提到了可以配置 pycharm 来配置 debug 断点;
在执行错误的时候,在 output 下面会有对应的执行详细信息,根据错误信息,可以在对应的地方加上断点查看错误的原因是什么。

05

GraphRAG 的核心步骤

参考论文《From Local to Global: A Graph RAG Approach to Query-Focused Summarization》的描述:https://arxiv.org/pdf/2404.16130
GraphRAG 的主要 Pipeline 步骤说明:
1. 文本分块 (Source Documents → Text Chunks),将源文档分割成较小的文本块,每块大约 600 个 token。块之间有 100 个 token 的重叠,以保持上下文连贯性;
2. 元素实例提取 (Text Chunks → Element Instances):使用 LLM 从每个文本块中提取实体、关系和声明。实体包括名称、类型和描述。关系包括源实体、目标实体和描述。使用多轮"gleaning"技术来提高提取质量;
  • "Gleaning" 是一种迭代式的信息提取方法。初始提取: LLM 首先对文本块进行一次实体和关系提取。评估:LLM 被要求评估是否所有实体都被提取出来了。迭代提取::如果 LLM 认为有遗漏,它会被提示进行额外的"gleaning"轮次,尝试提取之前可能遗漏的实体。多轮进行:这个过程可以重复多次,直到达到预设的最大轮次或 LLM 认为没有更多实体可提取。
create_final_entities.parquet
create_final_nodes.parquet
create_final_relationships.parquet
3. 元素摘要生成 (Element Instances → Element Summaries):将相同元素的多个实例合并成单一的描述性文本块。
4. 图社区检测 (Element Summaries → Graph Communities):将实体作为节点,关系作为边构建无向加权图。
  • 使用 Leiden 算法进行检测,得到层次化的社区结构,Leiden 算法帮助我们把大量的文本信息组织成有意义的群组,使得我们可以更容易地理解和处理这些信息。

5. 社区摘要生成 (Graph Communities → Community Summaries):
  • 为每个社区生成报告式摘要;
  • 对于叶子级社区,直接总结其包含的所有元素;
  • 对于高层社区,递归地利用子社区摘要。

6. 查询回答 (Community Summaries → Community Answers → Global Answer):
  • Community Summaries(社区摘要):预先生成的,包含了图中每个社区(即相关实体群组)的概要信息。它们存储了关于每个主题领域的关键信息,通过问题找到一些相关的主题(社区摘要);
  • Community Answers(社区回答):当收到用户查询时,系统会并行处理每个社区摘要,对每个社区摘要,系统会生成一个针对用户问题的部分答案,系统还会给每个部分答案评分,表示其对回答问题的相关性;
  • Global Answer:系统会收集所有有用的部分答案(过滤掉评分为0的答案),然后,它会按照相关性评分对这些答案进行排序。最后,系统会综合这些部分答案,生成一个全面、连贯的最终答案。

06

小结


这里主要想和大家分享下如何定制下 GraphRAG 支持千问模型,方便更多的同学体验下 GraphRAG,当然这还只是第一步,GraphRAG 还不能直接应用到真实的场景中。
官网上有一些架构的设计理念和过程,也可以参考学习,看到 GraphRAG 的社区热度也挺高的,估计很快就可以作为一个相对成熟的方案引入到实际的系统中。
为了方便学习,我把上述的改动代码上传到了代码仓库,感兴趣的同学可以试一下,也可以继续进行定制和优化...
代码仓库地址:https://code.alibaba-inc.com/aihehe.ah/biz_graphrag
接下来尝试下怎么支持自定义的向量存储以及调研下能否和业务集成,比如:
  • 自定义 VectorStore 实现
  • GraphRAG 可视化过程
  • 业务集成 GraphRAG



53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询