微信扫码
与创始人交个朋友
我要投稿
导语:最近 GraphRAG 在社区很火,作者亲自体验后,发现了一些可以探讨和改进的地方,本文主要介绍了如何改造 GraphRAG 以支持自定义的 LLM。
01
为什么在 RAG 中引入知识图谱?
相比传统的 RAG 方法,Graph RAG 在处理全局性问题时表现出更好;
02
GraphRAG 改造计划
GraphRAG 目前更像是一个 Demo 产品,想和业务结合现在也没什么可以操作的地方,肯定是需要自定义的。
这篇文章我会首先介绍下如何改造 GraphRAG 以支持自定义的 LLM,同时我把修改 GraphRAG 的代码也开源在 GitHub 上了,也欢迎感兴趣的朋友共同建设...
03
环境准备
3.1 安装依赖
git clone git@github.com:microsoft/graphrag.git
# 先安装pipx
brew install pipx
pipx ensurepath
sudo pipx ensurepath --global # optional to allow pipx actions in global scope. See "Global installation" section below.
# 安装poetry
pipx install poetry
poetry completions zsh > ~/.zfunc/_poetry
mkdir $ZSH_CUSTOM/plugins/poetry
poetry completions zsh > $ZSH_CUSTOM/plugins/poetry/_poetry
poetry install
3.2 项目结构
vector_stores 目录:包含向量数据库的实现。如果要自定义向量存储,需要在这个目录下进行实现。
3.3 运行& Debug 项目
mkdir -p ./ragtest/input
# 这一步可以随便替换成一些其他的文档,小一点的, 这样效率比较开,可以更快的验证下我们的改造结果
curl https://www.gutenberg.org/cache/epub/24022/pg24022.txt > ./ragtest/input/book.txt
初始化项目:
python -m graphrag.index --init --root ./ragtest
对文档进行索引:
python -m graphrag.index --root ./ragtest
进行本地查询:
python -m graphrag.query \
--root ./ragtest \
--method local \
"Who is Scrooge, and what are his main relationships?"
接下来填具体的参数,还有工作目录不要忘了。
04
GraphRAG 支持通义千问
4.1 修改的内容
4.2 支持 Qwen 类型的配置
4.3 使用 Qwen 进行 Index
def _load_qwen_llm(
on_error: ErrorHandlerFn,
cache: LLMCache,
config: dict[str, Any],
azure=False,
):
log.info(f"Loading Qwen completion LLM with config {config}")
return QwenCompletionLLM(config)
def _load_qwen_embeddings_llm(
on_error: ErrorHandlerFn,
cache: LLMCache,
config: dict[str, Any],
azure=False,
):
log.info(f"Loading Qwen embeddings LLM with config {config}")
return DashscopeEmbeddingsLLM(config);
通过兼容原本的方法,到这里索引部分就可以通过 Qwen 完全进行使用了。
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
import asyncio
import json
import logging
from http import HTTPStatus
from typing import Unpack, List, Dict
import dashscope
import regex as re
from graphrag.config import LLMType
from graphrag.llm import LLMOutput
from graphrag.llm.base import BaseLLM
from graphrag.llm.base.base_llm import TIn, TOut
from graphrag.llm.types import (
CompletionInput,
CompletionOutput,
LLMInput,
)
log = logging.getLogger(__name__)
class QwenCompletionLLM(
BaseLLM[
CompletionInput,
CompletionOutput,
]
):
def __init__(self, llm_config: dict = None):
log.info(f"llm_config: {llm_config}")
self.llm_config = llm_config or {}
self.api_key = self.llm_config.get("api_key", "")
self.model = self.llm_config.get("model", dashscope.Generation.Models.qwen_turbo)
# self.chat_mode = self.llm_config.get("chat_mode", False)
self.llm_type = llm_config.get("type", LLMType.StaticResponse)
self.chat_mode = (llm_config.get("type", LLMType.StaticResponse) == LLMType.QwenChat)
async def _execute_llm(
self,
input: CompletionInput,
**kwargs: Unpack[LLMInput],
) -> CompletionOutput:
log.info(f"input: {input}")
log.info(f"kwargs: {kwargs}")
variables = kwargs.get("variables", {})
# 使用字符串替换功能替换占位符
formatted_input = replace_placeholders(input, variables)
if self.chat_mode:
history = kwargs.get("history", [])
messages = [
*history,
{"role": "user", "content": formatted_input},
]
response = self.call_with_messages(messages)
else:
response = self.call_with_prompt(formatted_input)
if response.status_code == HTTPStatus.OK:
if self.chat_mode:
return response.output["choices"][0]["message"]["content"]
else:
return response.output["text"]
else:
raise Exception(f"Error {response.code}: {response.message}")
def call_with_prompt(self, query: str):
print("call_with_prompt {}".format(query))
response = dashscope.Generation.call(
model=self.model,
prompt=query,
api_key=self.api_key
)
return response
def call_with_messages(self, messages: list[dict[str, str]]):
print("call_with_messages {}".format(messages))
response = dashscope.Generation.call(
model=self.model,
messages=messages,
api_key=self.api_key,
result_format='message',
)
return response
# 主函数
async def _invoke_json(self, input: TIn, **kwargs) -> LLMOutput[TOut]:
try:
output = await self._execute_llm(input, **kwargs)
except Exception as e:
print(f"Error executing LLM: {e}")
return LLMOutput[TOut](output=None, json=None)
# 解析output的内容
extracted_jsons = extract_json_strings(output)
if len(extracted_jsons) > 0:
json_data = extracted_jsons[0]
else:
json_data = None
try:
output_str = json.dumps(json_data)
except (TypeError, ValueError) as e:
print(f"Error serializing JSON: {e}")
output_str = None
return LLMOutput[TOut](
output=output_str,
json=json_data
)
def replace_placeholders(input_str, variables):
for key, value in variables.items():
placeholder = "{" + key + "}"
input_str = input_str.replace(placeholder, value)
return input_str
def preprocess_input(input_str):
# 预处理输入字符串,移除或转义特殊字符
return input_str.replace('<', '<').replace('>', '>')
def extract_json_strings(input_string: str) -> List[Dict]:
# 正则表达式模式,用于匹配 JSON 对象
json_pattern = re.compile(r'(\{(?:[^{}]|(?R))*\})')
# 查找所有匹配的 JSON 子字符串
matches = json_pattern.findall(input_string)
json_objects = []
for match in matches:
try:
# 尝试解析 JSON 子字符串
json_object = json.loads(match)
json_objects.append(json_object)
except json.JSONDecodeError:
# 如果解析失败,忽略此子字符串
log.warning(f"Invalid JSON string: {match}")
pass
return json_objects
实现下对应的 Embeding 模型;
"""The EmbeddingsLLM class."""
import logging
log = logging.getLogger(__name__)
from typing import Unpack
from graphrag.llm.base import BaseLLM
from graphrag.llm.types import (
EmbeddingInput,
EmbeddingOutput,
LLMInput,
)
from http import HTTPStatus
import dashscope
import logging
log = logging.getLogger(__name__)
class QwenEmbeddingsLLM(BaseLLM[EmbeddingInput, EmbeddingOutput]):
"""A text-embedding generator LLM using Dashscope's API."""
def __init__(self, llm_config: dict = None):
log.info(f"llm_config: {llm_config}")
self.llm_config = llm_config or {}
self.api_key = self.llm_config.get("api_key", "")
self.model = self.llm_config.get("model", dashscope.TextEmbedding.Models.text_embedding_v1)
async def _execute_llm(
self, input: EmbeddingInput, **kwargs: Unpack[LLMInput]
) -> EmbeddingOutput:
log.info(f"input: {input}")
response = dashscope.TextEmbedding.call(
model=self.model,
input=input,
api_key=self.api_key
)
if response.status_code == HTTPStatus.OK:
res = [embedding["embedding"] for embedding in response.output["embeddings"]]
return res
else:
raise Exception(f"Error {response.code}: {response.message}")
4.4 使用 Qwen 进行 Query
query 相比 index 支持了流式的输出内容:
import asyncio
import logging
from http import HTTPStatus
from typing import Any
import dashscope
from tenacity import (
Retrying,
RetryError,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from graphrag.query.llm.base import BaseLLMCallback, BaseLLM
from graphrag.query.progress import StatusReporter, ConsoleStatusReporter
log = logging.getLogger(__name__)
class DashscopeGenerationLLM(BaseLLM):
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
max_retries: int = 10,
request_timeout: float = 180.0,
retry_error_types: tuple[type[BaseException]] = (Exception,),
reporter: StatusReporter = ConsoleStatusReporter(),
):
self.api_key = api_key
self.model = model or dashscope.Generation.Models.qwen_turbo
self.max_retries = max_retries
self.request_timeout = request_timeout
self.retry_error_types = retry_error_types
self._reporter = reporter
def generate(
self,
messages: str | list[str],
streaming: bool = False,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
try:
retryer = Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
for attempt in retryer:
with attempt:
return self._generate(
messages=messages,
streaming=streaming,
callbacks=callbacks,
**kwargs,
)
except RetryError as e:
self._reporter.error(
message="Error at generate()", details={self.__class__.__name__: str(e)}
)
return ""
else:
return ""
async def agenerate(
self,
messages: str | list[str],
streaming: bool = False,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
try:
retryer = Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
for attempt in retryer:
with attempt:
return await asyncio.to_thread(
self._generate,
messages=messages,
streaming=streaming,
callbacks=callbacks,
**kwargs,
)
except RetryError as e:
self._reporter.error(f"Error at agenerate(): {e}")
return ""
else:
return ""
def _generate(
self,
messages: str | list[str],
streaming: bool = False,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
if isinstance(messages, list):
response = dashscope.Generation.call(
model=self.model,
messages=messages,
api_key=self.api_key,
stream=streaming,
incremental_output=streaming,
timeout=self.request_timeout,
result_format='message',
**kwargs,
)
else:
response = dashscope.Generation.call(
model=self.model,
prompt=messages,
api_key=self.api_key,
stream=streaming,
incremental_output=streaming,
timeout=self.request_timeout,
**kwargs,
)
# if response.status_code != HTTPStatus.OK:
# raise Exception(f"Error {response.code}: {response.message}")
if streaming:
full_response = ""
for chunk in response:
if chunk.status_code != HTTPStatus.OK:
raise Exception(f"Error {chunk.code}: {chunk.message}")
decoded_chunk = chunk.output.choices[0]['message']['content']
full_response += decoded_chunk
if callbacks:
for callback in callbacks:
callback.on_llm_new_token(decoded_chunk)
return full_response
else:
if isinstance(messages, list):
return response.output["choices"][0]["message"]["content"]
else:
return response.output["text"]
实现 Query 的 Embedding 对象:
import asyncio
import logging
from typing import Any
import dashscope
from tenacity import (
Retrying,
RetryError,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.progress import StatusReporter, ConsoleStatusReporter
log = logging.getLogger(__name__)
class DashscopeEmbedding(BaseTextEmbedding):
def __init__(
self,
api_key: str | None = None,
model: str = dashscope.TextEmbedding.Models.text_embedding_v1,
max_retries: int = 10,
retry_error_types: tuple[type[BaseException]] = (Exception,),
reporter: StatusReporter = ConsoleStatusReporter(),
):
self.api_key = api_key
self.model = model
self.max_retries = max_retries
self.retry_error_types = retry_error_types
self._reporter = reporter
def embed(self, text: str, **kwargs: Any) -> list[float]:
try:
embedding = self._embed_with_retry(text, **kwargs)
return embedding
except Exception as e:
self._reporter.error(
message="Error embedding text",
details={self.__class__.__name__: str(e)},
)
return []
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
try:
embedding = await asyncio.to_thread(self._embed_with_retry, text, **kwargs)
return embedding
except Exception as e:
self._reporter.error(
message="Error embedding text asynchronously",
details={self.__class__.__name__: str(e)},
)
return []
def _embed_with_retry(self, text: str, **kwargs: Any) -> list[float]:
try:
retryer = Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
for attempt in retryer:
with attempt:
response = dashscope.TextEmbedding.call(
model=self.model,
input=text,
api_key=self.api_key,
**kwargs,
)
if response.status_code == 200:
embedding = response.output["embeddings"][0]["embedding"]
return embedding
else:
raise Exception(f"Error {response.code}: {response.message}")
except RetryError as e:
self._reporter.error(
message="Error at embed_with_retry()",
details={self.__class__.__name__: str(e)},
)
return []
运行下 Query 的效果:
4.5 项目中的一些关键节点
4.6 遇到错误怎么办
05
GraphRAG 的核心步骤
使用 Leiden 算法进行检测,得到层次化的社区结构,Leiden 算法帮助我们把大量的文本信息组织成有意义的群组,使得我们可以更容易地理解和处理这些信息。
对于高层社区,递归地利用子社区摘要。
06
小结
业务集成 GraphRAG
53AI,企业落地应用大模型首选服务商
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-07-18
2024-05-05
2024-09-04
2024-06-20
2024-05-19
2024-07-09
2024-07-09
2024-07-07
2024-06-13
2024-07-07