PiSSA方法 | 仅修改Lora初始化方式显著提高模型微调效果
写在前面
大家好,我是知乎@孟繁续。
今天给大家带来一篇高效微调LLM的文章,仅修改Lora的初始化方式,就可以显著提高微调效果,论文全称《PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models》。
Paper: https://arxiv.org/abs/2404.02948
Github: https://github.com/GraphPKU/PiSSA
知乎: https://zhuanlan.zhihu.com/p/687583780
PiSSA:一种参数高效微调方法
随着大模型的参数量日益增长,微调整个模型的开销逐渐变得难以接受。本文提出了一种名为PiSSA的参数高效微调方法,能够用很少的参数近似全参数微调。受到Intrinsic SAID [1]“预训练大模型参数具有低秩性”的启发,PiSSA对预训练模型进行奇异值分解。将模型中的参数分解为被称为适配器(adapter)的矩阵 和 相乘,,再加上一个被称为残差矩阵的 用于修正误差(如图1c所示)。SVD分解的主奇异值和奇异向量被用于初始化适配器,残余的奇异值和奇异向量被用于初始化残差矩阵。因此适配器中的参数是模型的核心参数,而残差矩阵中的参数是修正参数。我们微调参数量较小的核心适配器,冻结参数量大的残差矩阵,就达成了用很少的参数达到近似全参数微调的效果。
PiSSA的名称就来源于主(Principal)奇异值(Singular values)和奇异向量(Singular vectors)适配(Adaption,表示适应下游任务的方法)或者适配器(Adapter,表示插入的可训练的模块),这两个含义在我们的文章中会混着用。PiSSA的发音类似“披萨”(pizza),含义也和披萨很像-整个大模型是一个完整的披萨,我们切掉其中一角,而且是馅料最丰富的一角(主奇异值、奇异向量),重新烘焙(在下游任务上微调)成我们喜欢的口味。
PiSSA在模型架构上和之前广泛使用的LoRA[2]完全一致(如图1b所示),只是初始化的方式不同(如以下代码所示)。因此PiSSA作为LoRA的一种可选初始化方式,可以在peft包中很方便的进行调用。相同的架构也带来很多好处,比如大多数LoRA的优点,PiSSA都直接继承了,如:对残差模型使用4bit量化[3],减小训练开销;训练完成能合并进残差模型,不改变推理过程的模型架构;无需分享完整模型参数,只需要分享参数量很少的PiSSA模块,使用者直接加载PiSSA模块就能自动进行奇异值分解以及赋值;一个模型可以同时使用多个PiSSA模块等等。再比如,一些对LoRA方法的改进,也能与PiSSA进行结合:比如不固定每层的秩,通过学习找到最佳的秩[4];再比如用PiSSA指导更新[5],从而突破秩的限制等等。
# LoRA在peft包中的初始化方式后面:
nn.init.normal_(self.lora_A.weight, std=1 / self.r)
nn.init.zeros_(self.lora_B.weight)
# 增加一种PiSSA初始化可选项:
Ur, Sr, Vr = svd_lowrank(self.base_layer.weight, self.r, niter=4)
# 注意:由于self.base_layer.weight的维度是(out_channel,in_channel,所以AB的顺序相比图示颠倒了一下)
self.lora_A.weight = torch.diag(torch.sqrt(Sr)) @ Vh.t()
self.lora_B.weight = Ur @ torch.diag(torch.sqrt(Sr))
self.base_layer.weight = self.base_layer.weight - self.lora_B.weight @ self.lora_A.weight
然而PiSSA和LoRA背后的原理,截然不同。同样受到Intrinsic SAID[1]启发,LoRA认为大模型微调前后矩阵的变化具有很低的本征秩,因此通过和相乘得到的低秩矩阵来模拟模型的变化。初始阶段,使用高斯噪声初始化,使用0初始化,则, 因此保证模型初始能力没有变化。然后微调和就实现了对进行更新。而我们的PiSSA不关心,而是认为具有很低的本征秩r。因此直接对进行奇异值分解,并用修正误差:。使用SVD分解后奇异值最大的个奇异值、奇异向量进行初始化:
,
,
残差矩阵使用其余的奇异值、奇异向量进行初始化:
,
由于的本征秩为:因此, ,所以。我们直接对进行微调,冻结不重要的修正项。相比LoRA训练由高斯噪声以及0初始化的修正项,固定核心的base model,我们的PiSSA收敛更快、效果更好。
对比高中低奇异值微调效果实验
为了验证使用不同大小奇异值、奇异向量初始化适配器对模型的影响,我们分别使用高、中、低奇异值初始化LLaMA 2-7B[6]、Mistral-7B-v0.1[7]、Gemma-7B[8]的适配器,然后在MetaMathQA[9]数据集上进行训练,将实验结果展示在图2中。从图中可以看出,使用主要奇异值初始化的方法训练损失最小,在GSM8K[10]和MATH[11]验证集上的准确率更高。这一现象验证了我们微调主要奇异值、奇异向量的正确性。
对比PiSSA、LoRA在不同的可训练参数量下微调的效果
我们继续在数学任务上对模型可训练的参数量与效果之间的关系进行消融。从图3.1发现在训练初期,PiSSA的训练loss下降特别快,而LoRA存在不下降,甚至略有上升的阶段。此外,PiSSA的训练loss全程低于LoRA,说明对训练集拟合的更好;从图3.2、3.3、3.4可以看出在每种setting下,PiSSA的loss始终比LoRA低,准确率始终比LoRA高,PiSSA能够使用更少的可训练参数追赶上全参数微调的效果
在不同的任务上对比PiSSA、LoRA微调效果
我们通过微调提升llama 2-7B、Mistral-7B以及Gemma-7B基础模型的数学、代码和对话能力。其中在MetaMathQA[9]上训练,在GSM8K[10]和MATH[11]数据集上验证模型的数学能力;在CodeFeedBack[12]上训练,在HumanEval[13]和MBPP[14]上数据集上验证模型的代码能力;在WizardLM-Evol-Instruct[15]上训练,在MT-Bench[16]上验证模型的对话能力。实验结果展示在下表中。从下表可以看出,使用相同的可训练参数,PiSSA的训练效果显著超越了LoRA的效果,甚至超越了全参数微调的效果。
快速奇异值分解
PiSSA继承了LoRA的优点,使用起来方便又效果比LoRA好,那么代价是什么?PiSSA还是有缺点的,缺点就是在初始化阶段,需要对模型进行奇异值分解。虽然仅在初始时分解一次,但是还是需要几分钟甚至几十分钟的开销。因此我们使用一种快速奇异值分解[17]方法替代标准的SVD分解,通过下表的实验可以看出,仅需几秒钟的时间,就能逼近标准SVD分解的训练集拟合效果。其中Niter表示迭代次数,Niter越大,时间越久但是误差越小。Niter = ∞表示标准SVD。表格中的平均误差表示,快速奇异值分解与标准SVD得到的之间的平均距离。
写在最后
本文通过对预训练模型的权重进行奇异值分解,将其中最重要的参数用于初始化一个名为PiSSA的适配器。微调这个适配器就能近似达到直接训练完整模型的效果。实验表明,PiSSA比LoRA收敛更快,最终效果更好,唯一的代价就是需要几秒的初始化过程。未来我们将会继续验证1)在更多的模型和任务上,PiSSA是否能全面超越LoRA?2)给足够长的训练步数,让两种方法尽可能拟合数据,LoRA能否追得上PiSSA的表现?3)之前与LoRA结合的方法,与PiSSA结合后能否进一步提升PiSSA的效果?4)如何从理论角度解释PiSSA优势的来源?我们期待听到来自社区的使用体验,以及意见建议。那么,您愿意为了更好的训练效果,多花几秒钟时间,一键更改LoRA的初始化为PiSSA吗?
PS:给公众号添加【星标⭐️】不迷路!您的点赞、在看、关注是我坚持的最大动力!
欢迎多多关注公众号「NLP工作站」,加入交流群,交个朋友吧,一起学习,一起进步!
我们的口号是“生命不止,学习不停”!
往期推荐:
InternLM2技术报告 Qwen1.5-MoE模型:2.7B的激活参数量达到7B模型的性能 RAG与Long-Context之争—没必要争 角色扮演大模型的碎碎念 自我蒸馏方法-减轻大模型微调过程中的灾难性遗忘 Yi技术报告细节分享 大模型增量预训练新技巧-解决灾难性遗忘 如何提高LLMs的文本表征(Text Embedding)能力? DEITA-大模型指令微调的数据高效筛选方法 大模型微调技巧 | 高质量指令数据筛选方法-MoDS 辟谣!微软撤回声称ChatGPT为20B参数的论文,并给出解释。 如何看待微软论文声称 ChatGPT 是 20B (200亿) 参数量的模型? 大模型微调技巧-在Embeeding上加入噪音提高指令微调效果 如何从数据集中自动识别高质量的指令数据 BaiChuan2技术报告细节分享&个人想法 大模型LLM微调经验总结&项目更新 打造LLM界的Web UI 是我们在训练大模型,还是大模型在训练我们? Llama2技术细节&开源影响 大模型时代-行业落地再思考 垂直领域大模型的一些思考及开源模型汇总 如何评估大模型-LLMs的好坏? 总结|Prompt在NER场景的应用
参考文献
[1] Intrinsic Dimensionality Explains the Effectiveness of Language Model Fine-Tuning [2] LoRA: Low-Rank Adaptation of Large Language Models [3] QLoRA: Efficient Finetuning of Quantized LLMs [4] AdaLoRA: Adaptive Budget Allocation for Parameter-Efficient Fine-Tuning [5] Delta-LoRA: Fine-Tuning High-Rank Parameters with the Delta of Low-Rank Matrices [6] Llama 2: Open Foundation and Fine-Tuned Chat Models [7] Mistral 7B [8] Gemma: Open Models Based on Gemini Research and Technology [9] MetaMath: Bootstrap Your Own Mathematical Questions for Large Language Models [10] Training Verifiers to Solve Math Word Problems [11] Measuring Mathematical Problem Solving With the MATH Dataset [12] OpenCodeInterpreter: Integrating Code Generation with Execution and Refinement [13] Evaluating Large Language Models Trained on Code [14] Program Synthesis with Large Language Models [15] WizardLM: Empowering Large Language Models to Follow Complex Instructions [16] Judging LLM-as-a-Judge with MT-Bench and Chatbot Arena [17] Finding structure with randomness: Probabilistic algorithms for constructing approximate matrix decompositions