AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


100行代码演示LoRA fine-tuning!
发布日期:2024-04-23 23:12:18 浏览次数: 2637 来源:牛爷儿


引言

LoRA fine-tuning,冻结模型参数,引入低阶矩阵,更新一小部分权重来适应特定任务。通过减少训练参数的数量,降低了计算成本和内存需求,同时,保留了预训练模型的知识和泛化能力。

LoRA方案中,提到的rank和alpha是什么?

之前的文章解释了rank是什么,这里再重新回顾下,同时,也解释下alpha超参数的含义。

在LoRA(Low-Rank Adaptation)fine-tuning中,alpha(α)和rank(秩)是超参数,共同决定了fine-tuning过程中,权重更新的效率和效果。

Alpha(α)的作用是控制LoRA权重更新的规模,是一个缩放因子,用于调整LoRA权重(W_A和W_B)在前向传播中的影响程度。通过调整alpha的值,可以控制fine-tuning过程中,对原始预训练模型权重的修改程度。如果alpha设置得较大,那么LoRA权重的变化对模型的影响也会更大;反之则反。

Rank(秩)在LoRA中指的是在LoRA权重矩阵分解中使用的秩。秩决定了LoRA权重矩阵分解中两个矩阵(W_A和W_B)的维度。较低的秩意味着更少的参数需要更新,从而减少了计算资源的需求和内存占用。然而,如果秩太低,可能无法充分捕捉到fine-tuning任务所需的特征变化,从而影响模型性能。因此,选择合适的rank也非常重要,实际项目中,需要不断验证调整,得到较好的结果。

Alpha和rank之间的关系是相互影响的,在实际操作中,调整alpha和rank的相对大小可以改变fine-tuning的效果。例如,如果rank保持不变,增加alpha会增强fine-tuning的效果,而减少alpha则会减弱这些效果。

LoRA的实现中,原始的模型权重保持不变,而是通过在前向传播过程中引入W_A和W_B的乘积来模拟权重的更新。这个过程可以表示为:


h = x @ (W_A @ W_B) * α


其中,h是模型的输出,x是输入,W_A和W_B是LoRA权重,α是缩放因子。通过这种方式,α直接影响了LoRA权重更新,对模型输出的贡献程度。如果α设置得较大,那么W_A和W_B的乘积对输出的影响就会更大,从而在fine-tuning过程中更强烈地调整模型以适应特定任务。相反,如果α设置得较小,那么权重更新的影响就会减弱,模型的变化就会更保守。

利用peft,transformer等库,实现LoRA微调

# step1 安装依赖!pip install trl transformers accelerate datasets bitsandbytes einops torch huggingface-hub git+https://github.com/huggingface/peft.git
from datasets import load_datasetfrom random import randrangeimport torchfrom transformers import AutoTokenizer, AutoModelForSeq2SeqLM,TrainingArguments,pipelinefrom peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, AutoPeftModelForCausalLMfrom trl import SFTTrainer
# 训练数据使用:https://huggingface.co/datasets/samsumdataset = load_dataset("samsum")
model_name = "google/flan-t5-small"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
#Makes training faster but a little less accurate model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
#setting padding instructions for tokenizertokenizer.pad_token = tokenizer.eos_tokentokenizer.padding_side = "right"
def prompt_instruction_format(sample):  # 格式化promptreturn f"""### Instruction:Use the Task below and the Input given to write the Response:
### Task:Summarize the Input
### Input:{sample['dialogue']}
### Response:{sample['summary']}"""
# Create the trainertrainingArgs = TrainingArguments(output_dir='output',num_train_epochs=1,per_device_train_batch_size=4,save_strategy="epoch",learning_rate=2e-4)
peft_config = LoraConfig(lora_alpha=16,lora_dropout=0.1,r=64,bias="none",task_type="CAUSAL_LM",)
trainer = SFTTrainer(model=model,train_dataset=dataset['train'],eval_dataset = dataset['test'],peft_config=peft_config,tokenizer=tokenizer,packing=True,formatting_func=prompt_instruction_format,args=trainingArgs,)
trainer.train()

训练过程建议使用google colab,加载数据集等操作,速度飞快。


53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询