上街免费网站建设,如何绑定网站,义乌网站建设方案详细,如何建设销售型企业网站本文要介绍的是大模型的微调训练方法之一----LoRA。
0 背景
现在大模型非常火爆#xff0c;大家都在想方设法应用大模型。 当前很多大模型虽说可以zero-shot直接使用#xff0c; 但是在具体应用上一般还是微调一下效果更好#xff0c; 也就是常说的finetune。 在小模型时代…本文要介绍的是大模型的微调训练方法之一----LoRA。
0 背景
现在大模型非常火爆大家都在想方设法应用大模型。 当前很多大模型虽说可以zero-shot直接使用 但是在具体应用上一般还是微调一下效果更好 也就是常说的finetune。 在小模型时代 finetune不是个问题。 但大模型时代 finetune是个大问题。 这是因为现在的大模型参数动辄10B起 训练的代价非常高昂即使是finetune也对计算资源有很高要求finetune只是训练的步数少 对显存等计算资源的占用并没有少)。 没个上百G的显存是玩不动的 这对普通人的门槛实在太高了。
那么高效的finetune方式就非常必要了。LoRA就是高效finefune方法的一种。
1 LoRA原理
LoRA论文 LoRA: Low-Rank Adaptation of Large Language Models LoRA的原理非常简单 先上一张图 其实从图上已经能清楚地看到大致的原理的。 通俗地讲 它的原理是这样的大模型都是过参数化的 当用于特定任务时 其实只有一小部分参数起主要作用。 也就是参数矩阵维度很高 但可以用低维矩阵分解近似。其实这个思想与矩阵特征向量 主成分分析 压缩感知等有异曲同工之妙。
具体做法是 在网络中增加一个旁路结构旁路是A和B两个矩阵相乘。 A矩阵的维度是dxr, B 矩阵的维度是rxd, 其中rd, 一般r取1248就够了。那么这个旁路的参数量将远远小于原来网络的参数W。LoRA训练时 我们冻结原来网络的参数W, 只训练旁路参数A和B。 由于A和B的参数量远远小于W, 那么训练时需要的显存开销就大约等于推理时的开销。 对采用Adam优化器来说 需要的显存就大约相当于全参数finetune的1/3, 极大地减小了训练的代价。
论文中作者的实验也证明了这一点。 在GPT-3 175B的finetune中 采用LoRA微调显存的消耗从1.2TB 降低到了350GB, 大约是三分之一
其实采用这种旁路相加的方式 与ResNet的跳连方式也有异曲同工之妙。 原网络的参数不变 在旁路上做些微小改变 适应特定新任务。 这样就可以让网络基本保持原来的能力 在特定任务上更精进了一步。
值得注意的是 LoRA微调并没有改变原有的预训练参数 只是针对特定任务微调出了新的少量参数 新的这些参数要与原有的预训练参数配合使用实际使用时 都是把旁路的参数和原来的参数直接合并 也就是参数相加 这样就完全不会增加推理时间。这是非常方便的 针对不同的任务 都可以训练出自己的LoRA参数 然后与原本的预训练参数结合 做成插件式的应用。 这就是最近大火的SD LoRA。全参数微调一般没这个条件 但LoRA微调还是可以的。 目前Civitai上有上万LoRA的模型 并且还在迅速增加。
2 代码详解
LoRA代码 https://github.com/microsoft/LoRA
LoRA原理很简单 代码实现也不复杂。 简单地说在模型实现上 要在特定的模块上加一个旁路 这个旁路就是两个矩阵相乘的形式。这些特定的模块理论上可以是任何模块 目前作者实现的是在Linear, Embeding, Conv, Attention(只改其中的q和v)这些模块上加。
具体实现见https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
拿其中的Linear做个简单分析吧 其他都是类似的。
class LoRALayer():def __init__(self, r: int, lora_alpha: int, lora_dropout: float,merge_weights: bool,):self.r rself.lora_alpha lora_alpha# Optional dropoutif lora_dropout 0.:self.lora_dropout nn.Dropout(plora_dropout)else:self.lora_dropout lambda x: x# Mark the weight as unmergedself.merged Falseself.merge_weights merge_weightsclass Linear(nn.Linear, LoRALayer):# LoRA implemented in a dense layerdef __init__(self, in_features: int, out_features: int, r: int 0, lora_alpha: int 1, lora_dropout: float 0.,fan_in_fan_out: bool False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)merge_weights: bool True,**kwargs):nn.Linear.__init__(self, in_features, out_features, **kwargs)LoRALayer.__init__(self, rr, lora_alphalora_alpha, lora_dropoutlora_dropout,merge_weightsmerge_weights)self.fan_in_fan_out fan_in_fan_out# Actual trainable parametersif r 0:self.lora_A nn.Parameter(self.weight.new_zeros((r, in_features)))self.lora_B nn.Parameter(self.weight.new_zeros((out_features, r)))self.scaling self.lora_alpha / self.r# Freezing the pre-trained weight matrixself.weight.requires_grad Falseself.reset_parameters()if fan_in_fan_out:self.weight.data self.weight.data.transpose(0, 1)def reset_parameters(self):nn.Linear.reset_parameters(self)if hasattr(self, lora_A):# initialize A the same way as the default for nn.Linear and B to zeronn.init.kaiming_uniform_(self.lora_A, amath.sqrt(5))nn.init.zeros_(self.lora_B)def train(self, mode: bool True):def T(w):return w.transpose(0, 1) if self.fan_in_fan_out else wnn.Linear.train(self, mode)if mode:if self.merge_weights and self.merged:# Make sure that the weights are not mergedif self.r 0:self.weight.data - T(self.lora_B self.lora_A) * self.scalingself.merged Falseelse:if self.merge_weights and not self.merged:# Merge the weights and mark itif self.r 0:self.weight.data T(self.lora_B self.lora_A) * self.scalingself.merged True def forward(self, x: torch.Tensor):def T(w):return w.transpose(0, 1) if self.fan_in_fan_out else wif self.r 0 and not self.merged:result F.linear(x, T(self.weight), biasself.bias) result (self.lora_dropout(x) self.lora_A.transpose(0, 1) self.lora_B.transpose(0, 1)) * self.scalingreturn resultelse:return F.linear(x, T(self.weight), biasself.bias)
在Linear层的实现上多继承了一个LoRALayer, LoRALayer中就是设置了一些参数 最主要的就是上面的讲道的矩阵的秩r了其他就是一些辅助参数 如控制训练和推理时主路参数和旁路参数是否合并等等。 在Linear层中 多定义了A和B两个可训练的参数矩阵 然后在forward中把主路和旁路输出相加 基本上就是完全按照原理来的。
3 使用
实际使用LoRA微调时 也不用自己向上面那样实现了。上面的loralib库已经实现好了 直接使用就好了。具体而言 就是把网络中原来使用nn.Linear用loralib库中的Linear替换就可以了 其他的模块同理。
实际上 还有更简洁的方式huggingface pert库很贴心地把各种finetune方式都做了集成 更加简单和方便。