AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


基于Agent的金融问答系统:代码重构
发布日期:2024-10-15 08:49:28 浏览次数: 1757 来源:一起AI技术

前言

在上一章【项目实战】基于Agent的金融问答系统:前后端流程打通,我们已经完成了金融问答系统的前后端搭建,形成了可用的Demo。本章,我们将介绍代码重构的过程,并介绍一些优化点。

代码重构简介

在开启本章介绍之前,请允许我花点时间啰嗦两句,聊一聊代码重构的哪些事儿。

在过去经历的项目中,代码重构很少被人重视。看着像?一样的代码(抱歉爆粗口,我所经历的一些项目包括我自己曾经写的代码,回看确实像?一样),它们并没有被好好清理,然后我们在?上面不断加需求,导致需求迭代越来越难,Bug越来越多...

这种事情现在每天还在不断地发生着,所以我决定有必要聊一聊代码重构。

什么是代码重构

代码重构是指对现有代码进行修改,以改善其结构、可读性和可维护性,而不改变其外部行为。重构的主要目的是提高代码质量,使其更易于理解和扩展。

代码重构的目的

  • • 提高可读性:使代码更易于理解,便于团队成员快速上手。

  • • 增强可维护性:降低后续修改和扩展的难度,减少潜在的错误。

  • • 优化性能:在不改变功能的情况下,提升代码的执行效率。

  • • 消除重复代码:通过抽象和重用,减少冗余,提高代码的整洁性。

代码重构的重要性

据统计,不好的代码会占用更多开发的时间。

代码重构的难点

通过代码重构提升代码质量既然如此重要,那么为什么很少有项目开展呢?

究其原因,可能有三点:

  • • 第一种:没有精力重构。开发工程师经常性被老板或者产品牵着鼻子走,完成一个需求接着一个新的需求,所以很少开展重构工作。这种情况在技术性为导向的项目还好,在以产品或市场为导向的项目中,尤其严重。

  • • 第二种:没有重构的思维。很多的开发工程师没有重构的思维甚至想法,他们以完成需求交付为目的,需求交付了也就代表他的工作结束了。

    我曾经与谷歌回来的一位朋友有次交流,我们探讨的内容是:为什么国内的研发人员代码质量意识薄弱?他说其中一个很重要的原因是:硅谷的很多从业者,是因为热爱,热爱编程、热爱技术,所以视自己写的代码为一件艺术品,力求精益求精;而国内有很多从业者,是因为生存,是因为做开发给钱多,是一份养家糊口的一份工作而已,因为缺少热爱,所以交差了事即可。对此,我深以为然。
  • • 第三种:没有重构的方法论。虽然我们很像做重构,但是重构工作就像修复一辆越开越慢的车子,如果没有科学的方法,有可能出现拆了重装之后,反而多了几个螺丝的问题,这会让老板更加恐怖。

本章,我将试图以这个金融问答系统为例,简单介绍一些代码重构的原则、方法。

代码重构的过程

1、搭建测试框架以及用例集

在开展代码重构前,我们要搭建好一个便于回归测试的测试框架,通过边重构边回归的方式,可以快速定位问题所在,以此降低问题排查的成本。

我们在app目录下,已经创建了一个test_framework.py中,继续补充测试用例集,例如:

在大厂中,回归测试一般会使用单元测试框架(如pytest)来进行执行,由于本例中我们的方法较为简单,所以就没有使用pytest。

2、消灭代码中的坏味道

2.1、统一管理配置相关内容

在之前实现的RAG管理模块中,有很多的配置是硬编码写在代码初始化中的,例如:

# 原始的rag.py
class RagManager:
    def __init__(self,
                 chroma_server_type="http",
                 host="localhost", port=8000,
                 persist_path="chroma_db",
                 llm=None, embed=None):
        self.llm = llm
        self.embed = embed

        chrom_db = ChromaDB(chroma_server_type=chroma_server_type,
                            host=host, port=port,
                            persist_path=persist_path,
                            embed=embed)
        self.store = chrom_db.get_store()

我们可以将所有的配置相关抽取到一个settings.py中,然后在使用的代码中通过引用settings.py来进行配置。

# setttings.py

"""
Chroma向量数据库使用时的相关的配置
"""

# 默认的ChromaDB的服务器类别
CHROMA_SERVER_TYPE ="http"
# 默认本地数据库的持久化目录
CHROMA_PERSIST_DB_PATH ="chroma_db"

CHROMA_HOST = os.getenv("CHROMA_HOST","localhost")
CHROMA_PORT =int(os.getenv("CHROMA_PORT",8000))
CHROMA_COLLECTION_NAME ="langchain"

说明:

  • • 为了有别于变量的命名,对于配置我们使用大写的变量名,例如:CHROMA_HOST、CHROMA_PORT等。

# 重构的rag.py
import settings


classRagManager:
def__init__(self,
                 vector_db_class=ChromaDB,  # 默认使用 ChromaDB
                 db_config=None,  # 数据库配置参数
                 llm=None, embed=None,
                 retriever_cls=SimpleRetrieverWrapper, **retriever_kwargs):
        self.llm = llm
        self.embed = embed
        logger.info(f'初始化llm大模型:{self.llm}')
        logger.info(f'初始化embed模型:{self.embed}')

# 如果没有提供 db_config,使用默认配置
if db_config isNone:
            db_config ={
"chroma_server_type": settings.CHROMA_SERVER_TYPE,
"host": settings.CHROMA_HOST,
"port": settings.CHROMA_PORT,
"persist_path": settings.CHROMA_PERSIST_DB_PATH,
"collection_name": settings.CHROMA_COLLECTION_NAME
}
            logger.info(f'初始化向量数据库配置:{db_config}')

# 创建向量数据库实例
        self.vector_db = vector_db_class(**db_config, embed=self.embed)
        self.store = self.vector_db.get_store()

说明:

  • • 上述代码中通过import settings,在使用配置时通过settings.CHROMA_SERVER_TYPE、settings.CHROMA_HOST等来引用。

2.2、处理参数过长的问题

在原始代码中,随着我们的需求迭代,在创建RAG时需要传入多个的参数,例如:

  • • chroma_server_type

  • • host

  • • port

  • • persist_path

  • • collection_name

如果按照原来的方法写函数,那么函数的参数列表就会非常长,如下:

RagManager(chroma_server_type="http", host="localhost", port=8000, persist_path="chroma_db", collection_name="langchain",llm , embed)

对于这种参数的问题,我们可以通过使用字典来处理,如下:

db_config = {
                "chroma_server_type": settings.CHROMA_SERVER_TYPE,
                "host": settings.CHROMA_HOST,
                "port": settings.CHROMA_PORT,
                "persist_path": settings.CHROMA_PERSIST_DB_PATH,
                "collection_name": settings.CHROMA_COLLECTION_NAME,
            }

RagManager(vector_db_class=ChromaDB, db_config=db_config, llm=self.llm, embed=self.embed)

说明:

  • • db_config是一个字典,可以包含多个配置参数,例如:chroma_server_typehostportpersist_pathcollection_name等。

  • • db_config中的参数可以通过**关键字来解包,从而传入到函数中。

  • • RagManager 的初始化函数中,通过**关键字来解包db_config,从而传入到ChromaDB的初始化函数中。

2.3、减少重复代码

【项目实战】基于Agent的金融问答系统:RAG检索模块初建成中,我们曾实现了一个pdf_processor.py, 该函数主要的工作是:

def process_pdfs(self)# 处理pdf文件
defprocess_pdfs_group(self, pdf_files_group)# 分组处理pdf文件
defload_pdf_files(self)# 加载pdf文件
defload_pdf_content(self, pdf_path)# 读取pdf文件内容
defsplit_text(self, documents)# 分割读取到的文本
definsert_docs_chromadb(self, docs, batch_size)  # 向向量数据库中插入数据

如果我们要将PDF文件给ElasticSearch服务里,那么这个过程大部分实现逻辑都是一样的,只是插入的对象不同,一个是向向量数据库中插入,一个是向elasticsearch中插入。

这种情况下,

  • • 不好的做法:复制上述代码到一个新的函数中,然后将最后一步insert_docs_chromadb()改为insert_docs_elasticsearch(),这样会导致代码重复。

  • • 较好的做法:对上述的插入过程进行重构,将插入函数通过函数类来调用,通过一个参数vector_db_class来决定插入向量数据库还是ElasticSearch。

# 重构后的pdf_processor.py

import os
import logging
import time
from tqdm import tqdm
from langchain_community.document_loaders importPyMuPDFLoader
from langchain_text_splitters importRecursiveCharacterTextSplitter
from rag.vector_db importVectorDB
from rag.elasticsearch_db importTraditionDB
from utils.logger_config importLoggerManager

logger =LoggerManager().logger


classPDFProcessor:
def__init__(self, directory, db_type='vector', **kwargs):
"""
        初始化 PDF 处理器
        :param directory: PDF 文件所在目录
        :param db_type: 数据库类型 ('vector' 或 'es')
        :param kwargs: 其他参数
        """

        self.directory = directory  # PDF 文件所在目录
        self.db_type = db_type  # 数据库类型
        self.file_group_num = kwargs.get('file_group_num',20)# 每组处理的文件数
        self.batch_num = kwargs.get('batch_num',6)# 每次插入的批次数量
        self.chunksize = kwargs.get('chunksize',500)# 切分文本的大小
        self.overlap = kwargs.get('overlap',100)# 切分文本的重叠大小
        logger.info(f"""
                    初始化PDF文件导入器:
                    配置参数:
                    - 导入的文件路径:{self.directory}
                    - 每次处理文件数:{self.file_group_num}
                    - 每批次处理样本数:{self.batch_num}
                    - 切分文本的大小:{self.chunksize}
                    - 切分文本重叠大小:{self.overlap}
                    """
)

# 根据数据库类型初始化相应的客户端
if db_type =='vector':
            self.vector_db = kwargs.get('vector_db')# 向量数据库实例
            self.es_client =None

            logger.info(f'导入的目标数据库为:向量数据库')
elif db_type =='es':
            self.vector_db =None
            self.es_client = kwargs.get('es_client')# Elasticsearch 客户端

            logger.info(f'导入的目标数据库为:ES数据库')
else:
raiseValueError("db_type must be either 'vector' or 'es'.")

defload_pdf_files(self):
# 这部分代码未做修改,具体内容省略

defload_pdf_content(self, pdf_path):
# 这部分代码未做修改,具体内容省略

defsplit_text(self, documents):
# 这部分代码未做修改,具体内容省略

defprocess_pdfs(self):
# 这部分代码未做修改,具体内容省略

definsert_docs(self, docs, insert_function, batch_size=None):
"""
        将文档插入到指定的数据库,并显示进度
        :param docs: 要插入的文档列表
        :param insert_function: 插入函数
        :param batch_size: 批次大小
        """

if batch_size isNone:
            batch_size = self.batch_num

        logging.info(f"Inserting {len(docs)} documents.")
        start_time = time.time()
        total_docs_inserted =0

        total_batches =(len(docs)+ batch_size -1)// batch_size

with tqdm(total=total_batches, desc="Inserting batches", unit="batch")as pbar:
for i inrange(0,len(docs), batch_size):
                batch = docs[i:i + batch_size]
                insert_function(batch)# 调用传入的插入函数

                total_docs_inserted +=len(batch)

# 计算并显示当前的TPM
                elapsed_time = time.time()- start_time
if elapsed_time >0:
                    tpm =(total_docs_inserted / elapsed_time)*60
                    pbar.set_postfix({"TPM":f"{tpm:.2f}"})

                pbar.update(1)

definsert_to_vector_db(self, docs):
"""
        将文档插入到 VectorDB
        """

        self.vector_db.add_with_langchain(docs)

definsert_to_elasticsearch(self, docs):
"""
        将文档插入到 Elasticsearch
        """

        self.es_client.add_documents(docs)

defprocess_pdfs_group(self, pdf_files_group):
# 读取PDF文件内容
        pdf_contents =[]

for pdf_path in pdf_files_group:
# 读取PDF文件内容
            documents = self.load_pdf_content(pdf_path)

# 将documents 逐一添加到pdf_contents
            pdf_contents.extend(documents)

# 将文本切分成小段
        docs = self.split_text(pdf_contents)

if self.db_type =='vector':
# 将文档插入到 VectorDB
            self.insert_docs(docs, self.insert_to_vector_db)
elif self.db_type =='es':
# 将文档插入到 Elasticsearch
            self.insert_docs(docs, self.insert_to_elasticsearch)
else:
raiseValueError("db_type must be either 'vector' or 'es'.")

说明:

  • • 在类的初始化函数中,我们通过一个参数vector_db来连接对应的数据库实例,同时传入db_type告知PDF处理器需要操作的数据库类型。

  • • 在处理PDF文件时,我们通过参数db_type来决定插入向量数据库还是ElasticSearch。

  • • 在插入文档 insert_docs 中,根据上一步骤传入的 insert_function 来调用具体的插入函数:如果是插入向量数据库,则传入的函数为self.insert_to_vector_db,那么调用时也会调用 insert_to_vector_db ;如果是插入ElasticSearch,则传入的函数为self.insert_to_elasticsearch,那么调用时会调用 insert_to_elasticsearch 。

2.4、使用静态扫描工具优化代码风格

我们可以使用静态扫描工具对代码进行风格优化,如Pylint、Flake8等,一般情况下PyCharm中会自带这些工具。

具体方法:

  1. 1. 启动PyCharm

  2. 2. 打开工程时,选择app目录

  3. 3. 打开任意.py文件后,右上角会有静态扫描问题提示(如下图)

  4. 4. 根据静态扫描的问题,进行代码风格修正(常见代码风格问题请见附录部分)

3、回归测试

在进行上面每一步重构时,都需要使用test_framework.py进行回归测试,确保重构后的代码没有引入新的错误。

由于本项目重构细节的内容非常多,不能一一列举,重构后的内容请查看Gitee或者Github仓库的代码。

内容小结

  • • 代码重构是一件非常重要的工作,它可以帮助提高代码质量,提升代码可读性和可维护性,进而为后续的迭代开发提供基础。

  • • 重构过程的一般步骤:

    • • 代码重构前,需要提前准备好测试框架和测试用例。

    • • 代码重构时,进行代码优化修改,并进行单元测试。

    • • 重构完一个模块并测试通过后,再进行下一个模块的重构。

    • • 重构完成后,进行整体测试,确保重构后的代码没有引入新的错误。

  • • 代码重构常见的优化方向:

    • • 统一配置项的管理,不要在代码中写死配置。

    • • 减少重复代码,使用函数、类等封装代码,提高代码复用性。

    • • 对于参数超长的情况,使用字典或元组等结构。

    • • 使用静态扫描工具来优化代码风格。

    • • ......




53AI,企业落地大模型首选服务商

产品:场景落地咨询+大模型应用平台+行业解决方案

承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询