微信扫码
添加专属顾问
我要投稿
导读
通过阅读本文你能够:
了解如何微调 Llama3增强知识图谱关系抽取
获取微调代码数据集&代码
关注公众号
后台发送 "llama3微调知识图谱" ,获取本文微调数据集&微调代码
后台发送"AI应用","RAG","大模型" 等获取更多相关文章
通过使用 Llama3–70B 创建的合成数据集对 Llama3–8B 进行微调,增强关系抽取
前提
关系抽取(RE)是从非结构化文本中提取关系,以识别各种命名实体之间的联系的任务。它与命名实体识别(NER)一起完成,并且是自然语言处理流程中的一个基本步骤。
随着大型语言模型(LLMs)的兴起,传统的监督方法涉及标记实体范围和分类它们之间的关系(如果有的话)的方法得到增强或完全被基于LLM的方法所取代
Llama3 是生成式AI领域最新的重大发布
基础模型有两种规模,8B和70B,预计将很快发布400B模型。
这些模型可以在HuggingFace平台上获得。70B变体为Meta的新聊天网站Meta.ai提供支持,并且表现与ChatGPT相当。
8B模型是其类别中性能最好的之一。Llama3的架构类似于Llama2,性能提升主要是由于数据升级。该模型配备了升级的分词器和扩展的上下文窗口。它被标记为开源,尽管只有一小部分数据被发布。总的来说,这是一个出色的模型,而且,开源。
Teacher and Student
Llama3-70B可以产生惊人的结果,但由于其规模,它在本地系统上使用起来不切实际,成本昂贵且难以使用。
因此,为了利用其能力,我们让Llama3-70B教会较小的Llama3-8B从非结构化文本中提取关系的任务。
具体来说,借助Llama3-70B的帮助,我们构建了一个旨在进行关系抽取的监督微调数据集。然后,我们使用这个数据集对Llama3-8B进行微调,以增强其关系抽取能力。
要在与此文相关的Google Colab笔记本中复制代码,你需要:
HuggingFace token(用于保存微调模型,可选)和Llama3访问权限,可以通过遵循其中一个模型卡片上的说明来获取
免费的GroqCloud账户(您可以使用Google账户登录)和相应的API key
P.S. 搞AI模型的都应该有以上账户, 如果你没有,应该反思了。。。
工作区设置
对于这个项目,我使用了配备 A100 GPU 和高内存设置的 Google Colab Pro(可以买个pro账号,也不贵)
首先安装所有所需的库:
!pip install -q groq
!pip install -U accelerate bitsandbytes datasets evaluate
!pip install -U peft transformers trl
注意到,整个设置从一开始就可以正常工作,没有任何依赖问题或需要从源代码安装 transformers,尽管模型是新的。
需要给予 Google Colab 对驱动器和文件的访问权限,并设置工作目录:
# For Google Colab settings
from google.colab import userdata, drive
# 这将提示进行授权
drive.mount('/content/drive')
# 设置工作目录
%cd '/content/drive/MyDrive/postedBlogs/llama3RE'
对于那些希望将模型上传到 HuggingFace Hub 的人,我们需要上传 Hub 凭据。在我的情况下,这些凭据存储在 Google Colab 的key中,可以通过左侧的键按钮访问。这一步是可选的。
Hugging Face Hub 设置
from huggingface_hub import login
上传 HuggingFace token(应具有写入权限)从 Colab secrets
HF = userdata.get('HF')
上传模型到 HuggingFace
login(token=HF,add_to_git_credential=True)
我还添加了一些路径变量来简化文件访问:
# 为数据文件夹创建一个路径变量
data_path = '/content/drive/MyDrive/postedBlogs/llama3RE/datas/'
# 完整的微调数据集
sft_dataset_file = f'{data_path}sft_train_data.json'
# Data collected from the the mini-test
mini_data_path = f'{data_path}mini_data.json'
# 测试数据包含所有三个输出
all_tests_data = f'{data_path}all_tests.json'
# 调整后的训练数据集
train_data_path = f'{data_path}sft_train_data.json'
# 创建一个路径变量,用于将SFT模型保存到本地
sft_model_path = '/content/drive/MyDrive/llama3RE/Llama3_RE/'
现在我们的工作空间已经设置好,我们可以继续第一步,即为关系抽取任务构建一个合成数据集。
使用 Llama3–70B 创建一个用于关系抽取的合成数据集
目前有几个关系抽取数据集可供使用,其中最著名的是:
CoNLL04. (https://paperswithcode.com/dataset/conll04)数据集。
此外,还有一些出色的数据集,比如在HuggingFace上可用的web_nlg,以及由AllenAI开发的SciREX。然而,大多数这些数据集都带有限制性许可。
受`web_nlg`数据集格式的启发,我们将构建自己的数据集。如果我们计划对在我们的数据集上训练的模型进行微调,这种方法将特别有用。首先,我们需要一系列用于关系抽取任务的短句子。我们可以以各种方式编制这个语料库。
收集一系列句子
我们将使用databricks-dolly-15k
https://huggingface.co/datasets/databricks/databricks-dolly-15k),
这是由Databricks员工在2023年生成的开源数据集。该数据集旨在进行监督微调,包括四个特征:指令、上下文、响应和类别。在分析了八个类别之后,我决定保留`information_extraction`类别中上下文的第一句话。数据解析步骤如下所示:
from datasets import load_dataset
加载数据集
dataset = load_dataset("databricks/databricks-dolly-15k")
选择数据集中所需的类别
ie_category = [e for e in dataset["train"] if e["category"]=="information_extraction"]
保留每个实例的上下文
ie_context = [e["context"] for e in ie_category]
将文本分割成句子(在句号处),并保留第一句话
reduced_context = [text.split('.')[0] + '.' for text in ie_context]
仅保留指定长度的序列(使用字符长度)
sampler = [e for e in reduced_context if 30 < len(e) < 170]
选择过程产生了一个包含1,041个句子的数据集。鉴于这是一个迷你项目,我没有手动挑选句子,因此,一些样本可能并不理想适合我们的任务。在一个专门用于生产的项目中,我会仔细挑选只有最合适的句子。然而,对于这个项目的目的,这个数据集将足够。
格式化数据
我们首先需要创建一个系统消息,该消息将定义输入提示并指示模型如何生成答案:
system_message = """您是一位经验丰富的注释者。
从以下文本中提取所有实体及其之间的关系。
将答案写成三元组实体1|关系|实体2。
不要添加其他内容。
示例文本:Alice is from France.
答案:Alice|is from|France.
"""
由于这是实验阶段,我将对模型的要求保持在最低限度。我确实测试了几个其他提示,包括一些要求以CoNLL格式输出的提示,其中实体被分类,模型表现得相当不错。
但是,为了简单起见,现在我们将坚持基本原则。
我们还需要将数据转换为对话格式:
messages = [[
{"role": "system","content": f"{system_message}"},
{"role": "user", "content": e}] for e in sampler]
Groq 客户端和 API
Llama3 刚刚发布了几天,API 选项的可用性仍然有限。
虽然 Llama3–70B 提供了聊天界面,但这个项目需要一个能够用几行代码处理我的 1,000 个句子的 API。
以下解释了如何免费使用 GroqCloud API。更多详情请参考视频。
提醒一下:需要登录并从 GroqCloud 网站检索免费的 API 密钥。
我的 API 密钥已经保存在 Google Colab 的 secrets 中。
首先初始化 Groq 客户端:
import os
from groq import Groq
gclient = Groq(
api_key=userdata.get("GROQ"),
)
接下来,我们需要定义一些辅助函数,这些函数将使我们能够有效地与 Meta.ai 聊天界面进行交互
import time
from tqdm import tqdm
def process_data(prompt):
"""发送一个请求并检索模型的生成内容。"""
chat_completion = gclient.chat.completions.create(
messages=prompt, # 发送给模型的输入提示
model="llama3-70b-8192", # 根据 GroqCloud 标签
temperature=0.5, # 控制多样性
max_tokens=128, # 生成的最大标记数
top_p=1, # 考虑的可能性加权选项的比例
stop=None, # 表示停止生成的字符串
stream=False, # 如果设置,将发送部分消息
)
return chat_completion.choices[0].message.content
def send_messages(messages):
"""以批处理的方式处理消息,并在批处理之间暂停。"""
batch_size = 10
answers = []
for i in tqdm(range(0, len(messages), batch_size)): # 每批处理 10 条消息
batch = messages[i:i+10] # 获取下一批消息
for message in batch:
output = process_data(message)
answers.append(output)
if i + 10 < len(messages): # 检查是否还有批次
time.sleep(10) # 等待 10 秒
return answers
第一个函数 process_data() 作为 Groq 客户端的聊天完成函数的包装器。第二个函数 send_messages() 以小批量处理数据。如果您在 Groq 游乐场页面上跟随设置链接,您将找到一个指向 Limits 的链接,其中详细说明了我们可以使用免费 API 的条件,包括对请求和生成的标记数量的限制。为了避免超出这些限制,我在每批处理 10 条消息后添加了 10 秒的延迟,尽管在我的情况下这并不是严格必要的。您可能希望尝试这些设置。
现在剩下的就是生成我们的关系抽取数据并将其与初始数据集集成:
输出:
使用 Llama3-70B 进行数据生成
answers = send_messages(messages)
将输入数据与生成的数据集合并
combined_dataset = [{'text': user, 'gold_re': output} for user, output in zip(sampler, answers)]
评估 Llama3–8B 对关系抽取的性能
在进行模型的微调之前,重要的是对其在几个样本上的性能进行评估,以确定是否确实需要进行微调。
构建测试数据集
我们将从刚刚构建的数据集中选择 20 个样本并将它们设置为测试集。数据集的其余部分将用于微调。
import random
random.seed(17)
# 选择 20 个随机条目
mini_data = random.sample(combined_dataset, 20)
# 构建对话格式
parsed_mini_data = [[{'role': 'system', 'content': system_message},
{'role': 'user', 'content': e['text']}] for e in mini_data]
```python
{'text': 'Long before any knowledge of electricity existed, people were aware of shocks from electric fish.',
'gold_re': 'people|were aware of|shocks\nshocks|from|electric fish\nelectric fish|had|electricity',
'test_re': 'electric fish|were aware of|shocks'}
对于完整的测试数据集,请参考Google Colab笔记本。
仅从这个例子中,就可以清楚地看出Llama3-8B在其关系抽取能力方面有待改进。让我们努力提升它。
Llama3–8B的监督微调
我们将利用一整套技术来辅助我们,包括 QLoRA 和 Flash Attention。我不会在这里深入讨论选择超参数的具体细节,但如果你对进一步探索感兴趣,可以查看这些参考资料[4]和[5]。
A100 GPU 支持 Flash Attention 和 bfloat16,并且具有约40GB的内存,这对我们的微调需求是足够的。
准备 SFT 数据集
我们首先将数据集解析为对话格式,包括系统消息、输入文本和期望的答案,我们从 Llama3–70B 生成中获取。然后将其保存为 HuggingFace 数据集:
def create_conversation(sample):
return {
"messages": [
{"role": "system","content": system_message},
{"role": "user", "content": sample["text"]},
{"role": "assistant", "content": sample["gold_re"]}
]
}
from datasets import load_dataset, Dataset
train_dataset = Dataset.from_list(train_data)
转换为口语化格式
train_dataset = train_dataset.map(create_conversation,
remove_columns=train_dataset.features,
batched=False)
选择模型
model_id = "meta-llama/Meta-Llama-3-8B"
加载分词器
from transformers import AutoTokenizer
Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id,
use_fast=True,
trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'
设置最大长度
tokenizer.model_max_length = 512
选择量化参数
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
加载模型
from transformers import AutoModelForCausalLM
from peft import prepare_model_for_kbit_training
from trl import setup_chat_format
device_map = {"": torch.cuda.current_device()} if torch.cuda.is_available() elseNone
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device_map,
attn_implementation="flash_attention_2",
quantization_config=bnb_config
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)
LoRA 配置
from peft import LoraConfig
# 根据Sebastian Raschka的研究结果
peft_config = LoraConfig(
lora_alpha=128, #32
lora_dropout=0.05,
r=256, #16
bias="none",
target_modules=["q_proj", "o_proj", "gate_proj", "up_proj",
"down_proj", "k_proj", "v_proj"],
task_type="CAUSAL_LM",
)
当目标是所有线性层时,可以获得最佳结果。如果内存限制是一个问题,选择更标准的值,如alpha=32和rank=16可能是有益的,因为这些设置会导致参数大大减少。
训练参数
from transformers import TrainingArguments
# 适应自 Phil Schmid 博客文章
args = TrainingArguments(
output_dir=sft_model_path, # 保存模型和存储库 ID 的目录
num_train_epochs=2, # 训练周期数
per_device_train_batch_size=4, # 训练期间每个设备的批处理大小
gradient_accumulation_steps=2, # 执行向后/更新传递之前的步骤数
gradient_checkpointing=True, # 使用梯度检查点以节省内存,在分布式训练中使用
optim="adamw_8bit", # 如果内存不足,请选择 paged_adamw_8bit
logging_steps=10, # 每 10 步记录一次日志
save_strategy="epoch", # 每个周期保存检查点
learning_rate=2e-4, # 学习率,基于 QLoRA 论文
bf16=True, # 使用 bfloat16 精度
tf32=True, # 使用 tf32 精度
max_grad_norm=0.3, # 基于 QLoRA 论文的最大梯度范数
warmup_ratio=0.03, # 基于 QLoRA 论文的预热比例
lr_scheduler_type="constant", # 使用恒定学习率调度程序
push_to_hub=True, # 将模型推送到 Hugging Face hub
hub_model_id="llama3-8b-sft-qlora-re",
report_to="tensorboard", # 报告指标到 tensorboard
)
如果选择在本地保存模型,可以省略最后三个参数。您还可能需要调整 per_device_batch_size 和 gradient_accumulation_steps 以防止内存不足(OOM)错误。
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=sft_dataset,
peft_config=peft_config,
max_seq_length=512,
tokenizer=tokenizer,
packing=False, # True if the dataset is large
dataset_kwargs={
"add_special_tokens": False, # the template adds the special tokens
"append_concat_token": False, # no need to add additional separator token
}
)
trainer.train()
trainer.save_model()
训练,包括模型保存,大约花了10分钟。
让我们清除内存,为推理测试做准备。如果您使用的是内存较少的 GPU,并遇到 CUDA 内存不足 (OOM) 错误,您可能需要重新启动运行时。
import torch
import gc
del model
del tokenizer
torch.cuda.empty_cache()
使用 SFT 模型进行推理
在这最后一步中,我们将加载半精度的基础模型,以及 Peft 适配器。对于这个测试,我选择不将模型与适配器合并。
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
import torch
HF 模型
peft_model_id = "solanaO/llama3-8b-sft-qlora-re"
# 使用 PEFT 适配器加载模型
model = AutoPeftModelForCausalLM.from_pretrained(
peft_model_id,
device_map="auto",
torch_dtype=torch.float16,
offload_buffers=True
)
接下来,我们加载分词器:
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
然后我们构建文本生成管道:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
我们加载测试数据集,其中包括我们之前设置的 20 个样本,并以对话式格式对数据进行格式化。但是,这次我们省略了助手消息,并将其格式化为 Hugging Face 数据集:
def create_input_prompt(sample):
return {
"messages": [
{"role": "system","content": system_message},
{"role": "user", "content": sample["text"]},
]
}
from datasets import Dataset
test_dataset = Dataset.from_list(mini_data)
转换为口语化格式
test_dataset = test_dataset.map(create_input_prompt,
remove_columns=test_dataset.features,
batched=False)
样本测试
让我们使用 SFT Llama3–8B 生成关系抽取输出,并将其与之前的两个输出在单个实例上进行比较:
生成输入提示
prompt = pipe.tokenizer.apply_chat_template(test_dataset[2]["messages"][:2], tokenize=False, add_generation_prompt=True) Output: ```markdown # 生成输出 outputs = pipe(prompt, max_new_tokens=128, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, )# 显示结果 print(f"问题: {test_dataset[2]['messages'][1]['content']}\n") print(f"Gold-RE: {test_sampler[2]['gold_re']}\n") print(f"LLama3-8B-RE: {test_sampler[2]['test_re']}\n") print(f"SFT-Llama3-8B-RE: {outputs[0]['generated_text'][len(prompt):].strip()}")
我们发现,通过微调,Llama3-8B的关系提取能力显著提高。尽管微调数据集既不是非常干净,也不是特别大,但结果很不错。
有关20个样本数据集的完整结果,请参阅Google Colab notebook。请注意,推断测试需要更长时间,因为我们以半精度加载模型。
结论
总之,通过利用Llama3-70B和一个可用的数据集,我们成功地创建了一个合成数据集,然后用它来对Llama3-8B进行特定任务的微调。这个过程不仅使我们熟悉了Llama3,还使我们能够应用来自Hugging Face的简单技术。我们观察到,与Llama2密切合作的经验与Llama3非常相似,显著改进的地方是增强了输出质量和更有效的分词器。
对于那些有兴趣进一步挑战界限的人,可以考虑用更复杂的任务来挑战模型,比如对实体和关系进行分类,并利用这些分类来构建知识图谱。
参考文献
Somin Wadhwa, Silvio Amir, Byron C. Wallace, 重温大语言模型时代的关系抽取, [arXiv.2305.05003](https://arxiv.org/pdf/2305.05003.pdf) (2023).
Meta, Meta Llama 3: 迄今为止最有能力的开放式LLM介绍, 2024年4月18日 ([link)](https://ai.meta.com/blog/meta-llama-3/).
Philipp Schmid, Omar Sanseviero, Pedro Cuenca, Youndes Belkada, Leandro von Werra, [欢迎 Llama 3 - Met's new open LLM,](https://huggingface.co/blog/llama3) 2024年4月18日.
Sebastian Raschka, [使用LoRA(低秩适应)对LLM进行微调的实用技巧](https://magazine.sebastianraschka.com/p/practical-tips-for-finetuning-llms), Ahead of AI, 2023年11月19日.
Philipp Schmid, [2024年如何使用Hugging Face对LLM进行微调](https://www.philschmid.de/fine-tune-llms-in-2024-with-trl) 2024年1月22日.
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2025-02-01
2025-01-01
2024-08-13
2025-02-04
2024-07-25
2024-04-25
2024-06-13
2024-09-23
2024-04-26
2024-08-21
2025-03-17
2025-03-17
2025-03-17
2025-03-17
2025-03-17
2025-03-17
2025-03-16
2025-03-16