AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


基于 LLamafactory 的异步API高效调用实现与速度对比
发布日期:2024-12-02 21:31:56 浏览次数: 1553 来源:AI悠闲区


背景

原先经常调用各家的闭源大模型的API,如果使用同步的方式调用,速度会很慢。为了加快 API 的调用速度,决定使用异步调用 API 的方式。

简介

本文编写的代码,支持原生的 llamafactory 的数据集导入方式。
推理速度远远快于同步的 API 调用方式。基于 langchain_openai.ChatOpenAI 的 invoke 方法实现异步调用。
下述代码的主要工作介绍如下:

  • 使用 LLamafactory 的原生方法加载 数据集;
  • 封装了异步调用工具类AsyncAPICall,限制API的调用速度,逐块推理,避免程序崩溃导致所有数据丢失;

async_call_api.py

# pip install langchain langchain_openai

import os
import sys
import json
import asyncio


import fire
from tqdm import tqdm
from dataclasses import dataclass
from aiolimiter import AsyncLimiter
from typing import List
import pandas as pd
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv

from llamafactory.hparams import get_train_args
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.data.loader import _get_merged_dataset

load_dotenv()


class AsyncLLM:
   def __init__(
       self,
       model: str = "gpt-3.5-turbo",
       base_url: str = "http://localhost:{}/v1/".format(
           os.environ.get("API_PORT", 8000)
       )
,
       api_key: str = "{}".format(os.environ.get("API_KEY", "0")),
       num_per_second: int = 6,
       **kwargs,
   )
:

       self.model = model
       self.base_url = base_url
       self.api_key = api_key
       self.num_per_second = num_per_second

       self.limiter = AsyncLimiter(self.num_per_second, 1)

       self.llm = ChatOpenAI(
           model=self.model, base_url=self.base_url, api_key=self.api_key, **kwargs
       )

   async def __call__(self, text):
       # 限速
       async with self.limiter:
           return await self.llm.ainvoke([text])


llm = AsyncLLM(
   base_url="http://localhost:{}/v1/".format(os.environ.get("API_PORT", 8000)),
   api_key="{}".format(os.environ.get("API_KEY", "0")),
   num_per_second=10,
)
llms = [llm]


@dataclass
class AsyncAPICall:
   uid: str = "0"

   @staticmethod
   async def _run_task_with_progress(task, pbar):
       result = await task
       pbar.update(1)
       return result

   @staticmethod
   def async_run(
       llms: List[AsyncLLM],
       data: List[str],
       keyword: str = "",
       output_dir: str = "output",
       chunk_size=500,
   )
-> List[str]:


       async def infer_chunk(llms: List[AsyncLLM], data: List):
           results = [llms[i % len(llms)](text) for i, text in enumerate(data)]
           with tqdm(total=len(results)) as pbar:
               results = await asyncio.gather(
                   *[
                       AsyncAPICall._run_task_with_progress(task, pbar)
                       for task in results
                   ]
               )
           return results

       idx = 0
       all_df = []
       file_exist_skip = False
       user_confirm = False

       while idx < len(data):
           file_path = os.path.join(output_dir, "tmp", f"{idx}.csv.temp")

           if os.path.exists(file_path):
               if not user_confirm:
                   while True:
                       user_response = input(
                           f"Find {file_path} file already exists. Do you want to skip them forever?\ny or Y to skip, n or N to rerun to overwrite: "
                       )
                       if user_response.lower() == "y":
                           user_confirm = True
                           file_exist_skip = True
                           break
                       elif user_response.lower() == "n":
                           user_confirm = True
                           file_exist_skip = False
                           break

               if file_exist_skip:
                   tmp_df = pd.read_csv(file_path)
                   all_df.append(tmp_df)
                   idx += chunk_size
                   continue

           tmp_data = data[idx : idx + chunk_size]
           loop = asyncio.get_event_loop()
           tmp_result = loop.run_until_complete(infer_chunk(llms=llms, data=tmp_data))
           tmp_result = [item.content for item in tmp_result]

           tmp_df = pd.DataFrame({"infer": tmp_result})

           if not os.path.exists(p := os.path.dirname(file_path)):
               os.makedirs(p, exist_ok=True)

           tmp_df.to_csv(file_path, index=False)
           all_df.append(tmp_df)
           idx += chunk_size

       all_df = pd.concat(all_df)
       return all_df["infer"]


def async_api_infer(
   model_name_or_path: str = "",
   eval_dataset: str = "",
   template: str = "",
   dataset_dir: str = "data",
   do_predict: bool = True,
   predict_with_generate: bool = True,
   max_samples: int = None,
   output_dir: str = "output",
   chunk_size=50,
)
:


   if len(sys.argv) == 1:
       model_args, data_args, training_args, finetuning_args, generating_args = (
           get_train_args(
               dict(
                   model_name_or_path=model_name_or_path,
                   dataset_dir=dataset_dir,
                   eval_dataset=eval_dataset,
                   template=template,
                   output_dir=output_dir,
                   do_predict=True,
                   predict_with_generate=True,
                   max_samples=max_samples,
               )
           )
       )
   else:
       model_args, data_args, training_args, finetuning_args, generating_args = (
           get_train_args()
       )

   dataset = _get_merged_dataset(
       data_args.eval_dataset, model_args, data_args, training_args, "sft"
   )

   labels = [item[0]["content"] for item in dataset["_response"]]
   prompts = [item[0]["content"] for item in dataset["_prompt"]]

   infers = AsyncAPICall.async_run(
       llms,
       prompts,
       chunk_size=chunk_size,
       output_dir=training_args.output_dir,
   )

   if not os.path.exists(training_args.output_dir):
       os.makedirs(training_args.output_dir, exist_ok=True)

   output_prediction_file = os.path.join(
       training_args.output_dir, "generated_predictions.jsonl"
   )

   with open(output_prediction_file, "w", encoding="utf-8") as writer:
       res: List[str] = []
       for text, pred, label in zip(prompts, infers, labels):
           res.append(
               json.dumps(
                   {"prompt": text, "predict": pred, "label": label},
                   ensure_ascii=False,
               )
           )
       writer.write("\n".join(res))


if __name__ == "__main__":
   fire.Fire(async_api_infer)

上述代码支持 yaml 文件:

async.yaml

## model
model_name_or_path: qwen/Qwen2.5-7B-Instruct

### method
do_predict: true

### dataset
dataset_dir: ../data/
eval_dataset: alpaca_zh_demo
template: qwen

### output
output_dir: output
# overwrite_output_dir: true


### eval
predict_with_generate: true
max_samples: 100

使用下述代码运行:

python async_call_api.py async.yaml

异步调用的API推理结果:

{"prompt": "识别并解释给定列表中的两个科学理论:细胞理论和日心说。", "predict": "细胞理论和日心说是两个重要的科学理论,分别属于生物学和天文学领域。\n\n1. 细胞理论:细胞理论是生物学中的一个基本理论,它描述了所有生命体的基本结构和功能单位——细胞。细胞理论包含以下三个主要观点:\n\n   a. 所有生物体都是由一个或多个细胞组成的。\n   b. 细胞是生命活动的基本单位,负责维持生物体的生命活动。\n   c. 新细胞只能由已存在的细胞分裂产生。\n\n   细胞理论的提出和发展对于生物学研究具有重要意义,它揭示了生物体的基本结构和功能单位,为生物学研究提供了理论基础。\n\n2. 日心说:日心说是一种关于太阳系结构的理论,它描述了太阳在太阳系中的中心位置,而地球和其他行星则围绕太阳旋转。日心说最初由古希腊天文学家阿里斯塔克提出,但直到16世纪,哥白尼重新提出这一理论,并得到了进一步的发展和完善。日心说的提出颠覆了传统的地心说观念,为现代天文学的发展奠定了基础。\n\n这两个科学理论分别在生物学和天文学领域中具有重要的意义,它们揭示了生命体和宇宙的基本结构和运行规律。", "label": "细胞理论是生物科学的一个理论,它认为所有生命体都是由微小的基本单元——细胞所构成。这是生物学的一个基础理论,认为细胞是所有生物的基本结构和功能单位,所有的生物都是由一个或多个细胞组成,细胞只能通过细胞分裂产生新的细胞。这一理论由薛定谔、施瓦内和雪莱于1839年首次提出。\n\n日心说是指太阳是太阳系的中心,也就是说,行星围绕太阳旋转的理论。这个理论打破了传统的地心说观点,认为地球并不是宇宙的中心。日心说的提出者是尼古拉·哥白尼,他在16世纪初发表了他的著作《天体运行论》,阐述了太阳系行星围绕太阳运行的模型,为天文学的发展做出了巨大贡献。"},...,
}

在输出结果中, predict 是大模型的推理结果。方便大家对比 predict 和 label,并评估大模型推理的精度。

为了避免大模型中途程序崩溃,把原始数据分块进行推理。这样即使程序中途崩溃,也能基于之前保存的分快数据继续推理,而不用重新开始推理。

速度对比

异步调用速度

下面是两个异步调用的进度条:

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:22<00:00, 2.27it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:22<00:00, 2.22it/s]

上述异步实验总共数据为100条数据,分块大小为50,故有2个进度条。100条速度44秒全部处理完成,平均处理速度 每秒处理2.2条数据。

同步调用速度

同步调用 LLM api 的代码很简单,如下所示:

infers = []
for prompt in tqdm(prompts):
   infers.append(llm.llm.invoke(prompt))

下面是同步调用的进度条:

100%|████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [06:54<00:00, 4.15s/it]

如果使用同步调用,100条数据,总共耗时 6分54秒,平均每条耗时4.15秒。

方法
推理100条数据时间
同步
6分54秒
异步
44秒

对比之下,异步调用比同步调用快了大约 9.41 倍。



53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询