微信扫码
与创始人交个朋友
我要投稿
LoRA fine-tuning,冻结模型参数,引入低阶矩阵,更新一小部分权重来适应特定任务。通过减少训练参数的数量,降低了计算成本和内存需求,同时,保留了预训练模型的知识和泛化能力。
LoRA的实现中,原始的模型权重保持不变,而是通过在前向传播过程中引入W_A和W_B的乘积来模拟权重的更新。这个过程可以表示为:
h = x @ (W_A @ W_B) * α
其中,h是模型的输出,x是输入,W_A和W_B是LoRA权重,α是缩放因子。通过这种方式,α直接影响了LoRA权重更新,对模型输出的贡献程度。如果α设置得较大,那么W_A和W_B的乘积对输出的影响就会更大,从而在fine-tuning过程中更强烈地调整模型以适应特定任务。相反,如果α设置得较小,那么权重更新的影响就会减弱,模型的变化就会更保守。
# step1 安装依赖
!pip install trl transformers accelerate datasets bitsandbytes einops torch huggingface-hub git+https://github.com/huggingface/peft.git
from datasets import load_dataset
from random import randrange
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,TrainingArguments,pipeline
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, AutoPeftModelForCausalLM
from trl import SFTTrainer
# 训练数据使用:https://huggingface.co/datasets/samsum
dataset = load_dataset("samsum")
model_name = "google/flan-t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
#Makes training faster but a little less accurate
model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
#setting padding instructions for tokenizer
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
def prompt_instruction_format(sample):
# 格式化prompt
return f"""### Instruction:
Use the Task below and the Input given to write the Response:
### Task:
Summarize the Input
### Input:
{sample['dialogue']}
### Response:
{sample['summary']}
"""
# Create the trainer
trainingArgs = TrainingArguments(
output_dir='output',
num_train_epochs=1,
per_device_train_batch_size=4,
save_strategy="epoch",
learning_rate=2e-4
)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset['train'],
eval_dataset = dataset['test'],
peft_config=peft_config,
tokenizer=tokenizer,
packing=True,
formatting_func=prompt_instruction_format,
args=trainingArgs,
)
trainer.train()
训练过程建议使用google colab,加载数据集等操作,速度飞快。
53AI,企业落地应用大模型首选服务商
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-04-26
2024-05-14
2024-03-30
2024-04-12
2024-05-10
2024-07-18
2024-05-28
2024-05-22
2024-04-25
2024-04-26
2024-11-22
2024-11-22
2024-11-21
2024-11-20
2024-11-19
2024-11-18
2024-11-18
2024-11-16