AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


开源 LLM 在 text-to-sql 任务的 baseline 效果介绍
发布日期:2024-04-29 12:07:31 浏览次数: 1782


本文是对基于开源项目 DB-GPT-Hub 利用开源的 LLM 在 spider 数据集上的 text-to-sql 任务的 baseline 效果介绍。其中开源 LLM 包括 Llama2-7B-Chat、Llama2-13B-Chat、CodeLlama-7B-Instruct、CodeLlama-13B-Instruct、Baichuan2-7B-Chat、Baichuan2-13B-Chat、Qwen-7B-Chat、Qwen-14B-Chat、ChatGLM3-6b,对 spider 数据集进行基于 LoRA 和 QLoRA 的训练,在 spider 官方的评估集上评估其执行准确率。

  • 项目地址:https://github.com/eosphoros-ai/DB-GPT-Hub

具体实验结果和对应的各个模型如上表,其中 method 中的 base 是指未经训练直接用模型本身进行预测评估,lora,qlora 是指大模型基于 LoRA 和 QLoRA 方式的训练方式。EX 中的 easy、 medium、 hard、 extra 分别为评估集中四个难度等级数据的准确率,all 为在所有难度等级上的整体执行准确率。上述训练集均只采用了 Spider 官方的训练数据,且只训练 8 个 epoch。

总体而言,采用 LoRA 和 QLoRA 方法进行微调后的模型相比 base 模型有显著提升。尤其是 CodeLlama-13B-Instruct 模型采用 LoRA 微调后在该任务上效果最好,在四个难度上的执行准确率都有大幅度提高,整体执行准确率达到 0.746 ,是目前效果最佳的模型。随着模型规模的增加,效果也有所提升。例如 Llama2 从 7B 到 13B ,准确率有了约 4-5% 的绝对提升。CodeLlama 从 7B 到 13B 提升更加明显,整体准确率从 0.149 上升到 0.539。国产模型方面, Qwen 的表现效果最好,在 14B 级别时效果整体准确率效果可以达到 0.701。


01

微调训练


随着模型规模的扩大和采用 LoRA、QLoRA 等方法的提示学习,可以有效提升开源 LLM 在文本到 SQL 转换任务上的效果。

详细的训练参数以 Qwen-7B-Chat 进行 LoRA 训练为例,如下:

CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \--model_name_or_path /home/model_files/Qwen-7B-Chat \--do_train \--dataset example_text2sql_train \--max_source_length 2048 \--max_target_length 512 \--template chatml \--finetuning_type lora \--lora_rank 64 \--lora_alpha 32 \--lora_target c_attn \--output_dir dbgpt_hub/output/adapter/qwen-7b-2048_epoch8_lora \--overwrite_cache \--overwrite_output_dir \--per_device_train_batch_size 1 \--gradient_accumulation_steps 16 \--lr_scheduler_type cosine_with_restarts \--logging_steps 500 \--save_steps 2000 \--learning_rate 2e-4 \--num_train_epochs 8 \--plot_loss \--bf16

如果是使用 QLoRA 方法训练的话,同样以 Qwen 模型为例,训练参数如下所示:(主要设置参数量化精度 quantization_bit 为 4)

CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \--model_name_or_path /home/model_files/Qwen-14B-Chat \--do_train \--dataset example_text2sql_train \--max_source_length 2048 \--max_target_length 512 \--template chatml \--quantization_bit 4 \--finetuning_type lora \--lora_rank 64 \--lora_alpha 32 \--lora_target c_attn \--output_dir dbgpt_hub/output/adapter/qwen-14b-2048_epoch8_qlora \--overwrite_cache \--overwrite_output_dir \--per_device_train_batch_size 1 \--gradient_accumulation_steps 16 \--lr_scheduler_type cosine_with_restarts \--logging_steps 500 \--save_steps 2000 \--learning_rate 2e-4 \--num_train_epochs 8 \--plot_loss \--bf16

同时,DB-GPT-Hub 项目还发布了 pip 包,用来降低 Text2SQL 训练的门槛, 除了通过仓库中提供的脚本的方式进行微调之外,还可以使用项目提供的 Python 包进行微调。

安装方式。直接采用 pip 安装即可:

pip install dbgpt_hub

使用方式。微调代码相关如下:

from dbgpt_hub.data_process import preprocess_sft_datafrom dbgpt_hub.train import start_sftfrom dbgpt_hub.predict import start_predictfrom dbgpt_hub.eval import start_evaluate
data_folder = "dbgpt_hub/data"data_info = [{"data_source": "spider","train_file": ["train_spider.json", "train_others.json"],"dev_file": ["dev.json"],"tables_file": "tables.json","db_id_name": "db_id","is_multiple_turn": False,"train_output": "spider_train.json","dev_output": "spider_dev.json",}]
train_args = {"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf","do_train": True,"dataset": "example_text2sql_train","max_source_length": 2048,"max_target_length": 512,"finetuning_type": "lora","lora_target": "q_proj,v_proj","template": "llama2","lora_rank": 64,"lora_alpha": 32,"output_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora","overwrite_cache": True,"overwrite_output_dir": True,"per_device_train_batch_size": 1,"gradient_accumulation_steps": 16,"lr_scheduler_type": "cosine_with_restarts","logging_steps": 50,"save_steps": 2000,"learning_rate": 2e-4,"num_train_epochs": 8,"plot_loss": True,"bf16": True,}
predict_args = {"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf","template": "llama2","finetuning_type": "lora","checkpoint_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora","predict_file_path": "dbgpt_hub/data/eval_data/dev_sql.json","predict_out_dir": "dbgpt_hub/output/","predicted_out_filename": "pred_sql.sql",}
evaluate_args ={"input": "./dbgpt_hub/output/pred/pred_sql_dev_skeleton.sql","gold": "./dbgpt_hub/data/eval_data/gold.txt","gold_natsql": "./dbgpt_hub/data/eval_data/gold_natsql2sql.txt","db": "./dbgpt_hub/data/spider/database","table": "./dbgpt_hub/data/eval_data/tables.json","table_natsql": "./dbgpt_hub/data/eval_data/tables_for_natsql2sql.json","etype": "exec","plug_value": True,"keep_distict": False,"progress_bar_for_each_datapoint": False,"natsql": False,}
preprocess_sft_data(data_folder = data_folder,data_info = data_info)
start_sft(train_args)start_predict(predict_args)start_evaluate(evaluate_args)



02

得分展示



为了进一步展示 DB-GPT-Hub 项目取得的模型基础实验进展,项目提供了查看所有的模型基线得分以及具体的单个数据集上的实验得分、单个模型的实验得分等。

比如:查看所有的模型基线得分:

from dbgpt_hub.baseline import show_scoresshow_scores()

显示结果如下所示:会默认按照平均精度降序输出,目前最高的平均精度为 codellama-13b-instruct 模型使用 lora 方法在 spider 数据集上训练,EX 为 0.746。

比如:还可以查看在数据集 spider 上的所有实验基线得分:

from dbgpt_hub.baseline import show_scoreshow_score(dataset="spider")

比如:查看在数据集 spider,模型 llama2-7b-chat 的实验基线得分:

from dbgpt_hub.baseline import show_scoreshow_score(dataset="spider", model="llama2-7b-chat")

比如:查看在数据集 spider,模型为 llama2-7b-chat,微调方法为 lora 的实验得分:

from dbgpt_hub.baseline import show_scoreshow_score(dataset="spider", model="llama2-7b-chat", method="lora")

from dbgpt_hub.baseline import show_scoreshow_score(dataset="spider", model="llama2-7b-chat", method="lora", prompt="alpaca")

最后,DB-GPT-Hub 项目目前 star 800+,欢迎关注和共建~


53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询