php开发网站,关于网站建设的调研报告,网站建设实训,衡水做网站优化我自己的原文哦~ https://blog.51cto.com/whaosoft/12472192
#LightSeq
最高加速9倍#xff01;字节跳动开源8比特混合精度Transformer引擎,近年来#xff0c;Transformer 已经成为了 NLP 和 CV 等领域的主流模型#xff0c;但庞大的模型参数限制了它的高效训练和推理。…我自己的原文哦~ https://blog.51cto.com/whaosoft/12472192
#LightSeq
最高加速9倍字节跳动开源8比特混合精度Transformer引擎,近年来Transformer 已经成为了 NLP 和 CV 等领域的主流模型但庞大的模型参数限制了它的高效训练和推理。于是字节跳动在 2019 年 12 月和 2021 年 6 月分别推出了高效推理和训练引擎 LightSeq大大加速了 Transformer 系列模型的训练和推理也打通了 Transformer 从训练到推理的整个流程极大优化了用户使用体验。最近LightSeq 训练引擎相关论文[1]被录用难度极高的超算领域国际顶会 SC22 接收得到了学术界的广泛认可
SC22 接收论文https://sc22.supercomputing.org/presentation/?idpap211sesssess154代码地址https://github.com/bytedance/lightseq
如何继续提升速度降低计算精度是比较直接的方法。2017 年以来fp16 混合精度技术 [2] 获得了广泛应用。在对模型效果无损的前提下将模型训练和推理的速度提升了 50% 以上。而为了维持模型效果更低精度的方法例如 int8通常需要使用如下传统方案
首先使用 fp16 混合精度将模型训练至收敛然后在模型计算密集型算子的权重、输入和输出位置处插入伪量化结点进行量化感知训练最后将带有伪量化结点的模型计算图转换到专用的 int8 推理引擎中进行服务部署和模型推理。
虽然在多数任务上上述方案可以实现模型效果无损但还是存在以下问题
使用方法复杂。例如要多一次量化感知训练 [4] 的过程并且带有伪量化节点的计算图转换复杂。训练速度慢。由于目前流行的深度学习框架不支持 int8 精度所以量化感知训练需要插入 fp16 的伪量化结点来模拟 int8 量化导致量化感知训练反而比 fp16 混合精度训练慢 2-3 倍。推理部署难且加速比低。对比 fp32、fp16 等类型int8 硬件和底层软件库优化相对滞后。例如在 NVIDIA GPU 上int8 矩阵乘法加速受限于硬件架构和特定 shape实际加速比远远低于理论值。
在下文中如无特殊说明量化都是指的 int8 精度的量化。 针对这些问题字节跳动推出了全新版本的 LightSeq GPU 量化训练与推理引擎。支持 Transformer 系列模型的量化训练与推理并做到了开箱即用用户友好。LightSeq 快准狠地实现了 int8 精度的量化训练和推理
快A100 多卡训练最高加速 5.2 倍T4 单卡推理最高加速 8.9 倍。准训练和推理效果基本无损。狠相同数据量下显存占用最高减少 68%模型存储空间减少 75%。
总体来说LightSeq 新版量化训练与推理引擎具有如下几个优点
1. 丰富的支持
支持完整的 Transformer 模块和多种解码算法支持 Transformer、BERT、GPT、BART、ViT 等多种模型结构支持 Fairseq、Hugging Face、NeurST 等多种训练框架接入量化训练、导出模型以及量化推理提供了丰富的样例供用户参考。
2. 卓越的性能
相比于 fp16 精度的 LightSeq 推理引擎int8 量化还可以进一步加速最高 70%相比于 PyTorch 推理更是达到了最高 8.9 倍的加速比。同时显存占用相比 fp16 推理引擎降低了 30% 左右模型存储空间只需要原来的四分之一。最后经过多个任务的验证推理效果几乎无损。
3. 便捷的使用
LightSeq 已经针对多个训练库进行了量化支持可以一键开启量化训练然后轻松导出为 LightSeq 支持的模型格式最后实现量化推理。除此之外LightSeq 还支持训练后量化无需额外训练即可体验量化推理。 如上图所示为了最大程度减小量化带来的损失首先需要用 fp16 精度训练一个浮点数模型将模型效果训到最好。然后开启量化进行 finetune得到微调过的量化模型此时模型效果已经基本恢复到浮点数模型的水平。接着将量化模型转换为 LightSeq 支持的 PB 或者 HDF5 模型格式最后用 LightSeq 进行量化推理。
安装方法
LightSeq 安装非常简单只需要一行命令即可
pip install lightseq
量化训练
LightSeq 支持 Fairseq、Hugging Face、NeurST 等训练框架的量化接入同时也可以自定义模型并开启量化训练。以 encoder 层为例只需要先定义浮点数模型然后开启量化即可
from lightseq.training import LSTransformerEncoderLayer
from lightseq.training.ops.pytorch.quantization import enable_quant
config LSTransformerEncoderLayer.get_config(modelbert-base,max_batch_tokens4096,max_seq_len512,fp16True,local_rank0,
)
layer LSTransformerEncoderLayer(config)
# 开启量化
layer.apply(enable_quant)
量化推理
LightSeq 提供了便捷的 python 推理接口只需要三行代码即可实现快速的量化推理
import lightseq.inference as lsi
model lsi.QuantTransformer(pb_path, batch_size)
result model.infer(input)
此外 LightSeq 还提供了 BERT、GPT、ViT 等模型的 python 接口分别调用 QuantBert、QuantGpt 和 QuanVit 即可体验。
梯度通信量化
LightSeq 支持 Transformer 模型的梯度通信量化[5]使用 Fairseq 或者 Hugging Face 即可轻松开启分布式量化训练并同时支持浮点数模型和量化模型。在构建模型后只需要为模型注册一个 communication hook 即可开启梯度通信量化再开始训练过程。
from lightseq.training.gradient_comm_quantization import encode_and_decode, GCQState
from torch.nn.parallel import DistributedDataParallel
# model could be from Fairseq or Hugging Face, wrapped by DDP
model DistributedDataParallel(model)
state GCQState(process_group)
# register hook
model.register_comm_hook(statestate, hookencode_and_decode)
性能测试
LightSeq 在多个任务上测试了量化训练、量化推理和梯度通信量化的速度并且分析了显存占用情况和量化模型的效果。
量化训练速度 LightSeq 在 8 张 A100 显卡上进行了训练实验主要对比对象是 Fairseq 的 Transformer、Hugging Face 的 BERT、GPT2 和 ViT。
可以看出四种模型结构加速趋势都是类似的加速比都会随着数据量的增大而减小原因有三点
随着数据量的增大矩阵乘法 GEMM 的占比会明显增加因此 PyTorch QAT 增加的额外的伪量化结点时间占比会逐渐减小最后速度会和 PyTorch fp16 无限接近。与此同时随着 GEMM 占比升高LightSeq fp16 自定义算子的提速效果也逐渐减小因此时间上也会和 PyTorch fp16 无限接近。由于 Ampere 架构显卡上 int8 GEMM 在 shape 较小时甚至不如 fp16 GEMM 快在大 shape 下才能稍快一点因此随着数据量增大LightSeq int8 也会无限接近 LightSeq fp16 的速度。
量化推理速度 LightSeq 在单张 T4 显卡上进行了推理实验主要对比对象是 Hugging Face 的 Transformer、BERT、GPT2 和 ViT。
可以看出随着输入数据量的增大LightSeq 与 PyTorch 的差距会逐渐减小这也是 GEMM 占比升高造成的。比较 LightSeq fp16 和 LightSeq int8可以看出随着数据量的增大LightSeq int8 越来越快。这是因为在 T4 显卡上int8 GEMM 的加速会随着 shape 的增大而有明显增加。因此在 T4 显卡上进行量化推理时输入数据量越大加速效果越好。 LightSeq 还针对机器翻译多个语向和多个测试集测试了不同 batch size 下LightSeq int8 推理相对于 LightSeq fp16 推理的加速比实验同样是在单张 T4 显卡上进行的采用的模型都是标准的 Transformer-Big。
可以得到和上文中相同的结论随着 batch size 的增大量化推理的加速比会逐渐升高。相比于 LightSeq fp16最高还可以再加速近 70%这极大地缩短了线上翻译模型的推理延时。 最后如上图所示为了展示自动 GEMM 调优技术的效果LightSeq 测试对比了 A100 显卡上 Transformer 和 BERT 模型 fp16、int8 调优前和 int8 调优后的延时。可以看出调优前某些 shape 的 int8 GEMM 速度甚至比 fp16 还要慢而调优后全面超越了 fp16。
显存占用 LightSeq 分析了不同 batch size 下量化模型相对于浮点数模型显存占用的加速比。可以看出随着 batch size 的增大量化模型的显存占用优势更明显最高可以减少 30% 左右。而 LightSeq fp16 引擎相对于 PyTorch 模型也极大程度减少了显存占用因此 LightSeq int8 引擎最终能够减少最多 68% 左右的显存。
量化模型效果 针对机器翻译多个语向和多个测试集LightSeq 测试了量化模型推理相对于浮点数模型 BLEU 的损失采用的模型都是标准的 Transformer-Big。
在数据量较大的语向 en2zh 上LightSeq int8 相对 BLEU 损失较大些最大达到了 - 0.4。而在数据量较小的语向 en2es 上LightSeq int8 不仅没有任何效果损失反而比浮点数模型更好。总体而言int8 量化模型的平均 BLEU 相比浮点数模型基本无损。在 GLUE 和 SQuAD 等多个任务上LightSeq 也验证了量化模型的效果。
梯度通信量化 由于在多机多卡场景下通信瓶颈更加明显所以梯度通信量化主要应用在分布式训练场景。因此 LightSeq 在 2 机 8 卡的 A100 上进行了分布式训练的速度测试。
可以看出梯度通信量化的训练加速效果整体上随着输入数据的增大而减弱。这主要是因为随着输入数据的增大计算时间占比升高梯度通信时间占比减少梯度量化的收益也随之减小。
LightSeq 还额外增加了不同数量网卡NIC下的训练速度测试。可以看到使用梯度通信量化的分布式训练速度相比原始的 LightSeq fp16 有大幅度提升。
量化技术
int8 量化的加速收益主要来自如下几个方面
GEMM 精度从 fp16 降低到 int8 后计算时间缩短自定义算子采用 int8 输入输出后数据读写时间缩短梯度采用 int8 存储后多机之间通信时间缩短。
以 Transformer 模型为例经过 LightSeq fp16 引擎加速后自定义算子时间大大缩短而 GEMM 时间占比提升到了 90% 左右因此优化的重点转移到了 GEMM 提速。将 fp16 GEMM 替换为 int8 GEMM 不仅可以缩短 GEMM 时间还可以减小前后算子的输入输出位宽从而减小读写数据的时间。最后多机训练的瓶颈主要在梯度的通信将梯度量化为 int8 精度可以大大加快分布式训练的速度。
量化原理 为了弥补量化带来的精度损失通常需要用量化感知训练来模拟量化过程。如上图所示量化感知训练就是将 float GEMM 的两个 float 输入分别做一遍量化和反量化称之为伪量化结点离散化成分段的浮点数输入然后进行 float GEMM 运算。得到结果后再次进行量化与反量化得到最终的浮点数结果。而量化的过程是不可导的因此需要用 STE 方法来估计量化参数的梯度。之所以量化感知训练中需要插入伪量化结点然后用 float GEMM 去模拟量化过程是因为 TensorFlow 和 PyTorch 等训练框架不支持 int8 GEMM。 而 LightSeq 量化训练直接采用 int8 GEMM 来真实还原量化过程因此相比传统的实现要更快且更加节省显存。在推理的时候同样采用离散化后的整数进行 int8 GEMM 运算最后再反量化回浮点数结果。量化推理过程和量化训练完全一致并且和传统的量化感知训练是完全等价的。
量化位置 整个量化 Transformer 的网络结构如上图所示红色箭头表示需要加上量化和反量化结点的位置。
首先所有 int8 GEMM 的输入和输出都需要进行量化。由于 int8 GEMM 的 shape 限制部分 GEMM例如注意力分数的计算仍然采用 float GEMM。此外第二层 FFN 的 GEMM 采用的是 int32 的输出因为它的 GEMM 输入是 ReLU 激活函数的输出结果只包含正数非对称因此如果采用 int8 输出的 GEMM将无法反量化为正确的浮点数结果。
然后所有的模型权重 weight 都需要存储为 int8 类型因此需要对 weight 做量化。而权重 bias 参数量较小无需量化保留 float 精度反而可以提升模型效果。
最后需要对 decoder 端的 cache 进行量化。因为在推理时decoder 端的 cache 需要频繁进行读写因此将 cache 量化为 int8 可以大大加快解码的速度。
量化策略 将一个浮点数矩阵量化为 int8 整数矩阵有很多方法LightSeq 采用的是对称量化即将正负数范围对称的浮点数区间等比例地映射到整数区间 [-127, 127] 上。
而实际上浮点数矩阵的数值范围通常并不对称存在极少的离群值。如果直接按照离群值的范围来量化矩阵会影响到量化后的精度所以需要先对矩阵进行数值截断。
LightSeq 采用 PACT 方法进行截断[6]将截断的范围当作模型可学习的参数然后利用 STE 算法去估计参数的梯度并进行反向传播优化。根据实践经验权重 weight 的初始截断范围设为[-1, 1]中间结果的初始截断范围设为[-16, 16]可以在大部分任务上达到最好的效果。最后经过截断范围和其他模型参数的联合优化量化模型的效果可以达到基本无损。
梯度通信量化
针对分布式训练场景LightSeq 推出了梯度量化压缩技术。即对浮点精度的梯度进行 int8 量化以减少梯度通信的时间消耗从而加速训练这就是梯度通信量化GCQ。 如上图所示梯度通信量化的主要流程如下
计算每张卡上各自梯度的截断范围对截断范围执行 all-reduce max 操作每张卡使用统一的截断范围对各自梯度进行 int8 量化对 int8 梯度执行 all-reduce sum 操作每张卡对 all-reduce 后的梯度进行反量化还原为浮点数梯度并进行参数更新。
为了解决 int8 梯度在 all-reduce 过程中溢出的问题LightSeq 首先将每张卡上的浮点数梯度除以卡数再使用除之前的截断范围进行量化最后进行 all-reduce 操作。这样每张卡上量化后的 int8 整数 all-reduce 完就不会溢出但是单卡实际用于量化的比特数也因此而减少所以目前方案在 2 机 8 卡效果几乎无损但随着卡数的上涨训练效果会有所下降。以 en2de 和 en2fr 翻译任务为例在 4 机 8 卡上进行分布式量化训练BLEU 值分别会下降 0.4 和 1.5 左右。未来 LightSeq 将会持续探索更好的方法来解决这一问题。
通用技术
除了上一章节中提到的量化技术以外此次更新 LightSeq 还提出了几种通用的优化技术不仅可以应用在量化模型中也适用于其它所有精度模型的训练与推理。
算子融合 上图是 encoder 模块量化训练的计算图LightSeq 将两次 GEMM 运算之间的所有操作融合成一个算子[7]减少了 kernel 调用的次数因此减少了总的计算时间。
图中黄色矩形表示 int8 GEMM绿色矩形表示 float GEMM。这里采用 float GEMM 是由于 shape 的限制不适合使用 int8 GEMM 加速。红色箭头表示流动数据的类型是 int8绿色箭头表示第二层 FFN 的 GEMM 输出是 int32 数据类型。int8 GEMM 输入输出的量化与反量化操作都被融合到了前后 kernel 里这不仅可以减少数据搬运还可以减小显存占用。 在推理时LightSeq 还针对 decoder 做了优化。如上图所示在计算 self-attention 时注意力得分的维度是(batch size, 1, sequence length)。因此在计算 value 乘积时可以不采用 GEMM 运算而直接手写加权求和的算子从而将图中虚线框中的计算融合成一个 kernel。
自动显存管理 模型量化引入了更复杂的张量类型和张量依赖关系这给显存管理带来新的挑战。为此LightSeq 设计了新的显存管理机制。如上图所示主要包括以下过程
训练启动前根据每个算子的拓扑依赖关系自动计算每个张量的生命周期及显存空间大小。其中包含动态维度的张量按照此维度的最大量进行计算例如机器翻译任务中的最大句长和最大 batch 句子数量。这些最大量在训练前已被指定张量确定生命周期和大小后分析显存复用关系。其中无生命周期重合的张量可以共用一片显存空间所有显存空间都是无数据类型的可以被分配到任意数据类型的张量上根据张量显存复用关系申请多段显存空间为每个张量分配实际的显存起止地址。
张量显存复用的分析LightSeq 借鉴了论文 [3] 中提出的 Greedy by Size for Offset Calculation 方法做了三个改进
支持了整个训练过程的显存复用forward/backward不同数据类型能做到显存复用int8/fp16/fp32在多段显存空间上容纳所有张量而非一段非常大的显存空间这样能有效提升显存利用率。
自动 GEMM 调优
LightSeq 的 int8 GEMM 采用了 NVIDIA 的 cuBLASLt 库这也是目前 NVIDIA 显卡上最为高效的矩阵运算库。但是输入数据的 shape 或者显卡不同的话GEMM 所采用的最优配置例如数据排布、GEMM 算法等等也可能不同因此需要进行自动选取。LightSeq 采取的自动调优方案如下
在多种型号显卡上例如 T4 和 A100进行不同 shape 的 GEMM 最优配置搜索并将结果保存到配置文件中用户只需要下载即可模型初始化时加载对应型号显卡的配置文件解析并保存到键值对为 (shape, 最优配置) 的字典中。如果没有对应型号显卡的配置文件或者没有需要的 GEMM shape那么用户可以选择自己搜索并保存或者直接使用默认配置模型前向或后向计算时根据输入的 shape 在字典中寻找最优配置然后进行 GEMM 计算。如果没有找到对应的 shape那么直接采用默认的配置。
未来工作
未来 LightSeq 还将继续探索移动端的低精度量化、反向传播中梯度的量化、大模型量化等方向。 #SCTNet
SCTNet一种带有transformer语义信息的单分支CNN用于实时分割。借助于提出的transformer类CNN块CFBlock和语义信息对齐模块SCTNet可以在训练中从transformer分支捕获丰富的语义信息。 80.5mIoU62.8FPS! 华科与美团联合提出单分支推理分割架构
最新的实时语义分割方法通常采用额外的语义分支来追求丰富的长距离上下文。然而额外的分支会带来不必要的计算开销并减缓推理速度。为了消除这一困境我们提出了SCTNet一种带有transformer语义信息的单分支CNN用于实时分割。
https://arxiv.org/abs/2312.17071
https://github.com/xzz777/SCTNet
SCTNet在保留轻量级单分支CNN高效性的同时还拥有语义分支的丰富语义表示。考虑到transformer提取长距离上下文的卓越能力SCTNet将transformer作为仅用于训练的语义分支。借助于提出的transformer类CNN块CFBlock和语义信息对齐模块SCTNet可以在训练中从transformer分支捕获丰富的语义信息。在推理过程中只需要部署单分支CNN。我们在Cityscapes,ADE20K和COCO-Stuff-10K上进行了广泛的实验结果表明我们的方法达到了新的最先进水平。 本文贡献主要包含以下三点
我们提出了一种新的单支实时分割网络SCTNet。通过学习从Transformer到CNN的语义信息对齐来提取丰富的语义信息SCTNet在保持轻量级单支CNN快速推理速度的同时具有Transformer的高准确性。为了缓解CNN特征和Transformer特征之间的语义鸿沟我们设计了CFBlock(ConvFormer Block)它可以仅使用卷积操作捕获长距离上下文。此外我们提出了SIAM(语义信息对齐模块)以更有效地对齐特征。在Cityscapes、ADE20K和COCO-Stuff-10K上的大量实验结果表明所提的SCTNet在实时语义分割方面优于现有的最新方法. SCTNet为提高实时语义切分的速度和性能提供了一个新的视角
本文方案 降低计算成本的同时获得丰富的语义信息我们将现在流行的两个分支架构拆解为
一个CNN分支进行推断一个Transformer分支用于训练阶段语义对齐。
Backbone 为了提高推理速度SCTNet采用了典型的分层CNN骨干。SCTNet的Stem模块由两个3×3卷积构成前两个阶段是由堆叠的残积模块组成的后两个阶段则是由所提CFBlock构成。CFBlock采用了几个精心设计的卷积操作来执行类似于Transformer块的远程上下文捕获功能。
Decoder Head 解码头由DAPPM与分割头构成为进一步丰富上下文信息作者在Stage4后面添加了DAPPM。然后作者将S2和S4输出进行拼接并送入分割头。
Training Phase 众所周知Transformer在捕获全局语义上下文方面表现出色。另一方面CNN已被证明比变换器更适合于对分层局部信息进行建模。受Transformer和CNN优点的启发我们探索配备一个具有这两种优点的实时分割网络。我们提出了一个单分支CNN它学习将其特征与强大的Transformer的特征对齐。这种特征对齐使单分支CNN能够提取丰富的全局上下文和详细的空间信息。具体而言SCTNet采用了一个仅作用在训练阶段的Transformer作为语义分支来提取强大的全局语义上下文语义信息对齐模块监督卷积分支以对齐来自Transformer的高质量全局上下文。
Inference Phase 为了避免两个分支的巨大计算成本在推理阶段只部署了CNN分支。利用transformer对齐的语义信息单分支CNN可以生成准确的分割结果而无需额外的语义或昂贵的密集融合。更具体地说输入图像被送入到单分支层次卷积主干中解码器头拾取主干中的特征并进行简单的拼接进行像素分类.
本文实验 上图与表为Cityscapes语义分割上不同方案的性能对比从中可以看到
所提SCTNet以大幅优势优于其他实时分割方案取得了最佳的速度-精度均衡所提SCTNet-B-Seg100去的了80.5%mIoU且速度达62.8FPS达成实时分割新SOTA所提SCTNet-B-Seg75取得了79.8%mIoU比RTFormer-B与DDRnet-23精度更高同时速度快两倍在所有输入分辨率下所提SCTNet-B均比其他方案指标更优此外SCTNet-S同样取得了比STDC2、RTFormer-S、SeaFormer-B、TopFormer-B更优的性能均衡。 上表为ADE20K与COCO-Stuff-10K两个数据集上不同分割方案的性能对比很明显所提SCTNet同样取得了更优的速度-精度均衡。 #STKET
作者提出了一种基于时空知识嵌入的 TransformerSTKET将先验时空知识纳入多头交叉注意机制中从而学习更多有代表性的视觉关系表示。 基于时空知识的视频场景图生成
视频场景图生成VidSGG旨在识别视觉场景中的对象并推断它们之间的视觉关系。该任务不仅需要全面了解分散在整个场景中的每个对象还需要深入研究它们在时序上的运动和交互。为此我们进行了相关的探索并发现每对物体组合及其它们之间的关系在每个图像内具有空间共现相关性并且在不同图像之间具有时间一致性/转换相关性。基于这些先验知识我们提出了一种基于时空知识嵌入的 TransformerSTKET将先验时空知识纳入多头交叉注意机制中从而学习更多有代表性的视觉关系表示。具体来说我们首先以统计方式学习空间共现和时间转换相关性。然后我们设计了时空知识嵌入层对视觉表示与知识之间的交互进行充分探索分别生成空间和时间知识嵌入的视觉关系表示。最后我们聚合这些特征以预测最终的语义标签及其视觉关系。大量实验表明我们所提出的框架大幅优于当前竞争算法。目前该论文已经被人工智能顶级期刊 IEEE T-IP接收。
论文链接https://arxiv.org/abs/2309.13237
代码链接https://github.com/HCPLab-SYSU/STKET
1. 概述
随着场景理解领域的快速发展许多研究者们开始尝试利用各种框架解决场景图生成Scene Graph Generation, SGG任务并已取得了不俗的进展。但是这些方法往往只考虑单张图像的情况忽略了时序中存在着的大量的上下文信息导致现有大部分场景图生成算法在无法准确地识别所给定的视频中包含的动态视觉关系。因此许多研究者致力于开发视频场景图生成Video Scene Graph Generation, VidSGG算法来解决这个问题。
目前的工作主要关注从空间和时间角度聚合对象级视觉信息以学习对应的视觉关系表示。然而由于各类物体与交互动作的视觉外表方差大以及视频收集所导致的视觉关系显著的长尾分布单纯的仅用视觉信息容易导致模型预测错误的视觉关系。
针对上述问题我们做了以下两方面的工作
首先我们提出挖掘训练样本中包含的先验时空知识用以促进视频场景图生成领域。其中先验时空知识包括1空间共现相关性某些对象类别之间的关系倾向于特定的交互。2时间一致性/转换相关性给定对的关系在连续视频剪辑中往往是一致的或者很有可能转换到另一个特定关系。其次我们提出了一种新颖的基于时空知识嵌入的 Transformer (Spatial-Temporal Knowledge-Embedded Transformer, STKET) 框架。该框架将先验时空知识纳入多头交叉注意机制中从而学习更多有代表性的视觉关系表示。根据在测试基准上得到的比较结果我们发现我们所提出的 STKET 框架优于以前的最先进方法。 图 1. 由于视觉外表多变和视觉关系的长尾分布导致视频场景图生成充满挑战
2. 基于时空知识嵌入的 Transformer
2.1. 时空知识表示
在推断视觉关系时人类不仅利用视觉线索还利用积累的先验知识 [1, 2]。受此启发我们提出直接从训练集中提取先验时空知识以促进视频场景图生成任务。其中空间共现相关性具体表现为当给定物体组合后其视觉关系分布将高度倾斜例如“人”与“杯子”之间的视觉关系的分布明显不同于“狗”与“玩具”之间的分布和时间转移相关性具体表现为当给定前一时刻的视觉关系后各个视觉关系的转换概率将大幅变化例如当已知前一时刻的视觉关系为“吃”时下一时刻视觉关系转移为“书写”的概率大幅下降。如图 2 所示我们可以直观地感受到给定物体组合或之前的视觉关系后预测空间可以被大幅的缩减。 图 2. 视觉关系的空间共现概率 [3] 与时间转移概率 图 3. 学习空间 (a) 和时间 (b) 知识表示的过程
2.2. 知识嵌入注意力层
空间知识通常包含有关实体之间的位置、距离和关系的信息。另一方面时间知识涉及动作之间的顺序、持续时间和间隔。鉴于它们独特的属性单独处理它们可以允许专门的建模更准确地捕获固有模式。因此我们设计了时空知识嵌入层彻底探索视觉表示与时空知识之间的相互作用。 图 4. 空间 (左侧) 和时间 (右侧) 知识嵌入层
2.3. 时空聚合模块
如前所述空间知识嵌入层探索每个图像内的空间共现相关性时间知识嵌入层探索不同图像之间的时间转移相关性以此充分探索了视觉表示和时空知识之间的相互作用。尽管如此这两层忽略了长时序的上下文信息而这对于识别大部分动态变化的视觉关系具有帮助。为此我们进一步设计了时空聚合STA模块来聚合每个对象对的这些表示以预测最终的语义标签及其关系。它将不同帧中相同主客体对的空间和时间嵌入关系表示作为输入。具体来说我们将同一对象对的这些表示连接起来以生成上下文表示。然后为了在不同帧中找到相同的主客体对我们采用预测的对象标签和 IoU即并集交集来匹配帧中检测到的相同主客体对 。最后考虑到帧中的关系在不同批次中有不同的表示我们选择滑动窗口中最早出现的表示。
3. 实验结果
为了全面评估所提出的框架的性能我们除了对比现有的视频场景图生成方法STTran, TPI, APT外我们也选取了先进的图像场景图生成方法KERN, VCTREE, ReIDN, GPS-Net进行比较。其中为确保对比的公平图像场景图生成方法通过对每一帧图像进行识别从而达到对所给定视频生成对应场景图的目标。 #SeTformer
这里提出了SeTformer一种新的transformer其中DPSA完全被Self-optimal Transport (SeT)取代以实现更好的性能和计算效率。在小型和基准尺寸模型下SeTformer在ImageNet-1K上实现了令人印象深刻的84.7%和86.2%的top-1准确率。
论文链接https://arxiv.org/pdf/2401.03540.pdf
Transformer变压器最初是用于自然语言处理NLP的技术在视觉领域得到了显著的流行这要归功于Vision TransformerViT的开创性工作它的优势已经在各种视觉任务中得到了证明包括图像分类、目标检测、分割等。对于捕获长距离依赖关系点积自注意力DPSA与softmax归一化在transformer中起着至关重要的作用。然而该模型的计算导致了二次时间和内存复杂度使得训练长序列模型变得困难。
简介
本文提出了SeTformer一种新的transformer其中DPSA完全被Self-optimal Transport (SeT)取代以实现更好的性能和计算效率。SeT基于两个基本的softmax属性保持非负的注意力矩阵和使用非线性的重新加权机制来强调输入序列中重要的标记。通过引入一个用于最优传输的核成本函数SeTformer有效地满足了这些属性。特别是在小型和基准尺寸模型下SeTformer在ImageNet-1K上实现了令人印象深刻的84.7%和86.2%的top-1准确率。在目标检测方面 SeTformer-base相比FocalNet同类产品超出2.2 mAP 使用的参数和浮点运算数分别减少了38%和29%。在语义分割方面 我们的基准模型相比NAT超出了3.5 mIoU并且参数减少了33%。SeTformer在GLUE基准测试中也取得了最先进的语言建模结果。这些发现凸显了SeTformer在视觉和语言任务中的适用性。 方法与模型
我们的目标是开发一种强大而高效的自注意力模型尤其注重简单性。我们不添加任何复杂模块如卷积、平移窗口或注意力偏置以提高视觉任务的性能。事实上我们采用了不同的策略。SeT利用了softmax的重要性质包括非负性和重新加权机制同时在设计中也注重了效率。使用具有正定PD核的RKHS避免了聚合负相关信息。SeT通过OT引入了非线性的重新加权方案。这涉及在RKHS中计算输入和参考集之间的对齐得分。这个过程引入了对齐得分的非线性给元素分配权重以突出它们的重要性。这有助于模型捕捉复杂关系并强调局部相关性。 SeTformer 架构首先是一个下采样的卷积层然后是包含多个 SeT 块的四个序列阶段。连续的阶段通过降采样层相连降低空间尺寸同时加倍深度。在右边我们展示了我们的注意力计算将 x 和 y 元素映射到RKHS然后通过 x 和 y 之间的 OT 计算聚合 x 特征如果它们与相应的参考对齐良好。
我们使用Swin作为我们的基线模型用我们的SeT模块替换其自注意力。我们的模型由四个阶段组成每个阶段具有不同的空间分辨率结果是输入图像的1/4大小。输入使用两层3×3卷积和2×2步幅进行嵌入。在每个阶段之后除了最后一层外都有一个通过3×3卷积和2×2步幅进行下采样的模块。这与Swin不同Swin使用的是非重叠的2×2卷积。 1 Representing local image neighborhoods in an RKHS
为 了 保 持 线 性 计 算 我 们 将 输 入 特 征 向 量 嵌 入到 一 个RKHS中 其 中 点 评 估 采 用 线 性 函 数 的 形式。核方法使我们能够通过一个正定核函数K将数据从其原始空间X映射到一个高维希尔伯特空间特征空间F中。对于函数uX → F特征映射正定核函数表示为K(x, x′) 〈u(x), u(x′)〉F。鉴于u(x)可以是无穷维的核技术允 许 从Rk中 导 出 一 个 有 限 维 度 的 表 示v(x)其 中 内 积〈v(xi), v(x′j)〉表 示K(x, x′)。正 如所示如果K是正定的对于任意的x和x′我们有K(x, x′) ≥ 0这与softmax算子的非负性质一致。
2 Optimal transport (OT)
我们模型中的一个基本作用是通过学习它们之间的映射将相关令牌进行聚合。我们的加权聚合依赖于被视为不同测度或加权点云的元素x和x′之间的输运计划。OT在对齐问题中得到了广泛应用并且具有捕捉数据几何形状的出色能力。在本文中我们专注于Kantorovich形式的OT ´ 其中使用熵正则化来平滑输运计划
3 Self-optimal Transport (SeT)
对 于 一 个 输 入 特 征 向 量x和 一 个 位 于X中 的 参考ym我们进行以下步骤(i)将特征向量x和y表示为RKHSF中 的 元 素 (ii)使 用OT将x的 元 素 与y对 齐(iii)对x的元素进行加权聚合得到一个对齐矩阵A。我们使用参考y来实现高效的元素聚合。参考集合中的每个元素都作为一个”对齐单元”输入特征通过加权求和在这些单元中进行聚合。这些权重指示了输入和参考之间的对应关系通过OT计算得出。假设我们有一个输入特征向量x {x1, . . . , xn}其中x属于X ∈ Rd是从输入图像中随机提取的。在Nystrom¨ 近似方法的背景下y的样本是通过对训练集X中的特征向量进行K-means聚类来获得的质心从而我们得到y {y1, . . . , ym}其中m ≤ n。使用参考集合有助于优化计算过程并使模型能够有效地处理更长的输入序列。设k是一个正定的核函数如定义在RKHS上的高斯核函数以及映射u : Rd → F。我们创建一个大小为n × m的矩阵k用于存储比较k(xi, yj )的结果。接下来我们根据公式(2)计算x和y之间的传输计划得到大小为n × m的矩阵T(x, y)。传输计划找到将输入特征与参考元素对齐的最佳方法同时最小化对齐成本。 4 Projecting onto a linear subspace
当处理有限维度的u(x)时 Ay(x)可 以 直 接 计 算 而 不 会 引 起 重 大的计算开销。对于无限维或高维的u(x)Nystrom¨ 算法提 供了 一 种 有 效 的 近 似 方 法 来 嵌 入væRd → Rk。Nystrom¨ 算 法 通 过 对 列 和 行 进 行 采 样 并 将 输 入从 特 征 空 间F投 影 到 线 性 子 空 间F1上 来 近 似 计 算传 输 计 划 从 而 得 到 嵌 入〈v(xi), v(x′j)〉F1。子 空间F1由k个中心u(z1), . . . , u(zk)张成。显式公式v(xi) k(z, z)−1/2k(z, xi)表示将z z1, . . . , zk作为中心来进行新的嵌入。这种高效的方法只需要执行K-means聚类并计算逆平方根矩阵。
5 Linear positional encoding
为了将位置信息融入我们的模型中我们采用了的方法在输入集和参考集之间的相似性上应用了指数惩罚基于它们的位置距离。这涉及到对T(v(x), y)与一个距离矩阵M进行乘法运算其中Mij e(− 1τ2)(α−β)2其中α i/nβ j/mτ表示平滑参数。我们考虑了内容和位置信息的相似性权重与其他位置编码方法相比取得了优秀的性能。
实验与结果
我 们 在 图 像 和 语 言 领 域 进 行 了 实 验 包括ImageNet、COCO和ADE20K以及GLUE以展示我们的模型的影响。我们对超参数进行了微调例如参考数量mOT中的熵正则化ϵ以及位置嵌入中的τ。我们观察到ϵ和τ在任务之间表现稳定但对于值m的选
224x224分辨率ImageNet-1K的分类准确率 SeTformer模型以较小的模型大小、Flops和吞吐量稳定优于ConvNeXt。我们的Mini模型的准确率超过Swin-T模型0.4%参数量减少40%28M → 16MFlops减少37%。我们的Tiny模型83.9%在性能上超过CSWin 1.2%并具有类似的模型大小速度提升12%从701/s到785/s。与FocalNet-T模型相比它在性能上表现更优提高了1.6%。使用更大的模型我们在较少的参数和较低的计算成本下实现了最先进的性能。例如SeTformer-B模型在超过24%和36%的Flops和参数减少的情况下将NAT-B模型84.4%的准确率提高了1.8%。我们还注意到吞吐量是在V100 GPU上测量的。
COCO数据集上Mask R-CNN目标检测结果 SeTformer在卷积神经网络如ResNet和Transformer骨 干 网 络 如CSWin、 NAT、 MViTv2方 面 表 现 优 异。例 如 SeTformer-T的APb为49.3APm为44.0 相 较 于NAT-T增 加 了1.6和1.4个 百 分 点同时计算量更小 模型尺寸更小。在扩展规模方面SeTformer-B的APb为51.9相比于CSWin-B的50.8增加了1.1个百分点同时参数减少28%计算量减少33%。
ADE20K数据集上的语义分割结果 语义分割任务上我们的模型优于现有最先进的方法例如相比于CSWin的对应模型SeTformer-T和SeTformer-S的mIoUSS/MS分别提高了1.3 / 0.7和0.7 / 0.4同时具有更轻、更低复杂度的优势。 #FindReplace Transforme
论文新提出了一种名为“FindReplace Transformer”的多 Transformer 架构并证明了通过集成多个Transformer能够解决单一 Transformer 无法胜任的任务。
ICLR 匿名研究单一 Transformer 不具备图灵完备性但多 Transformer 可以。
Transformer 自 2017 年出世以来就在 AI 领域高举高打ChatGPT 引发全球大型语言模型热潮后更是在 NLP 领域被赋予了神话般的地位。
但近日一篇正在审核中的 ICLR 2023 投稿论文如下经研究后提出一个观点单一 Transformer 并不具备图灵完备性其计算能力存在理论上的局限性在圈内引起关注。
由于该论文正在审核中作者信息没有被公开。
论文链接https://openreview.net/pdf?idMGWsPGogLH
与此同时该论文新提出了一种名为“FindReplace Transformer”的多 Transformer 架构并证明了通过集成多个Transformer能够解决单一 Transformer 无法胜任的任务。
这项研究直接对标并超越了当前最先进的GPT-4模型在一系列极具挑战性的基准测试中展现了显著的优势和潜力。
1 被神化的 Transformer 局限在哪里
图灵完备性是评判一个计算系统强大与否的关键指标。如果一个系统被确认为图灵完备则理论上只要赋予其充足的运行时间和内存资源即可以执行任何可计算的算法。
在实际应用中尽管 Transformer 模型在诸多自然语言处理任务上表现卓越但其能力受到设计上的固有限制例如固定的上下文窗口长度和有限的词汇表大小。这意味着 Transformer 模型并不具备解决所有类型计算问题的能力特别是那些需要无限存储空间或无限制迭代过程的问题。
在论文中研究团队特别指出基础的语言模型工作原理在于根据前 k 个词语的概率来预测下一个词语。在 NLP 领域通常会构建一些专门针对固定长度输入输出序列设计的模型集合或框架并将这类模型归入 MF_SMF 类别。
Transformer 作为 MF_SMF 这一框架下的具体实例其图灵完备性的缺失得到了该研究团队的理论论证。他们基于以下逻辑
首先回顾计算理论的基础图灵停机问题是不可判定的意味着不存在一个通用的方法来判断任意给定程序何时终止运行就如同无法找到一把万能钥匙预测每一场棋局结束时间一样。这一原理同样适用于评估模型是否会在执行过程中陷入无尽循环而无法自拔。
研究者进而分析了 MF_S这里假设 MF_S 代表 MF_SMF 中的子集集合中的模型
假设可以构建一个算法H它可以准确判断MF_S中任意模型m是否终止。假设MF_S集合中存在一个模型m’它足够强大以至于能够模拟任何图灵机的计算过程包括那些永远不会停止的图灵机。根据算法H的假设能力如果MF_S集合中的模型m’能够模拟那些不会停止的图灵机那么算法H应该能够预测m’在模拟这些图灵机时是否会停止。然而根据图灵的停机问题不可判定定理我们知道实际上不可能存在这样一个算法H因为它会与图灵的定理相矛盾。因此MF_S集合中不可能存在能够模拟所有图灵机行为的模型m’也就是说MF_S中没有任何模型是图灵完备的。
Transformer便属于 MF_SMF所以 Transformer 不具备图灵完备性。
研究人员指出Transformer在处理自然语言任务尤其是在机器翻译方面有明显的优势。这类模型能够通过递归的方式输入序列并生成更新后的序列从而逐个预测下一个符号。
但是尽管Transformer模型能够基于之前的字符序列连续生成新的字符序列每次接收一段输入字符后产出相应的输出字符并利用新产生的字符序列进行迭代计算它还是受到了上下文长度k和词汇表大小v的限制。这意味着它能够处理的不同字符组合的数量不会超过v^k种。
例如当 Transformer 遇到重复输入时由于它的无状态特性这有利于并行训练多个序列模型必须保证对同一输入产生一致的输出结果。这可能导致在某些情况下模型陷入无限循环的模式即只能生成有限数量的、最多为v^k种不同的输出序列或者在自我复制的过程中无法停止。
与Transformer相比图灵在1936年提出的图灵机概念具有无限的计算潜力不受这些结构性的限制能够模拟任何可计算的过程确保不会陷入类似的有限循环困境。
2 如何超越 GPT-4
实验结果显示单个 Transformer 架构并不具备图灵完备性而多 Transformer 则有能力实现图灵完备如论文中所提出的 FindReplace Transformer、并执行如 GPT-4 等最先进的 Transformer 模型所无法解决的问题。
论文中创新性地将 Find Transformer 与 Replace Transformer 相结合构建了FindReplace Transformer体系结构——这是一个能在任意长度序列上运行的多Transformer系统在论文中被形象地比喻为“磁带”Tape。
该系统由 Find Transformer、Replace Transformer 以及 Map 三部分组成其中 Map 是一个从 Replace Transformer 到 Find Transformer 所涉及的有序集合的函数映射关系。
具体运作时Find Transformer 会在输入序列中定位并标识出需要由 Replace Transformer 处理的部分内容。这两个组件各自具有固定的上下文长度 k并依次对“磁带”上的每个长度为k的子序列进行分析Find Transformer 会选择那些在最终层产生最高激活值的特定子序列。
随后Replace Transformer 会接收 Find Transformer 标识出的子序列作为输入并基于此生成一个新的长度为k的输出序列这个过程利用了 Map 关联的 f∈Map(r) 规则确保了两个 Transformer 之间的协同工作及信息传递。
那这个 FindReplace Transformer 的多 Transformer 系统是如何可以实现图灵完备的呢
简单来说FindReplace Transformer 是一个学习简化的机器。在编程语言的基石 λ 演算 中有三条被称为“归约”Reduction的规则:
Alpha Reduction这是一个绑定变量的重命名。它被用来避免命名冲突。例如在λ 演算的项 λx.x我们可以化简成 λy.y且不改变其意思。Beta Reduction这是将函数应用于其参数的过程。例如在λ项(λx.x)y表示将函数λx.x作用于参数y我们可以化简成 y。Eta Reduction这是对函数和参数的简化。如果你有一个函数比如λx.(fx)而x不出现在f中那么这个就可以化简为f。
FindReplace Transformer 的多Transformer 系统之所以能够实现图灵完备性关键在于其架构设计和训练方式允许模型通过一系列组合操作模拟类似于 λ 演算中的归约规则。尽管单个 Transformer 受限于上下文长度、词汇表大小等因素但通过构建一个多 Transformer 协作的框架并结合特定的学习机制这些简单且局部的“查找与替换”操作得以在更复杂的计算任务中累积并形成强大的综合效应。
具体来说在FindReplace Transformer中多个 Transformer 可能被专门设计来分别或协同地处理不同类型的简化归约任务例如模拟 Alpha Reduction 进行变量重命名、模拟 Beta Reduction 执行函数应用以及模拟 Eta Reduction进行函数简化等。每个 Transformer 可能专注于理解和学习如何执行这类简单的转换操作并将结果传递给下一个Transformer从而逐步构建起复杂问题的解决方案。
虽然单个 Transformer 不具备图灵完备性但当它们以特定的方式组织起来并协同工作时可以模拟通用图灵机的逻辑行为进而实现对任意可计算问题的解决能力。这样的体系结构让FindReplace Transformer在处理大规模、多层次的复杂问题时展现出超越传统单一Transformer的性能表现实现了更高阶的计算能力。
2023年当OpenAI 发布GPT-4时微软研究院的研究人员发表了一篇题为“Sparks of Artificial General Intelligence(Bubeck et al., 2023)”的论文阐述了早期AGI所面临的局限性。
研究者们以汉诺塔问题为例进行了说明。汉诺塔是一个经典的递归问题要求玩家将按照大小顺序堆叠的圆盘从一根柱子移动到另一根柱子上期间只能移动一个圆盘且任何时候大盘不能位于小盘之上借助第三根柱子作为中转。
GPT-4无法解决这个复杂的推理问题从而突显了当前Transformer在推理过程中缺乏规划能力。
研究者对比了几种模型在解决完整汉诺塔问题上的表现。随着问题规模增大其难度呈指数级上升规模为n的问题其解决方案需要2^n - 1步操作。FindReplace Transformer在此任务上表现出色甚至能生成比GPT-4至少长18倍的正确解决方案。 除了在汉诺塔这个GPT-4都难以解决的问题上表现优越之外在其他AI任务如创作满足特定条件的诗歌等FindReplace Transformer都能超越GPT-4这反映了其在泛化能力上的优势。 3 结语
FindReplace Transformer模型通过创新性地结合多个Transformer单元并模拟λ演算中的归约规则在处理如汉诺塔问题等复杂组合任务时展现出了超越传统单个Transformer的优越性能。
这一研究成果揭示了多Transformer系统在实现图灵完备性方面的潜力也证明了在面对特定计算难题时提高模型的逻辑推理和抽象表达能力的重要性。
而纵观整个人工智能技术的发展从深度学习兴起到大模型浪潮来袭每一次技术迭代人们都对于新技术报以极大的热情与崇拜。
然而无论是深度学习还是Transformer架构亦或是如今新出现FindReplace Transformer架构所带给我们的启示是在研究和应用深度学习技术时都需要避免过分神化任何技术应该理性地看待每一项技术关注其优势和局限并结合实际问题来选择和调整合适的技术。只有这样才能不断地在通往人工通用智能AGI的道路上迈进。 #Soft MoE
本文提出了一种可微的稀疏混合专家 Transformer 模型 (fully-differentiable sparse Transformer) Soft MoE 来解决端到端训练困难的问题同时也能够保持 MoE 方法的优势即以较低的推理成本更大的模型容量。
Soft MoE 提出了一种新的可微稀疏混合专家模型稀疏混合专家 (Sparse Mixture of Experts, MoE) 是一种在保证模型训练和推理的成本不显著增加的情况下大幅度提升模型容量的方法。
MoE 方法已经有很长的一段历史了是一种扩大模型容量的经典高效的做法但是它的缺点是
训练不稳定Token Dropping 的问题较难扩展 Expert 的数量低效率的微调
造成以上问题的一个原因是 MoE 的端到端训练困难因此本文提出了一种可微的稀疏混合专家 Transformer 模型 (fully-differentiable sparse Transformer) Soft MoE 来解决端到端训练困难的问题同时也能够保持 MoE 方法的优势即以较低的推理成本更大的模型容量。Soft MoE 的特点是给每个专家输入不同 token 的权重混合。
视觉实验结果证明Soft MoE 大大优于标准 ViT 和流行的 MoE 方法比如 128 个 Expert16 个 MoE 层的 Soft MoE-Huge/14 模型参数比 ViT-Huge/14 多 40 倍但推理时间成本仅增长 2%同时性能要好得多。
1 Soft MoE一种完全可微的稀疏 Transformer
论文名称 From Sparse to Soft Mixtures of Experts
论文地址 https://arxiv.org/pdf/2308.00951.pdf
1 Soft MoE 论文解读
1.1 背景把离散优化问题变为可微的优化问题
稀疏混合专家 (Sparse Mixture of Experts, MoE) 是一种在保证模型训练和推理的成本不显著增加的情况下大幅度提升模型容量的方法。在视觉语言和多模态任务中都取得了成功代表像视觉的 V-MoE[1]文本的 Switch Transformer[2]和多模态的 LIMoE[3]。
如下图1左所示稀疏 MoE Transformer 的核心是一个离散优化问题即模型需要决定每个输入 token 应该输入哪些 Expert 里面这些 Expert 一般是 MLP 模块。输入 token 和 Expert 之间的匹配 (token-to-expert match) 是 MoE 中要考虑的很重要的问题之一之前也有各种各样的方法尝试解决此问题比如基于线性规划的[4]比如基于 RL 算法的[5]比如基于固定规则的[6]比如基于最优传输理论的[7]和基于贪婪匹配的[8]。总之解决好稀疏 MoE 的这个离散优化问题的确是件不容易的事情。稀疏 MoE 的缺点有
训练不稳定Token Dropping 的问题较难扩展 Expert 的数量低效率的微调 图1Sparse MoE 和 Soft MoE 的区别左Sparse MoE给每个 Expert 分配一定的输入 token。右Soft MoE给每个 Expert 分配的是所有输入 token 的加权平均值
如下图1右所示Soft MoE 把稀疏 MoE Transformer 的这个离散优化问题变成了可微的优化问题。Soft MoE 觉得不必一定要 hard 地找到输入 token 和 Expert 之间的一一匹配而是可以 Soft 地混合输入 token 并且分给每一个 Expert。Soft MoE 给每个 Expert 分配的不是某几个输入 token而是所有输入 token 的加权平均值 (权重取决于 token 和 Expert)然后由这个对应的 Expert 去处理这个加权平均值。
1.2 变为可微的优化问题之后解决了之前稀疏 MoE 的什么问题
问题1 精心设计的 Expert-to-token 的路由机制通常并不比随机固定路由好。
Soft MoE 可以避免这个问题因为每个路由的参数都是基于每个输入 token 直接更新的。
问题2 训练不稳定 (LIMoE[3]这个工作观察到在训练期间可能有大部分 token 改变路由给训练带来一定挑战) 导致很多稀疏 MoE 方法的 Expert 都不可以设置得很多。
Soft MoE 可以避免这个问题扩展到数千个 Expert。
1.3 Soft MoE 算法描述
参数配置 整个过程如下图2所示。 图2Soft MoE 算法流程图
遵循稀疏 MoE 的常用设计思想作者用 Soft MoE 块替换了 Transformer 的一部分 MoE 块。slot 的总数是 Soft MoE 的关键超参数因为时间复杂度取决于 slot 的数量而不是 Expert 的数量。比如可以设置等于输入序列长度的 slot 数以匹配等效密集 Transformer 的 FLOP。
Soft MoE 的 JAX 代码
def soft_moe_layer(X, Phi, experts):# Compute the dispatch and combine weights.logits jnp.einsum(md,dnp-mnp, X, Phi)D jax.nn.softmax(logits, axis(0,))C jax.nn.softmax(logits, axis(1, 2))# The input slots are a weighted average of all the input tokens,# given by the dispatch weights.Xs jnp.einsum(md,mnp-npd, X, D)# Apply the corresponding expert function to each input slot.Ys jnp.stack([f_i(Xs[i, :, :]) for i, f_i in enumerate(experts)],axis0)# The output tokens are a weighted average of all the output slots,# given by the combine weights.Y jnp.einsum(npd,mnp-md, Ys, C)return Y
全部代码
https://github.com/google-research/vmoegithub.com/google-research/vmoe
1.4 Soft MoE 的一些关键性质
1) 完全可微
Sparse MoE 算法的通病是 token 和 Expert 之间存在的分配问题有时精心设计的 Expert-to-token 的路由机制通常并不比随机固定路由好。输入 token 和 Expert 之间的匹配 (token-to-expert match) 是 MoE 中要考虑的很重要的问题之一之前也有各种各样的方法尝试解决此问题比如基于线性规划的[4]比如基于 RL 算法的[5]比如基于固定规则的[6]比如基于最优传输理论的[7]和基于贪婪匹配的[8][9]。所有这些方法本质上都是离散不可微的。
Soft MoE 可以避免这个问题因为每个路由的参数都是基于每个输入 token 直接更新的。
2) 可以避免掉 Token Dropping 和 Expert Unbalance 的问题
MoE 算法里面每个 Expert 都会处理一些 token很自然地就会带来 Token Dropping (有的 token 不会分配给任何一个 Expert) 和 Expert Unbalance (一些 Expert 会比另一些 Expert 分配到更多 token) 的问题。
Soft MoE 可以避免这个问题因为每个 slot 的输入都是所有 token 的加权平均值。
3) 运算速度快
Soft MoE 的主要优点是完全避免了之前算法中的 token 排序或 top-k 操作因为这些操作的速度慢而且不太适合硬件加速器。因此Soft MoE 明显快于大多数 Sparse MoE 算法。
4) Soft MoE 算法是密集的 MoE 算法还是稀疏的 MoE 算法
要回答这个问题我们需要首先搞明白为什么 Sparse MoE 算法是稀疏的。Sparse MoE 是稀疏的这件事的根本原因是每个 Expert 的输入特征仅仅是一部分的 token而 Soft MoE 的输入是所有输入 token 的加权平均值因此不能算作是稀疏的。
Soft MoE 也不能算作是 Dense MoE 算法因为每个 Expert 仅仅会处理输入 token 的子集。
5) Soft MoE 算法需要归一化
Transformers 中MoE 层通常用于替换每个编码器块中的 FFN 层因此如果去遵循大部分 Transformer 架构的 Pre-Normalization 方法就需要使用归一化这里 Soft MoE 针对 的操作是
l2_normalize(X, axis1) scale * l2_normalize(Phi, axis0)
其中scale 是可学习的参数l2_normalize 的定义是
def l2_normalize(x, axis, eps1e-6):norm jnp.sqrt(jnp.square(x).sum(axisaxis, keepdimsTrue))return x * jnp.reciprocal(norm eps)
6) Soft MoE 算法和注意力机制 (Multi-Head Self-Attention) 的区别和联系 1.5 Soft MoE 算法的局限性
自回归解码 (Auto-regressive decoding)
因为 Soft MoE 算法要在运行过程中合并所有的输入 token因此很难实现自回归。因为自回归必须在训练期间保留过去的 token 和未来 token 之间的因果关系 (Causality)。
Self-Attention 解决这个问题的手段是依赖于注意力的掩码 (Mask) 机制。如果想在 Soft MoE 中实现这一点就需要特别小心 token 之间的依赖和相关关系。总之研究 Soft MoE 算法的自回归解码是个很有价值的方向。
内存消耗
Soft MoE 倾向于利用大量 Expert而其成本和 Dense Backbone 类似使得模型的内存需求可能变大。
1.6 图像分类实验结果
训练数据集
预训练数据集 JFT-4B一个私有数据集其最新版本包含超过 4B 张图像涵盖超过 29k 个类。预训练的过程中评价指标是 JFT-4B 上的上游验证精度 Precision-at-1 和 ImageNet 10-shot 精度 (冻结模型权重并用一个新的权重来计算的该数据集仅在包含来自 ImageNet-1K 的每个类包含 10 张图像的数据集上进行训练)。
微调数据集 ImageNet-1K 训练集。
验证集 ImageNet-1K 验证集。
模型尺寸
ViT-S/8, ViT-S/16, ViT-S/32, ViT-B/16, ViT-B/32, ViT-L/16, ViT-L/32, ViT-H/14。
方法
Token Choice, Expert Choice 和本文的 Soft MoE。
训练策略
300k steps, Batch Size 4096
Pareto Model 实验结果
如下图3所示是四种方法 Soft MoE, Experts Choice, Tokens Choice, Dense 在预训练过程中的 JFT-4B Precision-at-1 的结果和 ImageNet 10-shot 的精度的训练成本/性能帕累托边界。Soft MoE 算法在这两个指标上都优于之前的方法。 图3四种方法在预训练过程中的 JFT-4B Precision-at-1 的结果和 ImageNet 10-shot 的精度的训练成本/性能帕累托边界
更长的训练结果
本文还测试在更长的训练 step 下模型的性能如何把从 Small 到 Huge 的模型训练了 4K steps用 128 个 Expert 的 Soft MoE 替换 ViT S/16、B/16、L/16 和 H/14 中的最后一半 Block 中的 FFN每个 Expert 使用一个 slot。
由于模型并行性所需的额外数据传输Large Soft MoE 模型产生的 wall-clock time overhead 很小。所有变体都训练了 4M 步除了 H/14出于成本原因训练了 2M 步实验结果如下图4和5所示。
如下图4所示是 Soft MoE 和 ViT 的 JFT-4B 精度、ImageNet 10-shot 精度和 ImageNet 微调精度与 ExaFLOPS 的训练成本。 图4不同模型更长的训练 step 下的 JFT-4B 精度
如下图5所示是所有结果。对于给定的计算预算Soft MoE 模型大大优于 Vision Transformer 模型。比如 Soft MoE-S/16 在 JFT-4B 和 ImageNet 10-shot 上的表现优于 ViT-B/16它还提高了完整 ImageNet 数据的微调分数即使它的训练 (和推理) 成本要小得多。同样Soft MoE-B/16 在上游任务 JFT-4B 和 ImageNet 10 shot 的表现优于 ViT-L/16微调后仅落后 0.5同时速度快 3 倍所需的 FLOP 减少了近 4 倍。最后Soft MoE-L/16 模型优于 Dense H/14 模型同时在推理速度又快 3 倍左右。 图5不同模型更长的训练 step 下的实验结果
根据前面的实验结果较小的 Soft MoE 的性能可以匹配较大的视觉 Transformer作者因此继续训练小模型 Backbone希望以非常低的训练成本获得更高质量的模型。
作者观察到对于 Soft MoE 方法而言较长的 cooldown (学习率线性减小到零的时期) 可以很好地适用于 Soft MoE因此将 cooldown 从 50k steps 增加到 500k steps。
实验结果如下图6和7所示。Soft MoE-B/16 训练了 1k TPUv3 Days优于在相似时间预算上训练的 ViT-H/14而 Soft MoE-B 模型的 FLOPs 要低 10 倍wall-clock time 低 5.7 倍。即使将 ViT-H/14 的训练代价加倍Soft MoE-B 模型的性能也可以与之相匹配。Soft MoE-L/16 模型的在推断上比 ViT H/14 快近 2 倍的同时性能大大优于所有模型。 图6不同训练代价和尺寸的 Soft MoE 模型和 ViT 的 JFT-4B Precision-at-1 性能和 ImageNet 10-shot 性能 图7Soft MoE 模型和 ViT 的实验结果
视觉-文本对比学习实验结果
作者还验证了 Soft MoE 得到的模型在其他任务的性能。具体而言作者探索了一种流行的范式即图像语言对比学习这里遵循的是 LiT[10] 方法其中图像塔在图像分类任务上进行了预训练然后在在图像-文本对数据集上训练文本编码器时冻结。
视觉编码器作者重用了在 JFT 上训练的模型对比学习在 WebLI 上训练这是一个专有数据集由 10B 图像和从互联网上抓取的 ALT 文本组成。图像编码器被冻结而文本编码器从头开始训练。实验结果如下图8所示Soft MoE -L/16 在 Imagenet 和 Cifar-100 零样本上的性能分别比 ViT-L/16 高出 1% 和 2% 以上。 图8对比学习实验结果 #Transformers18~ Diffusion
还是Transformers,来自 UC 伯克利的 William Peebles 以及纽约大学的谢赛宁撰文揭秘扩散模型中架构选择的意义并为未来的生成模型研究提供经验基线。
近几年在 Transformer 的推动下机器学习正在经历复兴。过去五年中用于自然语言处理、计算机视觉以及其他领域的神经架构在很大程度上已被 transformer 所占据。
不过还有许多图像级生成模型仍然不受这一趋势的影响例如过去一年扩散模型在图像生成方面取得了惊人的成果几乎所有这些模型都使用卷积 U-Net 作为主干。这有点令人惊讶在过去的几年中深度学习的大事件一直是跨领域的 Transformer 的主导地位。U-Net 或卷积是否有什么特别之处使它们在扩散模型中表现得如此出色
将 U-Net 主干网络首次引入扩散模型的研究可追溯到 Ho 等人这种设计模式继承了自回归生成模型 PixelCNN只是稍微进行了一些改动。而 PixelCNN 由卷积层组成其包含许多的 ResNet 块。其与标准的 U-Net 相比PixelCNN 附加的空间自注意力块成为 transformer 中的基本组件。不同于其他人的研究Dhariwal 和 Nichol 等人消除了 U-Net 的几种架构选择例如使用自适应归一化层为卷积层注入条件信息和通道计数。
本文中来自 UC 伯克利的 William Peebles 以及纽约大学的谢赛宁撰文《 Scalable Diffusion Models with Transformers 》目标是揭开扩散模型中架构选择的意义并为未来的生成模型研究提供经验基线。该研究表明U-Net 归纳偏置对扩散模型的性能不是至关重要的并且可以很容易地用标准设计如 transformer取代。
这一发现表明扩散模型可以从架构统一趋势中受益例如扩散模型可以继承其他领域的最佳实践和训练方法保留这些模型的可扩展性、鲁棒性和效率等有利特性。标准化架构也将为跨领域研究开辟新的可能性。
论文地址https://arxiv.org/pdf/2212.09748.pdf项目地址https://github.com/facebookresearch/DiT论文主页https://www.wpeebles.com/DiT
该研究专注于一类新的基于 Transformer 的扩散模型Diffusion Transformers简称 DiTs。DiTs 遵循 Vision Transformers (ViTs) 的最佳实践有一些小但重要的调整。DiT 已被证明比传统的卷积网络例如 ResNet 具有更有效地扩展性。
具体而言本文研究了 Transformer 在网络复杂度与样本质量方面的扩展行为。研究表明通过在潜在扩散模型 (LDM) 框架下构建 DiT 设计空间并对其进行基准测试其中扩散模型在 VAE 的潜在空间内进行训练可以成功地用 transformer 替换 U-Net 主干。本文进一步表明 DiT 是扩散模型的可扩展架构网络复杂性由 Gflops 测量与样本质量由 FID 测量之间存在很强的相关性。通过简单地扩展 DiT 并训练具有高容量主干118.6 Gflops的 LDM可以在类条件 256 × 256 ImageNet 生成基准上实现 2.27 FID 的最新结果。
Diffusion Transformers
DiTs 是一种用于扩散模型的新架构目标是尽可能忠实于标准 transformer 架构以保留其可扩展性。DiT 保留了 ViT 的许多最佳实践图 3 显示了完整 DiT 体系架构。 DiT 的输入为空间表示 z对于 256 × 256 × 3 图像z 的形状为 32 × 32 × 4。DiT 的第一层是 patchify该层通过将每个 patch 线性嵌入到输入中以此将空间输入转换为一个 T token 序列。patchify 之后本文将标准的基于 ViT 频率的位置嵌入应用于所有输入 token。
patchify 创建的 token T 的数量由 patch 大小超参数 p 决定。如图 4 所示将 p 减半将使 T 翻四倍因此至少能使 transformer Gflops 翻四倍。本文将 p 2,4,8 添加到 DiT 设计空间。 DiT 块设计在 patchify 之后输入 token 由一系列 transformer 块处理。除了噪声图像输入之外扩散模型有时还会处理额外的条件信息例如噪声时间步长 t、类标签 c、自然语言等。本文探索了四种以不同方式处理条件输入的 transformer 块变体。这些设计对标准 ViT 块设计进行了微小但重要的修改。所有模块的设计如图 3 所示。
本文尝试了四种因模型深度和宽度而异的配置DiT-S、DiT-B、DiT-L 和 DiT-XL。这些模型配置范围从 33M 到 675M 参数Gflops 从 0.4 到 119 。
实验
研究者训练了四个最高 Gflop 的 DiT-XL/2 模型每个模型使用不同的 block 设计 ——in-context119.4Gflops、cross-attention137.6Gflops、adaptive layer normadaLN118.6Gflops或 adaLN-zero118.6Gflops。然后在训练过程中测量 FID图 5 为结果。 扩展模型大小和 patch 大小。图 2左给出了每个模型的 Gflops 和它们在 400K 训练迭代时的 FID 概况。可以发现增加模型大小和减少 patch 大小会对扩散模型产生相当大的改进。 图 6顶部展示了 FID 是如何随着模型大小的增加和 patch 大小保持不变而变化的。在四种设置中通过使 Transformer 更深、更宽训练的所有阶段都获得了 FID 的明显提升。同样图 6底部展示了 patch 大小减少和模型大小保持不变时的 FID。研究者再次观察到在整个训练过程中通过简单地扩大 DiT 处理的 token 数量并保持参数的大致固定FID 会得到相当大的改善。 图 8 中展示了 FID-50K 在 400K 训练步数下与模型 Gflops 的对比 SOTA 扩散模型 256×256 ImageNet。在对扩展分析之后研究者继续训练最高 Gflop 模型 DiT-XL/2步数为 7M。图 1 展示了该模型的样本并与类别条件生成 SOTA 模型进行比较表 2 中展示了结果。 当使用无分类器指导时DiT-XL/2 优于之前所有的扩散模型将之前由 LDM 实现的 3.60 的最佳 FID-50K 降至 2.27。如图 2右所示相对于 LDM-4103.6 Gflops这样的潜在空间 U-Net 模型来说DiT-XL/2118.6 Gflops计算效率高得多也比 ADM1120 Gflops或 ADM-U742 Gflops这样的像素空间 U-Net 模型效率高很多。 表 3 展示了与 SOTA 方法的比较。XL/2 在这一分辨率下再次胜过之前的所有扩散模型将 ADM 之前取得的 3.85 的最佳 FID 提高到 3.04。 #Diffusion Transformers (DiTs)
本文探索了一类新的基于 Transformer 的扩散模型 Diffusion Transformers (DiTs)。本文训练 latent diffusion models 时使用 Transformer 架构替换常用的 UNet 架构且 Transformer 作用于 latent patches 上。
本文探索了一类新的基于 Transformer 的扩散模型 Diffusion Transformers (DiTs)。本文训练 latent diffusion models 时使用 Transformer 架构替换常用的 UNet 架构且 Transformer 作用于 latent patches 上。
作者探索了 DiT 的缩放性发现具有较高 GFLOPs 的 DiT 模型通过增加 Transformer 宽度或者深度或者输入 token 数量始终有更好的 FID 值。最大的 DiT-XL/2 模型在 ImageNet 512×512 和 256×256 的测试中优于所有先前的扩散模型实现了 2.27 的 FID 值。
做了什么工作
探索了一类新的基于 Transformer 的 Diffusion Model称为 Diffusion Transformers (DiTs)。研究了 DiT 对于模型复杂度 (GFLOPs) 和样本质量 (FID) 的缩放性。证明了通过使用 Latent Diffusion Models (LDMs)[1]框架Diffusion Model 中的 U-Net 架构可以被 Transformer 替换。
1 DiTTransformer 构建扩散模型
论文名称Scalable Diffusion Models with Transformers (ICCV 2023, Oral)
论文地址https//arxiv.org/pdf/2212.09748.pdf
论文主页https//www.wpeebles.com/DiT.html
1 DiT 论文解读
1.1 把 Transformer 引入 Diffusion Models
机器学习正经历着 Transformer 架构带来的复兴NLPCV 等许多领域正在被 Transformer 模型覆盖。尽管 Transformer 在 Autoregressive Model 中得到广泛应用[2][3][4][5]但是这种架构在生成式模型中较少采用。比如作为图像领域生成模型的经典方法Diffusion Models[6][7]却一直使用基于卷积的 U-Net 架构作为骨干网络。
Diffusion Models 的开创性工作 DDPM [8]首次引入了基于 U-Net 骨干网络的扩散模型。U-Net 继承自 PixelCNN[9][10]变化很少。与标准 U-Net[11]相比额外的空间 Self-Attention 块 (Transformer 中必不可少的组件) 以较低分辨率穿插。[12]这个工作探索了 U-Net 的几种架构选择例如自适应归一化层 (Adaptive Normalization Layer[13]为卷积层注入条件信息和通道计数。然而DDPM 里面 U-Net 的高级设计在很大程度上都保持不变。
本文的目的是探索 Diffusion Models 架构选择的重要性并为未来生成式模型的研究提供基线。本文的结论表明 U-Net 架构设计对 Diffusion Models 的性能并不重要并且它们可以很容易地替换为 Transformers。
本文证明了 Diffusion Models 也可以受益于 Transformer 架构受益于其训练方案受益于其可扩展性受益于其鲁棒性和效率等等。标准化架构还将为跨域研究开辟了新的可能性。
1.2 Diffusion Models 简介
DDPM
高斯扩散模型假设有一个前向的加噪过程 (Forward Noising Process)在这个过程中逐渐将噪声应用于真实数据 这个优化的目标函数比较复杂最后通过 variational lower bound 方法得到的结论是优化下式 (此处详细推导可以参考开创性工作 DDPM[8]) 1.3 DiT 架构介绍
1.3.1 Patchify 过程 图1图片的 Patchify 操作。当 Patch 的大小 p 越小时token 的数量 T 越大
1.3.2 DiT Block 设计
在 Patchify 之后输入的 tokens 开始进入一系列 Transformer Block 中。除了噪声图像输入之外Diffusion Model 有时会处理额外的条件信息比如噪声时间步长 ttt , 类标签 ccc , 自然语言。
作者探索了4种不同类型的 Transformer Block以不同的方式处理条件输入。这些设计都对标准 ViT Block 进行了微小的修改所有 Block 的设计如下图2所示。 图2Diffusion Transformer (DiT) 架构
In-Context Conditioning 作者将以上几种方法 In-Context ConditioningCross-Attention BlockAdaptive Layer Norm (adaLN) BlockadaLN-Zero Block 的做法列入 DiT 的设计空间中。 1.3.3 模型尺寸 图3DiT 模型的详细配置 作者将以上几种配置列入了 DiT 的设计空间中。 1.3.4 Transformer Decoder
在最后一个 DiT Block 之后需要将 image tokens 的序列解码为输出噪声以及对角的协方差矩阵的预测结果。 最终完整 DiT 的设计空间是 Patch Size、DiT Block 的架构和模型大小。 1.4 DiT 训练策略
1.4.1 训练配方
作者在 ImageNet 数据集上训练了 class-conditional latent DiT 模型标准的实验设置。 数据增强技术只使用 horizontal flips。
作者发现 learning rate warmup 和 regularization对训练 DiT 模型而言不是必须的。
作者使用了 exponential moving average (EMA)参数为 0.9999 。
训练超参数基本都来自 ADM不调学习率, decay/warm-up schedules, Adam 参数以及 weight decay.
1.4.2 扩散模型配置 作者保留了 ADM 中使用的超参数。
1.5.1 DiT 架构设计
作者首先探索的是不同 Conditioning 策略的对比。对于一个 DiT-XL/2 模型其计算复杂度分别是in-context (119.4 Gflops), cross-attention (137.6 Gflops), adaptive layer norm (adaLN, 118.6 Gflops), adaLN-zero (118.6 Gflops)。实验结果如下图4所示。
adaLN-Zero 的 Block 架构设计取得了最低的 FID 结果同时在计算量上也是最高效的。在 400K 训练迭代中adaLN-Zero Block 架构得到的 FID 几乎是 In-Context 的一半表明 Condition 策略会严重影响模型的质量。
初始化同样也重要adaLN-Zero Block 架构在初始化时相当于恒等映射其性能也大大优于 adaLN Block 架构。 因此在后续实验中DiT 将一直使用 adaLN-Zero Block 架构。 图4不同 Conditioning 策略对比
1.5.2 缩放模型尺寸和 Patch Size
作者训练了12个 DiT 模型 (尺寸为 S, B, L, XLPatch Size 为 8,4,2)。下图是不同 DiT 模型的尺寸和 FID-50K 性能。如下图5所示是不同大小 DiT 模型的 GFLOPs 以及在 400K 训练迭代中的 FID 值。可以发现在增加模型大小或者减小 Batch Size 时可以显著改善 DiT 的性能。 图5不同尺寸 DiT 模型的 GFLOPs 以及它们在 400K 训练迭代中的 FID
下图6上方是 Patch Size 不变增加模型规模时 FID 的变化。当模型变深变宽时FID 会下降。
下方是模型规模不变减小 Patch Size 时 FID 的变化。当 Patch Size 下降时FID 出现显著改善。 图5缩放 DiT 模型可以改善训练各个阶段的 FID
1.5.3 GFLOPs 对性能很重要
上图5的结果表明参数量并不能唯一确定 DiT 模型的质量。当 Patch Size 减小时参数量仅仅是略有下降只有 GFLOPs 明显增加。这些结果都表明了缩放模型的 GFLOPs 才是性能提升的关键。为了印证这一点作者在下图6中绘制了不同 GFLOPs 模型在 400K 训练步骤时候的 FID-50K 结果。这些结果表明当不同 DiT 模型的总 GFLOPs 相似时它们的 FID 值也相似比如 DiT-S/2 和 DiT-B/4。
作者还发现 DiT 模型的 GFLOPs 和 FID-50K 之间存在很强的负相关关系。 图6GFLOPs 与 FID 密切相关
1.5.4 大模型更加计算高效 图7大模型更加计算高效
1.5.5 缩放结果可视化 图8缩放对于视觉质量的影响
1.6 DiT 实验结果
作者将 DiT 与最先进的生成模型进行了比较结果如图9所示。DiT-XL/2 优于所有先前的扩散模型将 LDM 实现的先前最佳 FID-50K 降低到 2.27。图5右侧显示 DiT-XL/2 (118.6 GFLOPs) 相对于 LDM-4 (103.6 GFLOPs) 等Latent Space U-Net 模型的计算效率很高并且比 Pixel Space U-Net 模型更高效例如 ADM (1120 GFLOPs) 或 ADM-U (742 GFLOPs)。 图9ImageNet 256×256 图像生成结果
作者在 ImageNet 上训练了一个新的 DiT-XL/2这次分辨率是 512×5123M training iterations超参数与 256×256 模型相同。这个模型 latent 的维度是 64×64×4然后 Patch Size 为2这样 Transformer 模型需要处理的 token 的数量就是 1024。如下图10所示是比较结果。DiT-XL/2 在此分辨率下再次优于所有先前的扩散模型将 ADM 实现的先前最佳 FID 提高了 3.85 到 3.04。即使 token 的数量增加了DiT-XL/2 的计算效率依然很高比如 ADM 使用 1983 GFLOPsADM-U 使用 2813 GFLOPsDiT-XL/2 仅仅使用 524.6 GFLOPs。 图10ImageNet 512×512 图像生成结果
缩放模型大小还是采样次数
Diffusion Model 的一个独特之处是它们可以通过在生成图像时增加采样步骤的数量来在训练期间使用额外的计算。也就是扩散模型的计算量既可以来自模型本身的缩放也可以来自采样次数的增加。因此作者在这里研究了通过使用更多的采样计算较小的 DiT 模型是否可以胜过更大的模型。
作者计算了所有的 12 个 DiT 模型在 400K training iteration 时候的 FID 值每张图分别使用 [16, 32, 64, 128, 256, 1000] sampling steps。
实验结果如下图11所示考虑使用 1000 个采样步骤的 DiT-L/2 和使用 128 步的 DiT-XL/2。在这种情况下
DiT-L/2 使用 80.7 TFLOPs 对每张图像进行采样。DiT-XL/2 使用 15.2 TFLOPs 对每张图像进行采样。
但尽管如此DiT-XL/2 具有更好的 FID-10K 结果。说明增加采样的计算量也无法弥补模型本身计算量的缺失。 图11增加采样的计算量也无法弥补模型本身计算量的缺失 #eventful-transformers
如何降低视觉Transformer计算成本时间冗余方法让人大吃一惊
在为语言领域带来变革之后Transformer 正在进军视觉领域但其也有着高计算成本的问题。近日威斯康星大学麦迪逊分校一个研究团队提出了 Eventful Transformer可通过在视觉 Transformer 中利用时间冗余来节省成本。
Transformer 一开始是为自然语言处理任务设计的但现在却已经被广泛用于视觉任务。视觉 Transformer 在一系列视觉识别任务上实现了出色的准确度并在图像分类、视频分类和目标检测等任务上取得了当前最优的表现。
视觉 Transformer 的一大缺点是计算成本高。典型的卷积网络CNN处理每张图像需要数十 GFlops而视觉 Transformer 所需的往往会多上一个数量级达到每张图像数百 GFlops。在处理视频时由于数据量巨大这个问题更为严重。高昂的计算成本让视觉 Transformer 难以被部署到资源有限或有严格延迟需求的设备上这就限制了这项技术的应用场景否则我们已经有一些激动人心的应用了。
在近期一篇论文中威斯康星大学麦迪逊分校的三位研究者 Matthew Dutson、Yin Li 和 Mohit Gupta 首先提出可以在后续输入之间使用时间冗余来降低视觉 Transformer 在视频应用中的成本。他们也发布了模型代码其中包含用于构建 Eventful Transformer 的 PyTorch 模块。
论文地址https://arxiv.org/pdf/2308.13494.pdf项目地址http://wisionlab.com/project/eventful-transformers
时间冗余首先假设有一个视觉 Transformer其可以逐帧或逐视频片段地处理视频序列。这个 Transformer 可能是简单的逐帧处理的模型如目标检测器或是某个时空模型的中间步骤如 ViViT 的分解式模型的第一步。不同于一个输入就是一个完整序列的语言处理 Transformer在这里研究者的做法是随时间为 Transformer 提供多个不同的输入帧或视频片段。
自然视频包含显著的时间冗余即后续帧之间的差异很小。尽管如此包括 Transformer 在内的深度网络通常都会「从头开始」计算每一帧。该方法会丢弃之前推理获得的潜在相关信息浪费极大。故而这三位研究者设想是否可以复用之前计算步骤的中间计算结果来提升处理冗余序列的效率
自适应推理对于视觉 Transformer 以及一般意义上的深度网络而言推理成本通常由架构决定。然而在现实应用中可用的资源可能会随时间而变化比如可能因为存在相竞争的进程或电源发生变化。如此一来可能就存在运行时修改模型计算成本的需求。在这项新成果中研究者设定的一大主要设计目标便是适应性 —— 其方法可实现对计算成本的实时控制。下图 1底部给出了在视频处理过程中修改计算预算的示例。 Eventful Transformer本文提出了 Eventful Transformer这类 Transformer 能利用输入之间的时间冗余来实现高效且自适应的推理。Eventful 这一术语的灵感来自事件相机event camera这种传感器能在场景变化时离散地记录影像。Eventful Transformer 会跟踪随时间发生的 token 层面的变化情况并在每个时间步骤有选择性地更新 token 表征和自注意力映射图。Eventful Transformer 的模块中包含一种门控模块用于控制运行时间被更新 token 的数量。
该方法可用于现成的模型通常无需再训练并且兼容许多视频处理任务。研究者也进行了实验论证结果表明 Eventful Transformer 可用于现有的当前最佳模型在极大降低它们的计算成本的同时还能维持其原有的准确度。
Eventful Transformer
这项研究的目标加速用于视频识别的视觉 Transformer。在这个场景中视觉 Transformer 需要反复处理视频帧或视频片段具体的任务包括视频目标检测和视频动作识别等。这里提出的关键思想是利用时间冗余即复用之前时间步骤的计算结果。下面将详细描述如何通过修改 Transformer 模块来使其具备感知时间冗余的能力。
token 门控检测冗余
这一小节将介绍研究者提出的两种新模块token 门和 token 缓冲器。这些模块让模型可以识别和更新自上次更新后有明显变化的 token。
门模块该门会从输入 token N 中选择一部分 M 发送给下游层执行计算。其记忆中维护着一个参照 token 集记为 u。这种参照向量包含每个 token 在其最近一次更新时的值。在每个时间步骤比较各个 token 与其对应的参照值其中与参照值相差较大的 token 获得更新。
现在将该门的当前输入记为 c。在每个时间步骤按照以下流程更新门的状态并决定其输出见下图 2 构建可感知冗余的 Transformer
为了利用上述时间冗余研究者提出了一种对 Transformer 模块的修改方案。下图 4 展示了 Eventful Transformer 模块的设计。该方法可以加速针对各个 token 的运算如 MLP以及查询 - 键值和注意力 - 值乘法。 在针对各个 token 的运算 Transformer 模块中很多运算都是针对各个 token 的也就是说它们不涉及到 token 之间的信息交换其中包括 MLP 和 MSA 中的线性变换。为了节省计算成本研究者表示可以跳过未被门选取的 token 的面向 token 的运算。由于 token 之间的独立性这不会改变对所选 token 的运算结果。参见图 3。
具体来说针对各个 token 的运算包括 W_qkv 变换、W_p 变换和 MLP的连续序列研究者使用了一对门 - 缓冲器。注意他们还在 skip 连接之前添加了缓冲器以确保两个加法操作数的 token 是正确对齐的。
针对各个 token 的运算的成本正比于 token 的数量。门可将这个数量从 N 降至 M也就将下游的针对各个 token 的运算的计算成本降低了 N/M 倍。
查询 - 键值的积现在来看看查询 - 键值积 B q k^T。
下图 5 展示了稀疏地更新查询 - 键值积 B 中一部分元素的方法。 这些更新的总体成本为 2NMD相较而言从头开始计算 B 的成本为 N^2D。注意新方法的成本正比于 M即门选取的 token 的数量。当 M N/2 时此时更新的 token 不到总量一半可节省计算量。
注意力 - 值的积研究者为此提出了一种基于增量 ∆ 的更新策略。
下图 6 展示了新提出的高效计算三个增量项的方法。 同样当 M N/2 时可节省计算量。
token 选取策略
Eventful Transformer 的一大重要设计是其 token 选取策略。给定一个门误差张量 e这样一个策略的目标是生成一个掩码 m其中指示了应当被更新的 token。具体的策略包括
Top-r 策略该策略选取 r 个误差 e 有最大范数的 token这里使用的是 L2 范数。
阈值策略该策略选取误差 e 的范数超过一个阈值 h 的所有 token。
其它策略更复杂精细的 token 选取策略可实现更好的准确度 - 成本权衡比如可以使用一个轻量级策略网络来学习一个策略。但是训练策略的决策机制的难度可能很大因为二元掩码 m 一般是不可微分的。另一个思路是使用重要度分数作为选取的参考信息。但这些想法都还有待未来研究。
实验
研究者用实验评估了新提出的方法具体使用的任务是视频目标检测和视频动作识别。
下图 7 展示了视频目标检测的实验结果。其中正轴是计算节省率负轴是新方法的 mAP50 分数的相对减少量。可以看到新方法用少量的准确度牺牲换来了显著的计算量节省。 下图 8 给出了在视频目标检测任务上的方法比较和消融实验结果。 下图 9 给出了视频动作识别的实验结果。 下表 2 给出了在一台 CPUXeon Silver 4214, 2.2 GHz和一台 GPUNVIDIA RTX3090上运行时间毫秒结果可以看到时间冗余在 GPU 上带来的速度提升可达 1.74 倍在 CPU 上带来的提升可达 2.47 倍。 #Llama~transformers搭建
本例从零开始基于transformers库逐模块搭建和解读Llama模型源码(中文可以翻译成羊驼)。
并且训练它来实现一个有趣的实例两数之和。
输入输出类似如下
输入1234554321
输出66666
我们把这个任务当做一个文本生成任务来进行。输入是一个序列的上半部分输出其下半部分.
这和文本生成的输入输出结构是类似的所以可以用Llama来做。
目前大部分开源LLM模型都是基于transformers库来做的它们的结构大部分都和Llama大同小异。
俗话说魔鬼隐藏在细节中深入理解Llama模型的的源码细节将会帮助你打通和开源LLM模型相关的基础原理(如旋转位置编码以及长度外推)并让你熟悉各种参数的配置和使用(如past_key_valueattention_mask的使用等等)。
一准备数据
import math
from typing import List, Optional, Tuple, Union import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING,LLAMA_START_DOCSTRING logger logging.get_logger(llama) config LlamaConfig( vocab_sizelen(vocab), hidden_size512, intermediate_size2752, num_hidden_layers8, num_attention_heads16, hidden_actsilu, max_position_embeddings128, initializer_range0.02, rms_norm_eps1e-06, use_cacheTrue, pad_token_id0, bos_token_id1, eos_token_id2, tie_word_embeddingsFalse
)
BOS39148356267350577333188294649883914835945564522721EOS
# 定义数据集
class TwoSumDataset(torch.utils.data.Dataset): def __init__(self,size 100000, min_length10,max_length20): super(Dataset, self).__init__() self.size size self.min_lengthmin_length self.max_lengthmax_length def __len__(self): return self.size def __getitem__(self, i): x,y self.get(i) # 编码成token context_ids [vocab[i] for i in x] target_ids [vocab[i] for i in y] input_ids context_ids target_ids #-100标志位后面会在计算loss时会被忽略不贡献损失我们集中优化target部分生成的loss labels [-100]*len(context_ids) target_ids masks [0 if tvocab[PAD] else 1 for t in input_ids] example {input_ids:input_ids, labels:labels,attention_mask:masks} return example def get(self,i): return get_data(self.min_length,self.max_length) def show_example(self,example): input_ids,labels example[input_ids],example[labels] x .join([vocab_r[a] for a,b in zip(input_ids,labels) if b-100]) y .join([vocab_r[a] for a,b in zip(input_ids,labels) if b!-100]) print(xy) ds_train TwoSumDataset(size 100000,min_length10,max_length20)
ds_val TwoSumDataset(size 10000,min_length10,max_length20)
example ds_train[0]
ds_train.show_example(example)
BOS128786839290489063661127441413067547712889958343179581843EOS
def data_collator(examples: list): len_ids [len(example[input_ids]) for example in examples] longest max(len_ids) #之后按照batch中最长的input_ids进行padding input_ids [] labels_list [] masks_list [] for length, example in sorted(zip(len_ids, examples), keylambda x: -x[0]): ids example[input_ids] labs example[labels] masks example[attention_mask] ids [vocab[PAD]] * (longest - length)ids labs [-100] * (longest - length)labs masks [0]*(longest - length)masks input_ids.append(torch.LongTensor(ids)) labels_list.append(torch.LongTensor(labs)) masks_list.append(torch.LongTensor(masks)) input_ids torch.stack(input_ids) labels torch.stack(labels_list) attention_mask torch.stack(masks_list) return { input_ids: input_ids, labels: labels, attention_mask:attention_mask } # 数据加载器
dl_train DataLoader(datasetds_train, batch_size200, drop_lastTrue, shuffleTrue, collate_fn data_collator ) dl_val DataLoader(datasetds_val, batch_size200, drop_lastTrue, shuffleFalse, collate_fn data_collator ) for batch in dl_train: break
batch {input_ids: tensor([[ 1, 11, 6, ..., 7, 11, 2], [ 0, 1, 6, ..., 5, 4, 2], [ 0, 1, 7, ..., 8, 8, 2], ..., [ 0, 0, 0, ..., 10, 11, 2], [ 0, 0, 0, ..., 12, 3, 2], [ 0, 0, 0, ..., 11, 12, 2]]), labels: tensor([[-100, -100, -100, ..., 7, 11, 2], [-100, -100, -100, ..., 5, 4, 2], [-100, -100, -100, ..., 8, 8, 2], ..., [-100, -100, -100, ..., 10, 11, 2], [-100, -100, -100, ..., 12, 3, 2], [-100, -100, -100, ..., 11, 12, 2]]), attention_mask: tensor([[1, 1, 1, ..., 1, 1, 1], [0, 1, 1, ..., 1, 1, 1], [0, 1, 1, ..., 1, 1, 1], ..., [0, 0, 0, ..., 1, 1, 1], [0, 0, 0, ..., 1, 1, 1], [0, 0, 0, ..., 1, 1, 1]])}
二定义模型
下面我们会像搭积木建城堡那样从低往高地构建LLaMA模型。先构建4个基础组件旋转位置编码多头注意力、前馈网络、层归一化。类似用最基础的积木块搭建了 墙壁房顶房门窗户 这样的模块。然后用这4个基础组件构建中间成品: 解码层。类似用基础组件构建了房间。接着用多个中间成品解码层的堆叠组装成了LlamaModel完整模型相当于通过构建多个房间建成了城堡的主体结构。最后我们在LlamaModel基础上设计了两种不同的输出head一种是语言模型Head得到了LlamaForCausalLM可用于文本生成。另外一种是分类head得到了LlamaForSequenceClassification可用于文本分类。相当于我们在城堡主体结构完成的基础上设计了两种不同的装修风格一种是加装了一些游乐设施以便用于商业活动另一种则是加装了一些武器以便用于军事活动。
1, 旋转位置编码: RoPE (使用旋转矩阵实现的绝对位置编码可以起到相对位置编码的效果)
2, 多头注意力: LlamaAttention (用于融合不同token之间的信息)
3, 前馈网络: LlamaMLP (用于逐位置将多头注意力融合后的信息进行高维映射变换)
4, 层归一化: LlamaRMSNorm (用于稳定输入相当于保持每个词向量的方向不变但对模长标准化。)
5, Llama解码层: LlamaDecoderLayer (同时具备信息融合信息转换功能的基本结构单元)
6, Llama解码器: LlamaModel (多个解码层的堆叠)7Llama语言模型: LlamaForCausalLM (解码器加上语言模型head可用于文本生成)8Llama分类模型: LlamaForSequenceClassification (解码器加上分类head可用于文本分类)
import math
from typing import List, Optional, Tuple, Union import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING,LLAMA_START_DOCSTRING logger logging.get_logger(llama) config LlamaConfig( vocab_sizelen(vocab), hidden_size512, intermediate_size2752, num_hidden_layers8, num_attention_heads16, hidden_actsilu, max_position_embeddings128, initializer_range0.02, rms_norm_eps1e-06, use_cacheTrue, pad_token_id0, bos_token_id1, eos_token_id2, tie_word_embeddingsFalse
)
1旋转位置编码 RoPE
旋转位置编码即使用旋转矩阵表示位置编码(Rotary Position Encoding),简称RoPE。
关于RoPE的3个核心要点知识如下
RoPE的设计思想是使用绝对位置编码来达到相对位置编码的效果。
RoPE的实现方式是使用旋转矩阵来表示绝对位置编码。
使用NTK扩展方法可以让RoPE在短文本上训练并在长文本上做预测。
参考文章
《博采众长的旋转式位置编码》https://kexue.fm/archives/8265
《RoPE是一种进制编码》https://kexue.fm/archives/9675
1绝对位置编码和相对位置编码
位置编码一般可以分成绝对位置编码和相对位置编码。
绝对位置编码的优点是计算简单高效缺点是一般效果不如相对位置编码。
相对位置编码的优点是效果较好缺点是计算效率不如绝对位置编码。 在相对位置编码中注意力权重的结果仅仅和参与注意力计算的token向量的相对位置有关不和绝对位置直接关联。
这符合NLP领域在序列长度方向上具有平移不变性的特点所以相对位置编码一般效果会优于绝对位置编码。
不过绝对位置编码并非一无是处绝对位置编码只需要初始化时对序列的每个位置(数量正比于序列长度)赋予位置编码即可后续无需干预。
而相对位置编码要在计算过程中获取许多个(数量正比于序列长度平方)相对位置。
因此绝对位置编码更加简单高效。
2使用旋转矩阵表示位置编码
上述讨论可以看到绝对位置编码和相对位置编码互有优劣那么有没有什么办法能够对二者进行取长补短呢
有的这个方法就是RoPE它的设计思想就是使用绝对位置编码来达到相对位置编码的效果。
那么旋转位置编码如何使用绝对位置编码来达到相对位置编码的效果的呢答案是使用旋转矩阵来表示位置编码。 由于旋转矩阵是稀疏矩阵直接使用乘法计算会很浪费算力可以将旋转位置编码过程由矩阵乘法运算简化成两次向量的哈达玛积求和。 3旋转位置编码的长度扩展
在LLM的应用中有一个非常重要的参数叫做LLM支持的上下文长度(max context length)。
更长的上下文长度允许我们进行更多轮次的对话允许我们对更长的本文进行总结分析也允许我们生成更长的文章。
但是在训练LLM的时候我们的训练语料大部分是不够长的许多LLM训练时候设计的最大文本长度都是只有2k也就是最长2048个token。
那么能否在训练的时候使用较短的文本而在推理的时候扩展到长文本上呢
是有可能的我们可以对RoPE进行长度扩展。
我们介绍3种扩展方案。
第一种是直接外推直接外推其实就是继续沿用现有的位置编码公式不做任何修改。
在扩展长度不太长的时候例如由2k扩展到2.5k时这种方法可能对性能的影响并不大。
因为旋转位置编码只和相对位置m-n的大小有关一般具有远程衰减性即相对距离越大的两个token其相关性一般越弱。
因此如果我们的模型已经从训练数据那里学习到了token之间的相关性相对于相对距离在0-2k的一个合适的衰减规律的时候可以设想把这个规律应用到0-2.5k也是没有太大的问题的。
但是如果我们要扩展到更长的长度例如从2k扩展到32k这种直接外推的方案通常会严重地影响性能。因为我们学习到的衰减规律有可能在5k的那里就完全衰减截断基本降为0了这样我们就无法捕捉相对距离长于5k的两个token之间的相互作用外推就会导致性能下降。
总结一下直接外推对衰减规律在长距离情况下的使用容易出现问题导致性能下降。
为了减少长度外推对性能的影响我们可以让训练好的模型在更长的上下文上做少许步骤的微调。
第二种是线性内插线性内插需要改变位置编码公式等效于将位置序号等比例缩小。 线性内插没有改变模型学习到的衰减规律的应用范围不考虑微调的话其效果一般好于直接外推方案。
但是扩展倍数非常大的时候例如从2k扩展到32k其性能也会明显的受到影响。
因为在这种情况下衰减规律在短距离情况下的使用会受到较严重的影响本来距离为1的两个token长度扩展后相当于变成了距离为1/16衰减规律在短距离时可能具有非常大的变化率因此对相关性的评估可能会极端地偏离合理值。
应用线性内插时在长文本上做少许步骤的微调也能够明显地改善性能。
第三种是NTK扩展方式这种方式综合了外推和内插的优点做长度扩展后即使不微调也能够保持较好的性能。
前面的分析我们知道直接外推对衰减规律在长距离情况下的使用容易出问题在短距离情况下的使用不受影响。
而线性内插对衰减规律在短距离情况下的使用容易出现问题在长距离的情况下影响较小。
我们能否将它们综合起来在短距离情况下具有外推特性(与扩展前基本一致)在长距离情况下具有内插特性(缩放到扩展前的范围)从而使得长距离情况下和短距离情况下衰减规律的使用都不太受到影响呢。 NTK扩展方式的要点是高频外推低频内插实现方法是直接对底数base进行缩放类似进制编码转换。
采用NTK扩展到长文本即使不做微调性能会只会略有下降。
下面是RoPE以及三种长度扩展方式的实现。
class LlamaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings2048, base10000, deviceNone): super().__init__() self.dim dim self.max_position_embeddings max_position_embeddings self.base base inv_freq 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer(inv_freq, inv_freq, persistentFalse) #persistentFalse将不会作为state_dict # Build here to make torch.jit.trace work. self._set_cos_sin_cache( seq_lenmax_position_embeddings, deviceself.inv_freq.device, dtypetorch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached seq_len t torch.arange(self.max_seq_len_cached, devicedevice, dtypeself.inv_freq.dtype) freqs torch.einsum(i,j-ij, t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb torch.cat((freqs, freqs), dim-1) self.register_buffer(cos_cached, emb.cos()[None, None, :, :].to(dtype), persistentFalse) self.register_buffer(sin_cached, emb.sin()[None, None, :, :].to(dtype), persistentFalse) def forward(self, x, seq_lenNone): # x: [bs, num_attention_heads, seq_len, head_size] #超过预设的max_position_embeddings则重新计算更大的Rope缓存否则直接在缓存上切片 if seq_len self.max_seq_len_cached: self._set_cos_sin_cache(seq_lenseq_len, devicex.device, dtypex.dtype) return ( self.cos_cached[:, :, :seq_len, ...].to(dtypex.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtypex.dtype), ) class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev def __init__(self, dim, max_position_embeddings2048, base10000, deviceNone, scaling_factor1.0): self.scaling_factor scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached seq_len t torch.arange(self.max_seq_len_cached, devicedevice, dtypeself.inv_freq.dtype) t t / self.scaling_factor #线性内插相当于将位置序号等比例缩小 freqs torch.einsum(i,j-ij, t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb torch.cat((freqs, freqs), dim-1) self.register_buffer(cos_cached, emb.cos()[None, None, :, :].to(dtype), persistentFalse) self.register_buffer(sin_cached, emb.sin()[None, None, :, :].to(dtype), persistentFalse) class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla def __init__(self, dim, max_position_embeddings2048, base10000, deviceNone, scaling_factor1.0): self.scaling_factor scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached seq_len if seq_len self.max_position_embeddings: base self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) #NTK扩展方式直接对base进行缩放 inv_freq 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer(inv_freq, inv_freq, persistentFalse) t torch.arange(self.max_seq_len_cached, devicedevice, dtypeself.inv_freq.dtype) freqs torch.einsum(i,j-ij, t, self.inv_freq) #此处处理逻辑与原始的ROPE有差异原始逻辑如下 #emb torch.cat((freqs, freqs), dim-1) #emb[...,0::2]freqs #emb[...,1::2]freqs # Different from paper, but it uses a different permutation in order to obtain the same calculation emb torch.cat((freqs, freqs), dim-1) self.register_buffer(cos_cached, emb.cos()[None, None, :, :].to(dtype), persistentFalse) self.register_buffer(sin_cached, emb.sin()[None, None, :, :].to(dtype), persistentFalse) def rotate_half(x): Rotates half the hidden dims of the input. #此处逻辑与原始的ROPE有所差异原始逻辑如下 #x1 x[..., 0::2] #x2 x[..., 1::2] #res torch.cat((x1, x2), dim-1) #res[...,0::2]-x2 #res[...,1::2]x1 #return res x1 x[..., : x.shape[-1] // 2] x2 x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can squeeze them. cos cos.squeeze(1).squeeze(0) # [seq_len, dim] sin sin.squeeze(1).squeeze(0) # [seq_len, dim] cos cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed (q * cos) (rotate_half(q) * sin) k_embed (k * cos) (rotate_half(k) * sin) return q_embed, k_embed x torch.randn(1,8,4,2)
rope LlamaRotaryEmbedding(dim8)
cos,sin rope.forward(x,seq_len4)
print(cos.shape)
print(cos)
torch.Size([1, 1, 4, 8])tensor([[[[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [ 0.5403, 0.9950, 0.9999, 1.0000, 0.5403, 0.9950, 0.9999, 1.0000], [-0.4161, 0.9801, 0.9998, 1.0000, -0.4161, 0.9801, 0.9998, 1.0000], [-0.9900, 0.9553, 0.9996, 1.0000, -0.9900, 0.9553, 0.9996, 1.0000]]]])
2多头注意力 LlamaAttention
这里的LlamaAttention 基本上和《Attention Is All You Need》论文里的是一致的主要差异有以下一些。
1k和v的head数量可以是q的head数量的几分之一类似分组卷积的思想可以减少参数规模。
2rope位置编码是每次做多头注意力时都进行一次而不是原论文只在输入的时候进行一次。
3允许传入key和value的states的缓存past_key_value这在多轮对话中可以减少重复计算起到加速效果。
4attention_mask是通过加法形式作用到softmax之前的attention矩阵上的。
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) - torch.Tensor: This is the equivalent of torch.repeat_interleave(x, dim1, repeatsn_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) batch, num_key_value_heads, slen, head_dim hidden_states.shape if n_rep 1: return hidden_states hidden_states hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class LlamaAttention(nn.Module): Multi-headed attention from Attention Is All You Need paper def __init__(self, config: LlamaConfig): super().__init__() self.config config self.hidden_size config.hidden_size self.num_heads config.num_attention_heads self.head_dim self.hidden_size // self.num_heads self.num_key_value_heads config.num_key_value_heads self.num_key_value_groups self.num_heads // self.num_key_value_heads self.max_position_embeddings config.max_position_embeddings if (self.head_dim * self.num_heads) ! self.hidden_size: raise ValueError( fhidden_size must be divisible by num_heads (got hidden_size: {self.hidden_size} f and num_heads: {self.num_heads}). ) self.q_proj nn.Linear(self.hidden_size, self.num_heads * self.head_dim, biasFalse) self.k_proj nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, biasFalse) self.v_proj nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, biasFalse) self.o_proj nn.Linear(self.num_heads * self.head_dim, self.hidden_size, biasFalse) self._init_rope() def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb LlamaRotaryEmbedding(self.head_dim, max_position_embeddingsself.max_position_embeddings) else: scaling_type self.config.rope_scaling[type] scaling_factor self.config.rope_scaling[factor] if scaling_type linear: self.rotary_emb LlamaLinearScalingRotaryEmbedding( self.head_dim, max_position_embeddingsself.max_position_embeddings, scaling_factorscaling_factor ) elif scaling_type dynamic: self.rotary_emb LlamaDynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddingsself.max_position_embeddings, scaling_factorscaling_factor ) else: raise ValueError(fUnknown RoPE scaling type {scaling_type}) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] None, position_ids: Optional[torch.LongTensor] None, past_key_value: Optional[Tuple[torch.Tensor]] None, output_attentions: bool False, use_cache: bool False, ) - Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ hidden_states.size() if self.config.pretraining_tp 1: key_value_slicing (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices self.q_proj.weight.split( (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim0 ) key_slices self.k_proj.weight.split(key_value_slicing, dim0) value_slices self.v_proj.weight.split(key_value_slicing, dim0) query_states [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] query_states torch.cat(query_states, dim-1) key_states [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] key_states torch.cat(key_states, dim-1) value_states [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states torch.cat(value_states, dim-1) else: query_states self.q_proj(hidden_states) key_states self.k_proj(hidden_states) value_states self.v_proj(hidden_states) query_states query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len key_states.shape[-2] if past_key_value is not None: kv_seq_len past_key_value[0].shape[-2] cos, sin self.rotary_emb(value_states, seq_lenkv_seq_len) query_states, key_states apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states torch.cat([past_key_value[0], key_states], dim2) value_states torch.cat([past_key_value[1], value_states], dim2) past_key_value (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads n_heads key_states repeat_kv(key_states, self.num_key_value_groups) value_states repeat_kv(value_states, self.num_key_value_groups) attn_weights torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() ! (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( fAttention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is f {attn_weights.size()} ) if attention_mask is not None: if attention_mask.size() ! (bsz, 1, q_len, kv_seq_len): raise ValueError( fAttention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()} ) attn_weights attn_weights attention_mask # upcast attention to fp32 attn_weights nn.functional.softmax(attn_weights, dim-1, dtypetorch.float32).to(query_states.dtype) attn_output torch.matmul(attn_weights, value_states) if attn_output.size() ! (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( fattn_output should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is f {attn_output.size()} ) attn_output attn_output.transpose(1, 2).contiguous() attn_output attn_output.reshape(bsz, q_len, self.hidden_size) if self.config.pretraining_tp 1: attn_output attn_output.split(self.hidden_size // self.config.pretraining_tp, dim2) o_proj_slices self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim1) attn_output sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: attn_output self.o_proj(attn_output) if not output_attentions: attn_weights None return attn_output, attn_weights, past_key_value
3前馈网络 LlamaMLP
前馈网络是一个2层的感知机MLP。
先从hidden_size维度up_proj到intermediate_size维度然后再down_proj还原为hidden_size维度。
这里的主要特色是引入了一个gate_proj配合激活函数来实现一个门控注意力的作用。
class LlamaMLP(nn.Module): def __init__(self, config): super().__init__() self.config config self.hidden_size config.hidden_size self.intermediate_size config.intermediate_size self.gate_proj nn.Linear(self.hidden_size, self.intermediate_size, biasFalse) self.up_proj nn.Linear(self.hidden_size, self.intermediate_size, biasFalse) self.down_proj nn.Linear(self.intermediate_size, self.hidden_size, biasFalse) self.act_fn ACT2FN[config.hidden_act] def forward(self, x): if self.config.pretraining_tp 1: slice self.intermediate_size // self.config.pretraining_tp gate_proj_slices self.gate_proj.weight.split(slice, dim0) up_proj_slices self.up_proj.weight.split(slice, dim0) down_proj_slices self.down_proj.weight.split(slice, dim1) gate_proj torch.cat( [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim-1 ) up_proj torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim-1) intermediate_states (self.act_fn(gate_proj) * up_proj).split(slice, dim2) down_proj [ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) ] down_proj sum(down_proj) else: down_proj self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj
4层归一化 LlamaRMSNorm
这里的层归一化叫做RMSNorm和标准的LayerNorm有少许差异。
首先是没有移除均值直接除的RootMeanSquare然后也没有加上bias。
这两个小的修正可以保证在层归一化不会改变hidden_states对应的词向量的方向只会改变其模长。
在一定的意义上具有合理性。
class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps1e-6): LlamaRMSNorm is equivalent to T5LayerNorm super().__init__() self.weight nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon eps def forward(self, hidden_states): input_dtype hidden_states.dtype hidden_states hidden_states.to(torch.float32) variance hidden_states.pow(2).mean(-1, keepdimTrue) hidden_states hidden_states * torch.rsqrt(variance self.variance_epsilon) return self.weight * hidden_states.to(input_dtype)
5Llama解码层
解码层LlamaDecoderLayer由LlamaAttentionLlamaMLP以及两个LlamaRMSNorm组成并使用了两次残差结构。
class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size config.hidden_size self.self_attn LlamaAttention(configconfig) self.mlp LlamaMLP(config) self.input_layernorm LlamaRMSNorm(config.hidden_size, epsconfig.rms_norm_eps) self.post_attention_layernorm LlamaRMSNorm(config.hidden_size, epsconfig.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] None, position_ids: Optional[torch.LongTensor] None, past_key_value: Optional[Tuple[torch.Tensor]] None, output_attentions: Optional[bool] False, use_cache: Optional[bool] False, ) - Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: Args: hidden_states (torch.FloatTensor): input to the layer of shape (batch, seq_len, embed_dim) attention_mask (torch.FloatTensor, *optional*): attention mask of size (batch, 1, tgt_len, src_len) where padding elements are indicated by very large negative values. output_attentions (bool, *optional*): Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail. use_cache (bool, *optional*): If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values). past_key_value (Tuple(torch.FloatTensor), *optional*): cached past key and value projection states residual hidden_states hidden_states self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value self.self_attn( hidden_stateshidden_states, attention_maskattention_mask, position_idsposition_ids, past_key_valuepast_key_value, output_attentionsoutput_attentions, use_cacheuse_cache, ) hidden_states residual hidden_states # Fully Connected residual hidden_states hidden_states self.post_attention_layernorm(hidden_states) hidden_states self.mlp(hidden_states) hidden_states residual hidden_states outputs (hidden_states,) if output_attentions: outputs (self_attn_weights,) if use_cache: outputs (present_key_value,) return outputs
6Llama解码器
LlamaModel由多个Llama解码层堆叠而成。
有几个理解上的要点
1_make_causal_mask用于构造下三角这种mask结构以实现语言模型的单向注意力。
2_expand_mask用于将传入的等特殊符号相关的mask信息展开成和attention矩阵相同的张量结构。
3设置gradient_checkpointingTrue可以节约显存。其主要应用了torch.utils.checkpoint.checkpoint方法。它的原理非常简单在对decoder_layer进行forward时不保存中间激活值从而节约显存backward时重新计算相关值从而通过时间换取了空间。
4gradient_checkpointing和use_cache不能同时设置为True前者是为了节约显存时间换空间的后者是为了节约时间空间换时间。
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int 0
): Make causal mask used for bi-directional self-attention. bsz, tgt_len input_ids_shape mask torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, devicedevice) mask_cond torch.arange(mask.size(-1), devicedevice) mask.masked_fill_(mask_cond (mask_cond 1).view(mask.size(-1), 1), 0) mask mask.to(dtype) if past_key_values_length 0: mask torch.cat([torch.zeros(tgt_len, past_key_values_length, dtypedtype, devicedevice), mask], dim-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len past_key_values_length) # Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] None): Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len]. bsz, src_len mask.size() tgt_len tgt_len if tgt_len is not None else src_len expanded_mask mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) add_start_docstrings( The bare LLaMA Model outputting raw hidden-states without any specific head on top., LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel): config_class LlamaConfig base_model_prefix model supports_gradient_checkpointing True _no_split_modules [LlamaDecoderLayer] _skip_keys_device_placement past_key_values def _init_weights(self, module): std self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean0.0, stdstd) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean0.0, stdstd) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, valueFalse): if isinstance(module, LlamaModel): module.gradient_checkpointing value add_start_docstrings( The bare LLaMA Model outputting raw hidden-states without any specific head on top., LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel): Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [LlamaDecoderLayer] Args: config: LlamaConfig def __init__(self, config: LlamaConfig): super().__init__(config) self.padding_idx config.pad_token_id self.vocab_size config.vocab_size self.embed_tokens nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm LlamaRMSNorm(config.hidden_size, epsconfig.rms_norm_eps) self.gradient_checkpointing False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens value # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] - [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask None if input_shape[-1] 1: combined_attention_mask _make_causal_mask( input_shape, inputs_embeds.dtype, deviceinputs_embeds.device, past_key_values_lengthpast_key_values_length, ) if attention_mask is not None: # [bsz, seq_len] - [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask _expand_mask(attention_mask, inputs_embeds.dtype, tgt_leninput_shape[-1]).to( inputs_embeds.device ) combined_attention_mask ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask combined_attention_mask ) return combined_attention_mask add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor None, attention_mask: Optional[torch.Tensor] None, position_ids: Optional[torch.LongTensor] None, past_key_values: Optional[List[torch.FloatTensor]] None, inputs_embeds: Optional[torch.FloatTensor] None, use_cache: Optional[bool] None, output_attentions: Optional[bool] None, output_hidden_states: Optional[bool] None, return_dict: Optional[bool] None, ) - Union[Tuple, BaseModelOutputWithPast]: output_attentions output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache use_cache if use_cache is not None else self.config.use_cache return_dict return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError(You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time) elif input_ids is not None: batch_size, seq_length input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ inputs_embeds.shape else: raise ValueError(You have to specify either decoder_input_ids or decoder_inputs_embeds) seq_length_with_past seq_length past_key_values_length 0 if past_key_values is not None: past_key_values_length past_key_values[0][0].shape[2] seq_length_with_past seq_length_with_past past_key_values_length if position_ids is None: device input_ids.device if input_ids is not None else inputs_embeds.device position_ids torch.arange( past_key_values_length, seq_length past_key_values_length, dtypetorch.long, devicedevice ) position_ids position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds self.embed_tokens(input_ids) # embed positions if attention_mask is None: attention_mask torch.ones( (batch_size, seq_length_with_past), dtypetorch.bool, deviceinputs_embeds.device ) attention_mask self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) hidden_states inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( use_cacheTrue is incompatible with gradient checkpointing. Setting use_cacheFalse... ) use_cache False # decoder layers all_hidden_states () if output_hidden_states else None all_self_attns () if output_attentions else None next_decoder_cache () if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states (hidden_states,) past_key_value past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) return custom_forward layer_outputs torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids, None, ) else: layer_outputs decoder_layer( hidden_states, attention_maskattention_mask, position_idsposition_ids, past_key_valuepast_key_value, output_attentionsoutput_attentions, use_cacheuse_cache, ) hidden_states layer_outputs[0] if use_cache: next_decoder_cache (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns (layer_outputs[1],) hidden_states self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states (hidden_states,) next_cache next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_statehidden_states, past_key_valuesnext_cache, hidden_statesall_hidden_states, attentionsall_self_attns, )
7Llama语言模型
Llama语言模型 LlamaForCausalLM是在Llama解码器LlamaModel的基础上增加了一个lm_head作为Generator。
从而实现了一个完整的语言模型。
除此之外Llama语言模型还实现了以下重要功能。
1loss计算功能。当forward方法中传入labels时会自动计算语言模型的交叉熵损失。注意labels中的-100会被忽略不参与计算。
2文本生成generate方法。这个方法继承自PreTrainedModel可以设置model.generation_config.num_beams选择束搜索的束宽度默认为1即贪心搜索。
_CONFIG_FOR_DOC LlamaConfig class LlamaForCausalLM(LlamaPreTrainedModel): _tied_weights_keys [lm_head.weight] def __init__(self, config): super().__init__(config) self.model LlamaModel(config) self.vocab_size config.vocab_size self.lm_head nn.Linear(config.hidden_size, config.vocab_size, biasFalse) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head new_embeddings def set_decoder(self, decoder): self.model decoder def get_decoder(self): return self.model add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) replace_return_docstrings(output_typeCausalLMOutputWithPast, config_class_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor None, attention_mask: Optional[torch.Tensor] None, position_ids: Optional[torch.LongTensor] None, past_key_values: Optional[List[torch.FloatTensor]] None, inputs_embeds: Optional[torch.FloatTensor] None, labels: Optional[torch.LongTensor] None, use_cache: Optional[bool] None, output_attentions: Optional[bool] None, output_hidden_states: Optional[bool] None, return_dict: Optional[bool] None, ) - Union[Tuple, CausalLMOutputWithPast]: output_attentions output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs self.model( input_idsinput_ids, attention_maskattention_mask, position_idsposition_ids, past_key_valuespast_key_values, inputs_embedsinputs_embeds, use_cacheuse_cache, output_attentionsoutput_attentions, output_hidden_statesoutput_hidden_states, return_dictreturn_dict, ) hidden_states outputs[0] if self.config.pretraining_tp 1: lm_head_slices self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim0) logits [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits torch.cat(logits, dim-1) else: logits self.lm_head(hidden_states) logits logits.float() loss None if labels is not None: # Shift so that tokens n predict n shift_logits logits[..., :-1, :].contiguous() shift_labels labels[..., 1:].contiguous() # Flatten the tokens loss_fct CrossEntropyLoss() shift_logits shift_logits.view(-1, self.config.vocab_size) shift_labels shift_labels.view(-1) # Enable model parallelism shift_labels shift_labels.to(shift_logits.device) loss loss_fct(shift_logits, shift_labels) if not return_dict: output (logits,) outputs[1:] return (loss,) output if loss is not None else output return CausalLMOutputWithPast( lossloss, logitslogits, past_key_valuesoutputs.past_key_values, hidden_statesoutputs.hidden_states, attentionsoutputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_valuesNone, attention_maskNone, inputs_embedsNone, **kwargs ): if past_key_values: input_ids input_ids[:, -1:] position_ids kwargs.get(position_ids, None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask 0, 1) if past_key_values: position_ids position_ids[:, -1].unsqueeze(-1) # if inputs_embeds are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs {inputs_embeds: inputs_embeds} else: model_inputs {input_ids: input_ids} model_inputs.update( { position_ids: position_ids, past_key_values: past_key_values, use_cache: kwargs.get(use_cache), attention_mask: attention_mask, } ) return model_inputs staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past () for layer_past in past_key_values: reordered_past ( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past
8Llama分类模型
LlamaForSequenceClassification是一个序列分类模型。
这个分类模型可以用来训练RLHF流程中的Reward模型。
add_start_docstrings( The LLaMa Model transformer with a sequence classification head on top (linear layer). [LlamaForSequenceClassification] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last token. If a pad_token_id is defined in the configuration, it finds the last token that is not a padding token in each row. If no pad_token_id is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when inputs_embeds are passed instead of input_ids, it does the same (take the last value in each row of the batch). , LLAMA_START_DOCSTRING,
)
class LlamaForSequenceClassification(LlamaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels config.num_labels self.model LlamaModel(config) self.score nn.Linear(config.hidden_size, self.num_labels, biasFalse) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens value add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor None, attention_mask: Optional[torch.Tensor] None, position_ids: Optional[torch.LongTensor] None, past_key_values: Optional[List[torch.FloatTensor]] None, inputs_embeds: Optional[torch.FloatTensor] None, labels: Optional[torch.LongTensor] None, use_cache: Optional[bool] None, output_attentions: Optional[bool] None, output_hidden_states: Optional[bool] None, return_dict: Optional[bool] None, ) - Union[Tuple, SequenceClassifierOutputWithPast]: r labels (torch.LongTensor of shape (batch_size,), *optional*): Labels for computing the sequence classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels 1 a regression loss is computed (Mean-Square loss), If config.num_labels 1 a classification loss is computed (Cross-Entropy). return_dict return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs self.model( input_ids, attention_maskattention_mask, position_idsposition_ids, past_key_valuespast_key_values, inputs_embedsinputs_embeds, use_cacheuse_cache, output_attentionsoutput_attentions, output_hidden_statesoutput_hidden_states, return_dictreturn_dict, ) hidden_states transformer_outputs[0] logits self.score(hidden_states) if input_ids is not None: batch_size input_ids.shape[0] else: batch_size inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size ! 1: raise ValueError(Cannot handle batch sizes 1 if no padding token is defined.) if self.config.pad_token_id is None: sequence_lengths -1 else: if input_ids is not None: sequence_lengths (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( logits.device ) else: sequence_lengths -1 pooled_logits logits[torch.arange(batch_size, devicelogits.device), sequence_lengths] loss None if labels is not None: labels labels.to(logits.device) if self.config.problem_type is None: if self.num_labels 1: self.config.problem_type regression elif self.num_labels 1 and (labels.dtype torch.long or labels.dtype torch.int): self.config.problem_type single_label_classification else: self.config.problem_type multi_label_classification if self.config.problem_type regression: loss_fct MSELoss() if self.num_labels 1: loss loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss loss_fct(pooled_logits, labels) elif self.config.problem_type single_label_classification: loss_fct CrossEntropyLoss() loss loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type multi_label_classification: loss_fct BCEWithLogitsLoss() loss loss_fct(pooled_logits, labels) if not return_dict: output (pooled_logits,) transformer_outputs[1:] return ((loss,) output) if loss is not None else output return SequenceClassifierOutputWithPast( lossloss, logitspooled_logits, past_key_valuestransformer_outputs.past_key_values, hidden_statestransformer_outputs.hidden_states, attentionstransformer_outputs.attentions, )
三训练模型
下面我们来训练一个LlamaForCausalLM 实现两数之和的任务。
config LlamaConfig( vocab_sizelen(vocab), hidden_size512, intermediate_size2752, num_hidden_layers8, num_attention_heads16, num_key_value_heads4, rope_scaling None, hidden_actsilu, max_position_embeddings128, initializer_range0.02, rms_norm_eps1e-06, use_cacheTrue, pad_token_id0, bos_token_id1, eos_token_id2, tie_word_embeddingsFalse, pretraining_tp 1, max_new_tokens 100
) #试算一下
model LlamaForCausalLM(config)
out model.forward(**batch)
print(out.loss)
tensor(2.7630, grad_fn)from torchkeras import KerasModel
from accelerate import Accelerator class StepRunner: def __init__(self, net, loss_fn, acceleratorNone, stage train, metrics_dict None, optimizer None, lr_scheduler None ): self.net,self.loss_fn,self.metrics_dict,self.stage net,loss_fn,metrics_dict,stage self.optimizer,self.lr_scheduler optimizer,lr_scheduler self.accelerator accelerator if accelerator is not None else Accelerator() if self.stagetrain: self.net.train() else: self.net.eval) def __call__(self, batch): #loss with self.accelerator.autocast(): loss self.net(**batch).loss #backward() if self.stagetrain and self.optimizer is not None: self.accelerator.backward(loss) if self.accelerator.sync_gradients: self.accelerator.clip_grad_norm_(self.net.parameters(), 1.0) self.optimizer.step() if self.lr_scheduler is not None: self.lr_scheduler.step() self.optimizer.zero_grad() all_loss self.accelerator.gather(loss).sum() #losses (or plain metrics that can be averaged) step_losses {self.stage_loss:all_loss.item()} #metrics (stateful metrics) step_metrics {} if self.stagetrain: if self.optimizer is not None: step_metrics[lr] self.optimizer.state_dict()[param_groups][0][lr] else: step_metrics[lr] 0.0 return step_losses,step_metrics KerasModel.StepRunner StepRunner
keras_model KerasModel(model,loss_fn None, optimizertorch.optim.AdamW(model.parameters(),lr3e-5)) #加载 之前训练过的权重
ckpt_path llama_twosum keras_model.fit(train_data dl_train, val_data dl_val, epochs100,patience5, monitorval_loss,modemin, ckpt_path ckpt_path, mixed_precisionfp16 ) 四使用模型
from transformers.generation.utils import GenerationConfig
model.generation_config GenerationConfig.from_dict({num_beams:1, max_new_tokens:100, max_length:200})
model.generation_config.num_beams1
model.generation_config.max_new_tokens 100
model.generation_config.max_length200
def get_ans(tensor) -str: s .join([vocab_r[i] for i in tensor.tolist()]) ans s[s.find()1:s.find(EOS)].replace(BOS,).replace(EOS,) return ans
x,y get_data()
print(x: .join(x).replace(BOS,))
print(y: .join(y).replace(EOS,))
x: 348134005090157504501803
y: 90160985841853
input_ids torch.tensor([[vocab[i] for i in x]])
out model.generate(inputsinput_ids)
out
tensor([[ 1, 5, 6, 10, 3, 5, 6, 12, 12, 7, 12, 13, 11, 12, 3, 7, 9, 7, 12, 6, 7, 12, 3, 10, 12, 5, 14, 11, 12, 3, 8, 12, 11, 10, 7, 10, 6, 3, 10, 7, 5, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 12, 2, 2, 2, 2, 2, 2, 2, 2, 12, 3, 12, 3]])
get_ans(out[0])
90160985841853
五评估模型
from tqdm import tqdm
loop tqdm(range(1,201))
correct 0
for i in loop: x,y get_data() input_ids torch.tensor([[vocab[i] for i in x]]) out model.generate(inputsinput_ids) pred get_ans(out[0]) gt .join(y).replace(EOS,) if predgt: correct1 loop.set_postfix(acc correct/i) print(acc,correct/len(loop))
acc 0.99漂亮我们的测试准确率达到了99% #Replacing softmax with ReLU in Vision Transformers 对于视觉 Transformer将其 Self-Attention 中的 Softmax 操作替换为 ReLU/序列长度 (seqlen) 之后性能的下降问题有所缓解。本文在 ImageNet-21K 上训练了从 Small 级别到 Large 级别的视觉 Transformer证明了 ReLU-attention 可以在缩放性上接近或者匹配 Softmax-attention 的性能。Google出品使用ReLU取代SoftmaxViT性能不退化
本文的研究结论是对于视觉 Transformer将其 Self-Attention 中的 Softmax 操作替换为 ReLU/序列长度 (seqlen) 之后性能的下降问题有所缓解。本文在 ImageNet-21K 上训练了从 Small 级别到 Large 级别的视觉 Transformer证明了 ReLU-attention 可以在缩放性上接近或者匹配 Softmax-attention 的性能。
1 在 ViT 中使用 ReLU 取代 Softmax
论文名称 Replacing softmax with ReLU in Vision Transformers (Arxiv 2023)
论文地址https//arxiv.org/pdf/2309.08586.pdf
1.1 ReLU-attention 的新发现
Transformer 架构[1]在现代机器学习中无处不在。Attention 是 Transformer 的核心组件包括一个 Softmax 操作它在 token 上产生概率分布。Softmax 操作涉及到内部的计算所有输入的指数之和它的计算代价相当昂贵使得 Transformer 架构的并行化具有挑战性[2]。
本文作者探索了 Softmax 操作的 Point-wise 的替代方案该操作不一定输出概率分布。本文的核心贡献是观察到ReLU/序列长度(seqlen) 可以在缩放性方面接近或匹配传统的 Softmax 操作。这一结果为并行化提供了新的机会因为 ReLU-attention 相比传统的 Softmax-attention 可以使用更少的 gather 操作在序列长度维度实现并行化。
1.2 去掉 Softmax 的相关工作
替换 Softmax 的研究
ReLU 和 squared ReLU[3][4]把 Softmax 替换成了 ReLU[5]把 Softmax 替换成了 squared ReLU。但是这些方法不会除以序列长度本文通过实验发现对于达到与 Softmax 相当的准确度很重要。[6]仍然需要对序列长度轴进行归一化以确保注意力权重之和为1这依然需要 gather。
去掉激活函数的研究 1.3 ReLU-attention 方法
在进行 Self-attention 的操作时首先计算注意力权重 图1Scaled point-wise attention 实验结果
Sequence length scaling 1.4 实验结果
作者在 ImageNet-21K 上训练了 30 Epochs在 ImageNet-1K 上训练了 300 Epochs。作者使用了 ViT-22B[10]中提出的 qk-norm 技术因为这个技术被验证在扩大视觉模型时有益于优化稳定性但是作者发现在本文量级的模型这一技术没那么重要。
如下图2所示说明了 ReLU-attention 与 ImageNet-21K 训练的 Softmax-attention 的缩放趋势相匹配。x 轴表示实验所需的总 core hours。ReLU-attention 的优势是能够以比 Softmax-attention 以更少的 gather 操作对序列长度维度进行并行化。 图2Softmax 操作替换为 ReLU/seqlen 的缩放性能与传统带有 qk-layernorm 的 Transformer 的缩放性能匹配
1.5 qk-norm 实验结果
本文主要实验使用了 qk-norm其中 query 和 key 在计算注意力权重之前通过 LayerNorm 传递作者发现有必要在扩大模型大小时防止不稳定性。如图3所示是 qk-layernorm 的实验结果。结果表明qk-norm 对这些模型没有很大的影响。 图3qk-norm 实验结果
1.6 添加 gate 的影响
[11]这个工作删除了 Softmax 之后添加了一个门控单元并且不按序列长度缩放。具体而言在门控注意力单元中通过额外的投影层产生输出该输出在输出映射之前与注意力的结果做 Element-wise 的乘法。
如图4所示是添加 gate 的影响实验结果。作者研究了 gate 的存在是否消除了序列长度缩放的需要。总体而言作者观察到无论有没有 gate 的存在使用序列长度缩放都实现了最佳精度。注意到对于带有 ReLU 的 S/8 模型添加 gate 操作将实验所需的 core hour 增加了大约 9.3%。 图4添加 gate 的影响 #Transformer~目标检测算法汇总
都到了13了 ~~ 还是基于这个的么办法 自从VIT横空出世以来Transformer在CV界掀起了一场革新各个上下游任务都得到了长足的进步然后盘点一下基于Transformer的端到端目标检测算法
原始Tranformer检测器
DETRECCV2020
开山之作DETR
代码链接https://github.com/facebookresearch/detr
论文提出了一种将目标检测视为直接集预测问题的新方法。DETR简化了检测流程有效地消除了对许多人工设计组件的需求如NMS或anchor生成。新框架的主要组成部分称为DEtection TRansformer或DETR是一种基于集合的全局损失通过二分匹配强制进行一对一预测以及一种transformer encoder-decoder架构。给定一组固定的学习目标查询DETR分析了目标和全局图像上下文之间的关系以直接并行输出最后一组预测。与许多其他检测器不同新模型概念简单不需要专门的库。DETR在具有挑战性的COCO目标检测数据集上展示了与成熟且高度优化的Faster RCNN基线相当的准确性和运行时间。此外DETR可以很容易地推广到以统一的方式输出全景分割。
DETR的网络结构如下图所示从图中可以看出DETR由四个主要模块组成backbone编码器解码器以及预测头。主干网络是经典的CNN输出降采样32倍的feature。 实验结果如下所示性能上倒是还不错就是训练太慢了300 epochs。 DETR还展示了COCO上的全景分割结果可以看出实例区分能力还是比较有限中间的Bus。 Pix2seq谷歌Hinton
代码链接https://github.com/google-research/pix2seq
一句话总结一个简单而通用的目标检测新框架其将目标检测转换为语言建模任务大大简化了pipeline性能可比肩Faster R-CNN和DETR还可扩展到其他任务。
论文提出Pix2Seq一个简单而通用的目标检测框架与显式集成关于任务的先验知识的现有方法不同Pix2seq将目标检测作为一个基于观察到的像素输入的语言建模任务。目标描述例如边界框和类标签表示为离散token训练神经网络来感知图像并生成所需序列。Pix2seq主要基于这样一种直觉即如果神经网络知道目标的位置和内容我们只需要教它如何read them out。除了使用特定于任务的数据扩充Pix2seq对任务的假设最少但与高度专业化和优化的检测算法相比它在具有挑战性的COCO数据集上取得了有竞争力的结果。
网络主要包含四个组件
图像增强正如在训练计算机视觉模型中常见的那样论文使用图像增强来丰富一组固定的训练示例例如使用随机缩放和裁剪序列构造和扩充由于图像的目标注释通常表示为一组边界框和类标签论文将它们转换为一系列离散token架构使用编码器-解码器模型其中编码器感知像素输入解码器生成目标序列一次一个token目标/损失函数对模型进行训练以最大化基于图像和先前token的token的对数似然性使用softmax cross-entropy loss。 序列构造示意图 训练300 epochs实验结果 稀疏注意力Deformable DETRICLR 2021
代码链接https://github.com/fundamentalvision/Deformable-DETR
最近提出了DETR以消除在物体检测中对许多手动设计部件的需要同时证明了良好的性能。然而由于Transformer注意力模块在处理图像特征图时的限制它存在收敛速度慢和特征空间分辨率有限的问题。为了缓解这些问题论文提出了Deformable DETR其注意力模块只关注参考周围的一小组关键采样点。Deformable DETR可以实现比DETR更好的性能特别是在小目标上训练时间减少10倍。COCO基准的大量实验证明了算法的有效性。 DETR存在的问题
训练周期长相比faster rcnn慢10-20倍小目标性能差通常用多尺度特征来解小目标然而高分辨率的特征图大大提高DETR复杂度
- 存在上述问题的原因
初始化时attention model对于特征图上所有像素权重几乎是统一的即一个query与所有的k相乘的贡献图比较均匀理想状况是q与高度相关且稀疏的k相关性更强因此需要长时间学习更好的attention map处理高分辨率特征存在计算量过大存储复杂的特点
- Motivation
让encoder初始化的权重不再是统一分布即不再与所有key计算相似度而是与更有意义的key计算相似度可变形卷积就是一种有效关注稀疏空间定位的方式提出deformable DETR融合deformable conv的稀疏空间采样与transformer相关性建模能力在整体feature map像素中模型关注小序列的采样位置作为预滤波作为key。
实验结果 End-to-End Object Detection with Adaptive Clustering Transformer北大港中文代码链接https://github.com/gaopengcuhk/SMCA-DETR/DETR 本文的主要贡献如下
开发了一种称为自适应聚类TransformerACT的新方法该方法可以降低DETR的推理成本。ACT可以降低原始Transformer的二次复杂度同时ACT与原始Transformer完全兼容将DETR的FLOPS从73.4 Gflops减少到58.2 Gflops不包括骨干Resnet FLOPS而无需任何训练过程而AP的损失仅为0.7%通过多任务知识蒸馏MTKD进一步将AP的损失降低到0.2%该技术实现了ACT和原始Transformer之间的无缝切换。
实验结果如下 PnP-DETRICCV 2021
论文链接GitHub - twangnh/pnp-detr: Implementation of ICCV21 paper: PnP-DETR: Towards Efficient Visual Analysis with Transformers
DETR虽然有效但由于在某些区域如背景上的冗余计算转换完整的特征图可能代价高昂。在这项工作中论文将减少空间冗余的思想封装到一个新的poll and poolPnP采样模块中利用该模块构建了一个端到端PnP DETR架构该架构自适应地在空间上分配其计算以提高效率。具体地说PnP模块将图像特征映射抽象为精细的前景目标特征向量和少量粗略的背景上下文特征向量。Transformer对精细-粗糙特征空间内的信息交互进行建模并将特征转换为检测结果。此外通过改变采样特征长度PnP增强模型可以立即在单个模型的性能和计算之间实现各种期望的权衡而不需要像现有方法那样训练多个模型。因此它为具有不同计算约束的不同场景中的部署提供了更大的灵活性。论文进一步验证了PnP模块在全景分割上的泛化性以及最近基于Transformer的图像识别模型ViT[7]并显示出一致的效率增益。论文认为PnP-DETR为使用Transformer进行有效的视觉分析迈出了一步其中通常观察到空间冗余。 本文的主要贡献如下
分析了DETR模型中图像特征图的空间冗余问题该问题导致transformer网络计算量过大。因此提出对特征映射进行抽象以显著降低模型运算量设计了一种新颖的两步轮询池采样模块提取特征。该算法首先利用poll采样器提取前景精细特征向量然后利用pool采样器获取上下文粗特征向量构建了PnP-DETR该变换在抽象的细粗特征空间上进行操作并自适应地将计算分布在空间域。通过改变精细特征集的长度PnP-DETR算法效率更高在单一模型下实现了即时计算和性能折衷。PnP抽样模块是通用的是端到端学习的没有像RPN那样的明确监督。论文进一步在全景分割和最近的ViT模型上对其进行了验证并显示出一致的效率增益。这种方法为未来研究使用transformer的视觉任务的有效解决方案提供了有用的见解。实验结果如下
Sparse DETRICLR 2022
代码链接https://github.com/kakaobrain/sparse-detr
Deformable DETR使用多尺度特征来改善性能然而与DETR相比encoder tokens的数量增加了20倍encoder注意力的计算成本仍然是一个瓶颈。在本文的初步实验中发现即使只更新了encoder tokens的一部分检测性能也几乎不会恶化。受这一观察的启发论文提出了Sparse DETR它只选择性地更新decoder预期引用的令牌从而帮助模型有效地检测目标。此外在encoder中对所选token应用辅助检测损失可以提高性能同时最小化计算开销。本文验证了Sparse DETR即使在COCO数据集上只有10%的encoder tokens也比Deformable DETR获得更好的性能。尽管只有encoder tokens被稀疏化但与Deformable DETR相比总计算成本降低了38%FPS增加了42%。 论文的主要贡献如下
提出了一种有效的端到端目标检测器的编码器token稀疏化方法通过该方法减轻了编码器中的注意力复杂性。这种效率使得能够堆叠比Deformable DETR更多的编码器层从而在相同的计算量下提高性能提出了两个新的稀疏化标准来从整个token集合中采样信息子集Objectness ScoreOS和Decoder cross-Attention MapDAM。基于decoder cross-attention map标准稀疏模型即使在仅使用整个token的10%时也保持了检测性能仅对所选token采用编码器辅助损失。这种额外的损失不仅稳定了学习过程而且大大提高了性能只略微增加了训练时间。 实验结果如下 空间先验Fast Convergence of DETR with Spatially Modulated Co-AttentionICCV 2021
DETR的收敛速度较慢。从头开始训练DETR[4]需要500个epoch才能获得高精度。为了加速其收敛本文提出了一种简单而有效的改进DETR框架的方案即Spatially Modulated Co-AttentionSMCA机制。SMCA的核心思想是通过将co-attention响应限制在初始估计的边界框位置附近的较高区域在DETR中进行regression-aware co-attention。本文提出的SMCA通过替换decoder中的原始co-attention同时保持DETR中的其他操作不变提高了DETR的收敛速度。此外通过将multi-head和scale-selection注意力设计集成到SMCA中与基于空洞卷积的主干的DETR相比本文的SMCA可以实现更好的性能。论文对COCO数据集进行了广泛的消融研究以验证所提出的SMCA的有效性。 主要贡献如下
提出了一种新的空间调制共同注意SMCA它可以通过进行位置约束目标回归来加速DETR的收敛。SMCA是原始DETR中的即插即用模块。没有多尺度特征和多头注意力的SMCA的基本版本已经可以在50个epoch达到41.0 mAP在108个时期达到42.7 mAP。将SMCA的基本版本训练50个时期需要265个V100 GPU小时。完整SMCA进一步集成了多尺度特征和多头空间调制这可以通过更少的训练迭代进一步显著改进和超越DETR。SMCA在50个epoch可达到43.7mAP在108个时期可实现45.6mAP而DETR-DC5在500个时期可获得43.3mAP。将完整的SMCA训练50个epoch需要600 V100 GPU小时。对COCO 2017数据集进行了广泛的消融研究以验证所提出的SMCA模块和网络设计。
动机
为了加速DETR收敛本文通过动态预测一个2D的空间高斯weight map来跟co-attention feature maps相乘来达到加快收敛速度的目的。即插即用让DETR涨点明显。性能优于可变形DETR、DETR等网络。实验结果如下 Conditional DETRICCV 2021
本文针对DETR训练收敛缓慢这一关键问题提出了一种用于快速DETR训练的conditional cross-attention机制。动机是DETR中的cross-attention高度依赖内容嵌入来定位和预测box这增加了对高质量内容嵌入的需求从而增加了训练难度。
本文的方法称为Conditional DETR从解码器嵌入中学习条件空间query用于解码器multi-head cross-attention。好处在于通过条件空间query每个交叉注意力头能够关注包含不同区域的band例如一个目标末端或目标框内的区域。这缩小了用于定位目标分类和box回归的不同区域的空间范围从而放松了对内容嵌入的依赖并简化了训练。实验结果表明对于主干R50和R101Conditional DETR收敛速度快6.7倍对于更强的主干DC5-R50和DC5-R101收敛速度快10倍。 动机
为了分析 DETR 为什么收敛慢论文对 DETR decoder cross-attention 中的 spatial attention map 进行了可视化。 每个 head 的 spatial attention map 都在尝试找物体的一个 extremity 区域。论文认为DETR 在计算 cross-attention 时query 中的 content embedding 要同时和 key 中的 content embedding 以及 key 中的 spatial embedding 做匹配这就对 content embedding 的质量要求非常高。而训练了 50 epoch 的DETR因为 content embedding 质量不高无法准确地缩小搜寻物体的范围导致收敛缓慢。所以用一句话总结 DETR 收敛慢的原因就是DETR 高度依赖高质量的 content embedding 去定位物体的 extremity 区域而这部分区域恰恰是定位和识别物体的关键。
基于此提出Conditional DETR
实验结果如下
Anchor DETRAAAI 2022
代码链接https://github.com/megvii-research/AnchorDETR
本文提出了一种新的基于Transfomrer的目标检测查询机制。在以前的基于Transfomrer的检测器中object query是一组学习的嵌入。然而每个学习到的嵌入都没有明确的物理意义我们无法解释它将集中在哪里。由于每个object query的预测slot没有特定的模式因此很难进行优化。换句话说每个object query都不会关注特定区域。为了解决这些问题在本文的query设计中object query基于anchor point这在基于CNN的检测器中被广泛使用。因此每个object query都集中在anchor附近的目标上。此外本文的query设计可以在一个位置预测多个目标以解决困难“一个区域多个目标”。此外本文设计了一种注意力变体它可以降低内存成本同时实现与DETR中的标准注意力相似或更好的性能。由于query设计和注意力变体本文方法名为Anchor DETR可以实现比DETR更好的性能并且运行速度比DETR更快。 回顾基于CNN的检测器anchor与位置高度相关包含可解释的意义。受此启发作者提出了一种基于锚点anchor points的查询设计即将anchor points编码为目标查询。查询是锚点坐标的编码因此每个目标查询都具有显式的物理意义。
但是这个解决方案还有一个限制多个目标可能出现在一个位置 。在这种情况下只有这个位置的一个查询不能预测多个目标因此来自其他位置的查询必须协同预测这些目标。它将导致每个目标查询负责一个更大的区域。因此作者通过向每个锚点添加多个模式multiple patterns即一个锚点可以检测多个目标来改进目标查询设计以便每个锚点都可以预测多个目标
除了查询设计之外作者还设计了一个attention变体—行列解耦注意(Row-Column Decouple AttentionRCDA) 。它将二维key特征解耦为一维行特征和一维列特征然后依次进行行注意力和列注意力。RCDA可以降低计算成本同时实现与DETR中的标准注意力相似甚至更好的性能。
实验结果如下 Efficient DETR旷视
DETR和Deformable DETR具有堆叠6个解码器层的级联结构以迭代更新object query否则它们的性能会严重下降。本文研究了目标容器包括object query和reference point的随机初始化主要负责多次迭代的需求。基于论文的发现提出了Efficient DETR这是一种用于端到端目标检测的简单高效的管道。通过利用密集检测和稀疏集合检测Efficient DETR在初始化目标容器之前利用密集先验并消除了1解码器结构和6解码器结构之间的差距。在MS COCO上进行的实验表明本文的方法仅具有3个编码器层和1个解码器层与最先进的目标检测方法相比可以获得具有竞争力的性能。Efficient DETR在拥挤的场景中也很强大。它在CrowdHuman数据集上大大优于当期检测器。 实验结果如下 Dynamic DETRICCV 2021
本文提出了一种新的Dynamic DETRTransfomrer检测方法将动态注意力引入DETR的编码器和解码器阶段以打破其在小特征分辨率和训练收敛慢方面的两个限制。为了解决第一个限制这是由于Transformer编码器中的自注意力模块的二次计算复杂性论文提出了一种动态编码器以使用具有各种注意力类型的基于卷积的动态编码器来近似Transformer编码器的注意力机制。这种编码器可以基于诸如尺度重要性、空间重要性和表示即特征维度重要性的多个因素来动态调整注意力。为了减轻学习难度的第二个限制论文引入了一个动态解码器通过在Transformer解码器中使用基于ROI的动态注意力来替换交叉注意力模块。这种解码器有效地帮助Transfomrer从coarse-to-fine地关注ROI并显著降低学习难度从而实现更快的收敛。论文进行了一系列实验来证明我们的优势。Dynamic DETR显著缩短了训练时间减少了14倍但性能要好得多mAP提升3.6。 本文的主要贡献如下
提出了一种新的Dynamic DETR方法它相干地结合了基于动态卷积的编码器和基于动态Transformer的解码器。该方法显著提高了目标检测头的表示能力和学习效率而无需任何计算开销。与原始的DETR相比Dynamic DETR大大减少了训练时间减少了14倍但却显著提高了性能3.6 mAP如图1所示是第一个在标准1x设置中实现优于传统性能的端到端方法采用ResNet-50主干42.9mAP。 实验结果如下 结构重新设计Rethinking Transformer-based Set Prediction for Object DetectionICCV 2021
代码链接GitHub: Let’s build from hereEdward-Sun/TSP-Detection
DETR是最近提出的一种基于Transformer的方法它将目标检测视为一个集合预测问题并实现了最先进的性能但需要额外的训练时间来收敛。本文研究了DETR训练中优化困难的原因揭示了导致DETR缓慢收敛的几个因素主要是匈牙利损失和Transformer中co-attention的问题。为了克服这些问题本文提出了两种解决方案即TSP-FCOS使用FCOS的基于Transformer的集合预测和TSP-RCNN使用RCNN的基于Transformer集合预测。实验结果表明所提出的方法不仅比原始DETR收敛更快而且在检测精度方面显著优于DETR和其他基线。 TSP-FCOS在backbone和encoder之间加上了headTSP-RCNN在backbone和encoder之间加上了RoIAlign
实验结果如下 You Only Look at One Sequence: Rethinking Transformer in Vision through Object DetectionNeurIPS 2021
代码链接GitHub - hustvl/YOLOS: You Only Look at One Sequence (NeurIPS 2021)
Transformer能否在对2D空间结构了解最少的情况下从纯sequence-to-sequence的角度进行2D目标和区域级别的识别为了回答这个问题论文提出了“你只看一个序列”YOLOS这是一系列基于朴素视觉Transformer的目标检测模型具有最少的可能修改、区域优先级以及目标任务的归纳偏差。论文发现只有在中型ImageNet-1k数据集上预训练的YOLOS才能在COCO目标检测基准上获得相当有竞争力的性能例如直接采用BERT-Base架构的YOLOS-Base可以在COCO值上获得42.0 box AP。论文还通过YOLOS讨论了当前预训练方案和Transformer模型缩放策略的影响和局限性。 本文的主要贡献如下
使用中等大小的ImageNet-1k[51]作为唯一的预训练数据集并表明可以成功地迁移到普通ViT[21]以执行复杂的目标检测任务并在COCO[36]基准上以最少的可能修改即only looking at one sequenceYOLOS输出有竞争力的结果首次证明通过将固定大小的非重叠图像块序列作为输入可以以纯序列到序列的方式完成2D目标检测。在现有的物体检测器中YOLOS利用最小的2D感应偏置。对于朴素ViT论文发现目标检测结果对预训练方案非常敏感并且检测性能远未饱和。因此所提出的YOLOS也可以作为一项具有挑战性的基准任务以评估不同的标签监督和自监督ViT预训练策略。
实验结果如下 匹配优化DN-DETRCVPR 2022
代码链接https://github.com/FengLi-ust/DN-DETR
本文提出了一种新的去噪训练方法以加速DETRDEtection TRansformer训练并加深了对类DETR方法的收敛慢问题的理解。本文认为收敛缓慢是由于二分匹配的不稳定性导致的这在早期训练阶段导致了不一致的优化目标。为了解决这个问题除了匈牙利损失外论文还将带有噪声的GT框输入Transformer解码器并训练模型以重建原始框这有效地降低了二分匹配的难度并可以更快的收敛。本文的方法是通用的可以通过添加几十行代码轻松地插入到任何类DETR的方法中以实现显著的改进。因此DN-DETR在相同的设置下产生了显著的改进1.9AP。与相同设置下的基线相比DN-DETR在50%的训练时间内实现了可比的性能。 本文的主要贡献如下
设计了一种新的训练方法来加速DETR训练。实验结果表明我们的方法不仅加快了训练收敛而且导致了显著更好的训练结果—在12个epoch设置下在所有检测算法中获得最佳结果。此外我们的方法显示出比基线DAB-DETR显著的改进1.9AP并且可以很容易地集成到其他类DETR的方法中从一个新的角度分析了DETR的缓慢收敛并对DETR训练有了更深入的理解。设计了一个度量来评估二分匹配的不稳定性并验证了我们的方法可以有效地降低不稳定性进行了一系列消融研究以分析我们模型中不同组件的有效性如噪声、标签嵌入和注意力mask。
实验结果后如下 DINO
代码链接https://github.com/IDEACVR/DINO
本文提出DINO这是一种先进的端到端目标检测器。DINO通过使用对比的去噪训练方法、anchor初始化的混合query选择方法和box预测的look forward twice方案在性能和效率上改进了以前的类DETR模型。DINO在具有ResNet-50主干和多尺度特征的COCO上实现了12个epoch的49.4 AP和24个epoch的51.3AP与之前最好的类DETR的模型DN-DETR相比分别显著提高了6.0 AP和2.7 AP。DINO在模型大小和数据大小方面都具有很好的扩展性。没有任何trick在使用SwinL主干的Objects365数据集上进行预训练后DINO在COCO val 201763.2AP和测试集63.3AP上都获得了最好的结果。与排行榜上的其他模型相比DINO显著减少了其模型大小和预训练数据大小同时获得了更好的结果。 本文的主要贡献如下
设计了一种新的端到端类DETR的目标检测器采用了几种新技术包括对比DN训练、混合查询选择并对DINO模型的不同部分进行了两次前向。进行了深入的消融研究以验证DINO中不同设计选择的有效性。因此DINO通过ResNet-50和多尺度特征在12个epoch内达到49.4AP在24个epoch内实现51.3AP显著优于之前最好的类DETR的模型。特别是在12个epoch训练的DINO在小目标上表现出更显著的改善提高了7.5AP。不用任何trickDINO可以在公共基准上取得最好的成绩。在使用SwinL[23]主干对Objects365[33]数据集进行预训练后DINO在COCO val201763.2AP和测试集63.3AP基准上都取得了最好的结果。据我们所知这是端到端Transformer检测首次在COCO排行榜上超过最先进SOTA模型[1]。实验结果如下