给网站添加百度地图,绵阳做绵阳做网站网站,那个网站做玉石最专业,上海公司注销流程1.引言
1.1. MobileViT是什么#xff1f;
MobileViT是一种基于Transformer的轻量级视觉模型#xff0c;专为移动端设备上的图像分类任务而设计。
背景与目的#xff1a;
MobileViT由Google在2021年提出#xff0c;旨在解决移动设备上的实时图像分类需求。与传统的卷积神…1.引言
1.1. MobileViT是什么
MobileViT是一种基于Transformer的轻量级视觉模型专为移动端设备上的图像分类任务而设计。
背景与目的
MobileViT由Google在2021年提出旨在解决移动设备上的实时图像分类需求。与传统的卷积神经网络CNN相比MobileViT在保持高性能的同时显著降低了计算复杂度和内存需求从而更适应移动设备的计算能力。
技术特点
轻量级与移动友好MobileViT通过引入轻量级的Transformer模块和有效的降维策略大幅减少了模型的参数数量和计算复杂度使其能够在移动设备上高效运行。基于TransformerMobileViT采用了Transformer架构通过自注意力机制捕获图像的全局上下文信息提高了模型的泛化能力和准确性。优化方法MobileViT采用了一系列优化方法如混合精度训练和自适应模型调整等进一步提高了在移动设备上的运行效率。
性能表现
在多个图像分类数据集上MobileViT均取得了与现有轻量级CNN模型相当或更优的性能。例如在ImageNet-1k数据集上MobileViT在大约600万个参数的情况下达到了78.4%的Top-1准确率。MobileViT显示出更好的泛化能力即使在使用大量数据增强的情况下也能更好地预测未知数据集上的表现。与其他基于Transformer的模型相比MobileViT对超参数的调整相对健壮对L2正则化等超参数的敏感度较低。
适用场景MobileViT特别适用于需要实时图像分类的移动端应用如智能手机、平板电脑等。其轻量级和高效的特点使得它成为移动视觉任务中的理想选择。
综合而言MobileViT通过结合轻量级的Transformer架构和优化的设计成功实现了在移动设备上高效、准确的图像分类。其优秀的性能、泛化能力和对超参数的鲁棒性使得它在移动视觉领域具有广泛的应用前景。
1.2.Transformer架构的特点
Transformer架构最初由谷歌大脑在2017年的论文《Attention Is All You Need》中提出是一种基于自注意力机制的序列到序列Seq2Seq模型。自提出以来该模型在自然语言处理NLP和计算机视觉CV等领域取得了显著的成功并多次达到该领域内的最佳效果SOTA。
核心思想
Transformer架构的核心思想是使用自注意力机制self-attention mechanism来建立输入序列的表示。相比于传统的循环神经网络RNN架构Transformer能够并行地处理整个序列而不是按顺序逐步处理从而提高了计算效率。
架构组成
Transformer架构主要由两个主要组件组成编码器Encoder和解码器Decoder。 编码器Encoder 主要负责将输入序列转化为一种中间表示形式这种表示形式能够捕捉输入序列中的上下文信息。 编码器由多个相同的层堆叠而成每个层都包含自注意力机制和前馈神经网络Feed-Forward Neural Network。 自注意力机制允许模型在序列内的任意位置间直接建立依赖从而更好地理解数据的上下文关系。 位置编码Positional Encoding用于提供关于单词在序列中位置的信息因为Transformer不使用基于顺序的结构。 解码器Decoder 主要负责根据编码器的输出和之前的解码输出生成新的序列。 解码器同样由多个相同的层堆叠而成其结构与编码器类似但还包含了一个额外的自注意力层和一个编码器-解码器注意力层。 编码器-解码器注意力层允许解码器关注编码器输出的不同位置从而帮助生成准确的输出序列。
特点与优势
并行处理能力Transformer能够并行地处理整个序列而不是像RNN那样按顺序逐步处理这大大提高了计算效率。长距离依赖建模能力通过自注意力机制Transformer能够建模输入序列中的长距离依赖关系这在处理长序列时尤为重要。多头注意力机制Transformer采用多头注意力机制允许模型同时学习数据的不同表示每个“头”关注序列的不同部分这有助于模型捕捉更丰富的信息。灵活性Transformer架构非常灵活可以应用于各种序列生成任务如机器翻译、文本摘要、语音识别等。
综合而言Transformer架构在自然语言处理领域特别流行例如BERT和GPT等预训练语言模型就是从Transformer中衍生出来的。此外Transformer架构也被广泛应用于计算机视觉领域如图像分类、目标检测等任务。在智能驾驶领域Transformer架构也被用于感知、预测和决策等各个环节。
1.3. 研究内容
在本文的例子中我们将介绍并实现MobileViT架构该架构是由Mehta等人提出的它融合了Transformer由Vaswani等人开创和卷积神经网络的优点。通过TransformerMobileViT能够捕获图像中的长距离依赖关系从而生成全局表示而卷积操作则帮助模型捕捉图像中的局部空间关系。
MobileViT的设计不仅结合了Transformer和卷积的特性还作为一个通用且移动友好的骨干网络适用于各种图像识别任务。据研究结果显示在性能方面MobileViT相比其他复杂度相近或更高的模型如MobileNetV3具有优势同时保持了在移动设备上的高效运行。
请注意为了成功运行这个示例您需要安装TensorFlow 2.13或更高版本。
1.4. 研究意义
随着移动设备应用的广泛普及图像分类等计算机视觉任务在移动设备上的需求日益增长。然而传统的深度学习模型特别是基于卷积神经网络CNN的模型往往面临着计算资源和存储需求的限制难以在移动设备上高效运行。因此开发轻量级、高效的深度学习模型成为了一个迫切的研究需求。
MobileViT模型通过融合Transformer和卷积神经网络的优势为解决移动设备上的图像分类问题提供了新的思路。它利用Transformer的自注意力机制捕捉图像中的长距离依赖关系同时结合卷积操作捕捉局部空间关系从而在保持高性能的同时降低了计算复杂度和内存需求。相比传统的轻量级CNN模型MobileViT在多个图像分类数据集上均取得了优异的性能证明了其在移动设备图像分类任务中的有效性和实用性。
MobileViT的研究不仅具有理论价值还具有重要的实际应用前景。它能够为移动设备上的实时图像处理任务提供高效的解决方案为用户带来更好的使用体验。随着移动设备性能的不断提升和计算资源的持续优化MobileViT有望在更多领域得到应用推动移动设备上的计算机视觉技术向前发展。同时MobileViT的研究也为其他轻量级深度学习模型的设计和优化提供了有益的参考。
2. 部署MobileViT
2.1.设置
2.1.1.导入函数库
# 导入必要的库
import os
import tensorflow as tf # 设置Keras的后端为TensorFlow虽然Keras现在默认后端就是TensorFlow但这里显式设置以确保环境配置正确
os.environ[KERAS_BACKEND] tensorflow # 导入Keras库以及相关的layers和backend模块
import keras
from keras import layers
from keras import backend as K # 导入tensorflow_datasets库用于加载数据集
import tensorflow_datasets as tfds # 禁用tensorflow_datasets在加载数据时的进度条显示以避免在输出中显示额外的进度信息
tfds.disable_progress_bar()2.2.2.设置超参数
# 这些值来自表4。
patch_size 4 # 2x2用于Transformer块。
image_size 256 # 输入图像的尺寸。
expansion_factor 2 # MobileNetV2块的扩展因子。这段代码定义了三个变量分别用于设置Transformer块的Patch大小、输入图像的尺寸以及MobileNetV2块的扩展因子。这些参数对于构建MobileViT模型是必要的。
2.2.构建MobileViT
MobileViT架构是一个专为移动设备设计的图像分类模型它巧妙地结合了Transformer和卷积神经网络的优点以实现高效且准确的图像识别。
1. 输入处理
在模型的初始阶段输入图像首先通过一系列带步长的3x3卷积层进行处理。这些卷积层不仅用于提取图像的初步特征还通过调整步长来逐步降低特征图的分辨率从而减少后续层的计算量。
2. MobileNetV2风格倒置残差块
在特征提取的过程中MobileViT采用了MobileNetV2风格的倒置残差块进行特征转换和降采样。这些倒置残差块首先通过1x1卷积进行通道扩展然后利用深度可分离卷积进行空间特征提取最后再通过1x1卷积将特征图通道数恢复到原始大小。通过这种方式倒置残差块能够在不增加过多计算量的前提下有效地提高模型的特征提取能力。
3. MobileViT块
MobileViT架构的核心在于其独特的MobileViT块。这些块结合了Transformer和卷积神经网络的优点旨在捕获图像中的长距离依赖关系和局部空间关系。具体来说MobileViT块首先通过自注意力机制如多头自注意力计算特征图中不同位置之间的相关性从而捕获长距离依赖关系。然后它利用卷积操作对特征图进行局部空间特征的提取和融合。通过这种方式MobileViT块能够同时利用Transformer的全局建模能力和卷积神经网络的局部特征提取能力从而实现更高效、更准确的图像识别。
4. 输出层
经过多个MobileViT块的堆叠后模型最终通过全局平均池化层将特征图转换为固定长度的特征向量。然后这些特征向量被送入一个全连接层进行分类。全连接层的输出节点数与类别数相同通过softmax函数计算每个类别的概率分布。
总体而言MobileViT架构通过结合Transformer和卷积神经网络的优点实现了在移动设备上进行高效、准确的图像分类。其独特的MobileViT块能够有效地捕获图像中的长距离依赖关系和局部空间关系从而提高了模型的性能。同时MobileViT架构还采用了MobileNetV2风格的倒置残差块进行特征转换和降采样进一步提高了模型的计算效率。这些特点使得MobileViT成为了一个优秀的移动设备图像分类模型。
2.2.1.构建MobileViT
# 定义卷积块函数用于构建卷积层。
def conv_block(x, filters16, kernel_size3, strides2):# 创建二维卷积层。conv_layer layers.Conv2D(filters, # 过滤器数量kernel_size, # 卷积核大小stridesstrides, # 步长activationkeras.activations.swish, # 激活函数paddingsame, # 填充方式)return conv_layer(x) # 返回卷积后的输出# 根据输入尺寸和卷积核大小计算正确的填充量。
def correct_pad(inputs, kernel_size):# 根据图像数据格式确定图像维度。img_dim 2 if backend.image_data_format() channels_first else 1input_size inputs.shape[img_dim : (img_dim 2)]# 将卷积核大小转换为元组如果它是一个整数。if isinstance(kernel_size, int):kernel_size (kernel_size, kernel_size)# 计算调整值用于确保卷积后尺寸的正确性。if input_size[0] is None:adjust (1, 1)else:adjust (1 - input_size[0] % 2, 1 - input_size[1] % 2)correct (kernel_size[0] // 2, kernel_size[1] // 2)# 返回需要添加的填充量。return ((correct[0] - adjust[0], correct[0]),(correct[1] - adjust[1], correct[1]),)# 定义反残差块用于构建轻量级卷积神经网络中的反残差结构。
def inverted_residual_block(x, expanded_channels, output_channels, strides1):# 使用1x1卷积进行通道扩展。m layers.Conv2D(expanded_channels, 1, paddingsame, use_biasFalse)(x)m layers.BatchNormalization()(m)m keras.activations.swish(m)# 如果步长大于1则使用零填充。if strides 2:m layers.ZeroPadding2D(paddingcorrect_pad(m, 3))(m)# 使用深度可分离卷积进行空间维度的降采样。m layers.DepthwiseConv2D(3, stridesstrides, paddingsame if strides 1 else valid, use_biasFalse)(m)m layers.BatchNormalization()(m)m keras.activations.swish(m)# 使用1x1卷积将通道数降至输出通道数。m layers.Conv2D(output_channels, 1, paddingsame, use_biasFalse)(m)m layers.BatchNormalization()(m)# 如果步长为1且输入输出通道数相同则使用残差连接。if keras.ops.equal(x.shape[-1], output_channels) and strides 1:return layers.Add()([m, x])return m# 定义多层感知机MLP函数用于Transformer中的前馈网络。
def mlp(x, hidden_units, dropout_rate):for units in hidden_units:x layers.Dense(units, activationkeras.activations.swish)(x)x layers.Dropout(dropout_rate)(x)return x# 定义Transformer块函数用于构建Transformer模型中的自注意力机制。
def transformer_block(x, transformer_layers, projection_dim, num_heads2):for _ in range(transformer_layers):# 第一层归一化。x1 layers.LayerNormalization(epsilon1e-6)(x)# 创建多头注意力层。attention_output layers.MultiHeadAttention(num_headsnum_heads, key_dimprojection_dim, dropout0.1)(x1, x1)# 第一个残差连接。x2 layers.Add()([attention_output, x])# 第二层归一化。x3 layers.LayerNormalization(epsilon1e-6)(x2)# MLP。x3 mlp(x3,hidden_units[x.shape[-1] * 2, x.shape[-1]],dropout_rate0.1,)# 第二个残差连接。x layers.Add()([x3, x2])return x# 定义MobileViT块结合了局部特征提取和全局特征提取。
def mobilevit_block(x, num_blocks, projection_dim, strides1):# 使用卷积进行局部特征提取。local_features conv_block(x, filtersprojection_dim, stridesstrides)local_features conv_block(local_features, filtersprojection_dim, kernel_size1, stridesstrides)# 将特征图划分为不重叠的patches并通过Transformer块处理。num_patches int((local_features.shape[1] * local_features.shape[2]) / patch_size)non_overlapping_patches layers.Reshape((patch_size, num_patches, projection_dim))(local_features)global_features transformer_block(non_overlapping_patches, num_blocks, projection_dim)# 将Transformer的输出重新整理成特征图的形状。folded_feature_map layers.Reshape((*local_features.shape[1:-1], projection_dim))(global_features)# 使用1x1卷积将特征图的通道数调整为与输入匹配并与输入特征图进行拼接。folded_feature_map conv_block(folded_feature_map, filtersx.shape[-1], kernel_size1, stridesstrides)local_global_features layers.Concatenate(axis-1)([x, folded_feature_map])# 使用卷积层融合局部和全局特征。local_global_features conv_block(local_global_features, filtersprojection_dim, stridesstrides)return local_global_features上述代码定义了一系列用于构建和操作深度学习模型特别是MobileViT模型的函数。 conv_block: 功能创建一个卷积块包含卷积层、激活函数Swish和批量归一化。用途用于提取图像特征可以作为更复杂模型的一部分。 correct_pad: 功能计算进行卷积操作时所需的填充量以确保输出尺寸正确。用途在对输入图像进行卷积操作之前调整边界填充。 inverted_residual_block: 功能实现MobileNetV2中的反残差结构包含点卷积、深度卷积和批量归一化。用途构建轻量级网络结构用于减少模型参数和计算量。 mlp: 功能实现多层感知机MLP用于Transformer中的前馈网络部分。用途在Transformer模型中进行特征的非线性变换。 transformer_block: 功能构建Transformer块包含多头自注意力机制和前馈网络。用途处理序列数据捕获长距离依赖关系用于图像的全局特征提取。 mobilevit_block: 功能结合局部特征提取通过卷积和全局特征提取通过Transformer的MobileViT块。用途作为MobileViT模型的核心组件实现图像的高效特征提取和表示。
整体来看这些函数共同构成了一个深度学习模型的框架特别是针对移动设备优化的视觉Transformer模型MobileViT。它们涵盖了从数据预处理如填充和归一化到特征提取卷积和Transformer操作的各个步骤最终实现图像分类或其他视觉任务。
2.2.2.实例化MobileViT块
关于MobileViT块的深入解析
在MobileViT架构中MobileViT块是关键组成部分它融合了卷积和Transformer的优势。首先输入的特征表示A通过一系列卷积层这些卷积层专注于捕获图像中的局部细节和空间关系。这些特征图的典型形状是h, w, num_channels其中h代表高度w代表宽度num_channels是通道数。
随后这些特征图被分割成一系列非重叠的小补丁patches每个补丁的大小为p×p其中p表示补丁的边长。这些小补丁被重新组织成一个二维数组形状为p^2, n, num_channels其中n表示整个图像中被分割成的补丁数量计算公式为n (h * w) / (p * p)。这个过程可以看作是“展开”操作将二维特征图转化为一个包含多个补丁的一维序列。
接下来这个一维序列通过Transformer块进行处理。Transformer块利用自注意力机制来捕获补丁之间的全局依赖关系从而能够捕捉图像中的长距离依赖。这种全局建模能力是Transformer架构的核心优势尤其对于理解复杂图像结构和识别高级别概念非常有效。
经过Transformer块处理后输出向量B再次被“折叠”回二维特征图的形状h, w, num_channels。这个过程与之前的“展开”操作相反它将一维序列重新组织成二维特征图以便后续处理。
最后原始的特征表示A和经过Transformer处理后的特征表示B通过两个额外的卷积层进行融合。这两个卷积层的作用是将局部和全局特征进行结合生成更加丰富的特征表示。值得注意的是在这个过程中特征图的空间分辨率保持不变这有助于保持模型对图像细节的敏感度。
从某种角度来看MobileViT块可以被视为一种特殊的卷积块它结合了卷积的局部特征提取能力和Transformer的全局建模能力。这种设计使得MobileViT架构能够在保持较低计算复杂度的同时实现较高的图像分类准确率。
在构建MobileViT架构时多个MobileViT块被组合在一起形成一个完整的网络结构。以下是从原始论文中引用的示意图展示了MobileViT架构的一个具体实例如XXS变体请注意由于这里不能直接插入图像我们将省略具体的示意图。
def create_mobilevit(num_classes5):# 定义输入层假设输入图像大小为 image_size x image_size具有3个颜色通道。inputs keras.Input((image_size, image_size, 3))# 对输入图像进行归一化处理将像素值缩放到0到1之间。x layers.Rescaling(scale1.0 / 255)(inputs)# 开始卷积干线部分使用 conv_block 函数创建第一个卷积层。x conv_block(x, filters16)# 使用 inverted_residual_block 函数创建 MobileNetV2 风格的反残差块。x inverted_residual_block(x, expanded_channels16 * expansion_factor, output_channels16)# 使用 MV2 块进行下采样。# 第一次下采样步长为2输出通道数增加到24。x inverted_residual_block(x, expanded_channels16 * expansion_factor, output_channels24, strides2)# 继续使用 MV2 块进行特征提取保持通道数不变。x inverted_residual_block(x, expanded_channels24 * expansion_factor, output_channels24)# 再次使用 MV2 块进行特征提取。x inverted_residual_block(x, expanded_channels24 * expansion_factor, output_channels24)# 第一个 MV2 块到 MobileViT 块的转换。# 第二次下采样步长为2输出通道数增加到48。x inverted_residual_block(x, expanded_channels24 * expansion_factor, output_channels48, strides2)# 使用 mobilevit_block 函数创建 MobileViT 块包含2个 Transformer 层。x mobilevit_block(x, num_blocks2, projection_dim64)# 第二个 MV2 块到 MobileViT 块的转换。# 继续下采样步长为2输出通道数增加到64。x inverted_residual_block(x, expanded_channels64 * expansion_factor, output_channels64, strides2)# 使用 mobilevit_block 函数创建 MobileViT 块包含4个 Transformer 层。x mobilevit_block(x, num_blocks4, projection_dim80)# 第三个 MV2 块到 MobileViT 块的转换。# 再次下采样步长为2输出通道数增加到80。x inverted_residual_block(x, expanded_channels80 * expansion_factor, output_channels80, strides2)# 使用 mobilevit_block 函数创建 MobileViT 块包含3个 Transformer 层。x mobilevit_block(x, num_blocks3, projection_dim96)# 使用 conv_block 进行1x1卷积用于通道数的调整。x conv_block(x, filters320, kernel_size1, strides1)# 分类头使用全局平均池化层和全连接层进行分类。x layers.GlobalAvgPool2D()(x)outputs layers.Dense(num_classes, activationsoftmax)(x)# 创建 Keras 模型输入为之前定义的 inputs输出为分类结果 outputs。return keras.Model(inputs, outputs)# 实例化 MobileViT 模型类别数默认为5。
mobilevit_xxs create_mobilevit()
# 打印模型的概述信息包括每层的输出形状和参数数量。
mobilevit_xxs.summary()这段代码定义了一个创建MobileViT模型的函数 create_mobilevit并实例化了这个模型然后打印出了模型的概述。 函数定义: create_mobilevit: 这个函数接受一个参数 num_classes表示分类任务的类别数默认为5。 输入层: inputs: 使用 keras.Input 定义模型的输入假设输入图像的大小是 image_size x image_size具有3个颜色通道。 数据预处理: Rescaling: 对输入图像进行重缩放归一化到[0,1]区间。 初始卷积层: conv_block: 应用一个卷积块作为模型的起始部分。 反残差块: inverted_residual_block: 使用MobileNetV2中的反残差结构进行下采样和特征提取。 MobileViT块: mobilevit_block: 结合了卷积和Transformer结构的MobileViT块用于提取局部和全局特征。 分类头: GlobalAvgPool2D: 使用全局平均池化层来减少特征的空间维度。Dense: 使用全连接层进行分类激活函数为Softmax输出类别概率。 模型实例化: mobilevit_xxs: 调用 create_mobilevit 函数实例化MobileViT模型。 模型概述: summary: 打印模型的概述信息包括每层的名称、输出形状和参数数量。
这个函数构建了一个轻量级的深度学习模型适用于移动设备上的图像分类任务。模型结合了卷积神经网络的局部特征提取能力和Transformer的全局特征提取能力通过多个MobileViT块和反残差块进行特征提取最终通过分类头输出预测结果。通过调用 mobilevit_xxs.summary()用户可以快速了解模型的结构和参数量。
2.3 数据预处理
2.3.1.加载数据
我们将使用 tf_flowers 数据集来演示该模型。与其他基于Transformer的架构不同MobileViT使用了一个简单的数据增强流程这主要是因为它具有CNN卷积神经网络的特性。
# 定义批次大小和自动调优参数
batch_size 64
auto tf.data.AUTOTUNE
# 定义在训练时使用的更大的图像尺寸
resize_bigger 280
# 定义类别数
num_classes 5# 定义数据预处理函数
def preprocess_dataset(is_trainingTrue):# 定义内部函数用于处理单个图像和标签def _pp(image, label):if is_training:# 如果是在训练阶段先将图像调整到更大的分辨率然后随机裁剪到所需的尺寸image tf.image.resize(image, (resize_bigger, resize_bigger))image tf.image.random_crop(image, (image_size, image_size, 3))# 随机水平翻转图像image tf.image.random_flip_left_right(image)else:# 如果是在测试或验证阶段直接将图像调整到所需的尺寸image tf.image.resize(image, (image_size, image_size))# 将标签转换为独热编码label tf.one_hot(label, depthnum_classes)return image, label# 返回内部函数return _pp# 定义数据集准备函数
def prepare_dataset(dataset, is_trainingTrue):# 如果是在训练阶段先对数据集进行洗牌if is_training:dataset dataset.shuffle(batch_size * 10)# 使用映射函数并行地应用预处理函数dataset dataset.map(preprocess_dataset(is_training), num_parallel_callsauto)# 将数据集分批并使用预取操作优化性能return dataset.batch(batch_size).prefetch(auto)这段代码定义了两个函数preprocess_dataset 和 prepare_dataset用于准备和预处理数据集。preprocess_dataset 函数根据是否处于训练阶段对图像执行不同的预处理操作包括调整图像大小、随机裁剪、随机水平翻转和标签的独热编码。prepare_dataset 函数则用于对整个数据集应用预处理函数并进行洗牌、分批处理和预取操作以优化数据加载过程。
2.3.2. 数据预处理
# 使用 TensorFlow Datasets 库加载 tf_flowers 数据集分为训练集和验证集。
# 训练集占90%验证集占10%。
train_dataset, val_dataset tfds.load(tf_flowers, split[train[:90%], train[90%:]], as_supervisedTrue
)# 获取训练集和验证集的样本数量。
num_train train_dataset.cardinality()
num_val val_dataset.cardinality()# 打印训练集和验证集的样本数量。
print(fNumber of training examples: {num_train}) # 训练样本数
print(fNumber of validation examples: {num_val}) # 验证样本数# 使用之前定义的 prepare_dataset 函数准备训练集和验证集。
# 训练集使用 is_trainingTrue 进行数据增强。
train_dataset prepare_dataset(train_dataset, is_trainingTrue)
# 验证集使用 is_trainingFalse不进行数据增强。
val_dataset prepare_dataset(val_dataset, is_trainingFalse)这段代码首先使用 TensorFlow Datasets (TFDS) 库加载了 tf_flowers 数据集并将其划分为训练集和验证集其中训练集占据了90%验证集占据了剩余的10%。as_supervisedTrue 参数意味着数据集中的标签已经是监督信号不需要进一步处理。
接着通过调用 cardinality() 方法获取了训练集和验证集中样本的数量并打印出来以便了解数据集的规模。
最后调用 prepare_dataset 函数对训练集和验证集进行进一步的准备包括数据增强、批处理和预取操作。训练集的 is_training 参数设置为 True 以应用数据增强而验证集的 is_training 参数设置为 False通常不进行数据增强以保持数据的原始分布。
2.4.训练MobileViT 模型
# 设置学习率和标签平滑因子。
learning_rate 0.002
label_smoothing_factor 0.1# 设置训练周期数。
epochs 30# 创建 Adam 优化器并设置学习率。
optimizer keras.optimizers.Adam(learning_ratelearning_rate)# 创建分类交叉熵损失函数并设置标签平滑因子。
loss_fn keras.losses.CategoricalCrossentropy(label_smoothinglabel_smoothing_factor)# 定义运行实验的函数。
def run_experiment(epochsepochs):# 创建 MobileViT 模型实例。mobilevit_xxs create_mobilevit(num_classesnum_classes)# 编译模型指定优化器、损失函数和评价指标。mobilevit_xxs.compile(optimizeroptimizer, lossloss_fn, metrics[accuracy])# 设置检查点回调函数保存验证准确率最高的模型权重。checkpoint_filepath /tmp/checkpoint.weights.h5checkpoint_callback keras.callbacks.ModelCheckpoint(checkpoint_filepath,monitorval_accuracy,save_best_onlyTrue,save_weights_onlyTrue, # 指定仅保存模型权重)# 训练模型使用训练数据集和验证数据集。mobilevit_xxs.fit(train_dataset,validation_dataval_dataset,epochsepochs,callbacks[checkpoint_callback],)# 加载最佳模型权重。mobilevit_xxs.load_weights(checkpoint_filepath)# 在验证数据集上评估模型并打印准确率。_, accuracy mobilevit_xxs.evaluate(val_dataset)print(fValidation accuracy: {round(accuracy * 100, 2)}%) # 打印验证准确率return mobilevit_xxs# 调用实验函数开始训练和评估过程。
mobilevit_xxs run_experiment()这段代码首先设置了模型训练所需的一些关键参数包括学习率、标签平滑因子和训练周期数。然后定义了优化器和损失函数其中损失函数采用了标签平滑技术有助于提高模型的泛化能力。
run_experiment 函数负责创建 MobileViT 模型实例、编译模型、设置检查点回调、训练模型以及在验证集上评估模型的性能。训练过程中使用了早停法Early Stopping来保存最佳模型权重避免过拟合。最后函数返回训练好的模型并打印出验证集上的准确率。
2.5.MobileViT与TFLite
结果和TFLite转换使用大约一百万个参数在256x256分辨率下达到约85%的top-1准确率是一个出色的结果。这款MobileViT模型与TensorFlow LiteTFLite完全兼容可以使用以下代码进行转换
以下是添加了中文注释的代码
# 将模型序列化为 SavedModel 格式并保存到 mobilevit_xxs 文件夹。
tf.saved_model.save(mobilevit_xxs, mobilevit_xxs)# 将 SavedModel 转换为 TFLite 格式。这里使用的是 TFLite 中的后训练动态范围量化。
converter tf.lite.TFLiteConverter.from_saved_model(mobilevit_xxs)
# 设置优化类型为默认优化。
converter.optimizations [tf.lite.Optimize.DEFAULT]
# 设置支持的操作集包括 TensorFlow Lite 内置操作和 TensorFlow 操作。
converter.target_spec.supported_ops [tf.lite.OpsSet.TFLITE_BUILTINS, # 启用 TensorFlow Lite 操作。tf.lite.OpsSet.SELECT_TF_OPS, # 启用 TensorFlow 操作。
]# 执行转换操作得到 TFLite 模型。
tflite_model converter.convert()# 将转换后的 TFLite 模型写入文件 mobilevit_xxs.tflite。
with open(mobilevit_xxs.tflite, wb) as f:f.write(tflite_model)这段代码首先将训练好的 mobilevit_xxs 模型序列化并保存为 SavedModel 格式。SavedModel 是 TensorFlow 的一种模型格式可以保存模型结构、权重和训练配置。
然后使用 tf.lite.TFLiteConverter 将 SavedModel 转换为 TFLite 格式。TFLite 是 TensorFlow 的轻量级解决方案适用于移动和嵌入式设备。在转换过程中设置了优化选项以减小模型大小并提高运行效率并指定了模型支持的操作集。
最后将转换后的 TFLite 模型写入到一个名为 “mobilevit_xxs.tflite” 的文件中。这样得到的 TFLite 模型可以被部署到移动设备或其它边缘设备上进行高效的推理计算。
3. 总结与展望
3.1 总结
本文详细介绍了MobileViT模型的设计原理、架构组成以及在移动设备图像分类任务中的应用。MobileViT作为一种结合了Transformer和CNN优势的轻量级模型已经在多个标准数据集上展现出了卓越的性能。以下是对全文内容的总结 MobileViT架构介绍了MobileViT的基本概念包括其设计背景、技术特点、性能表现和适用场景。MobileViT通过轻量化的Transformer模块和有效的降维策略实现了在移动设备上的高效运行。 Transformer架构特点分析了Transformer架构的核心思想、组成组件和优势特别是在并行处理能力和长距离依赖建模方面的表现。 研究内容探讨了MobileViT的研究意义包括其在移动设备上的应用需求、理论价值和实际应用前景。 模型部署提供了使用TensorFlow和TFLite部署MobileViT模型的详细步骤包括数据预处理、模型构建、训练和转换为TFLite格式。 模型结构与训练结果附录中列出了MobileViT模型的具体结构和参数量以及模型训练过程中的损失和准确率变化情况。
3.2 展望
虽然MobileViT在移动设备图像分类任务上取得了显著的成果但仍有诸多方向值得未来的研究和探索 模型优化尽管MobileViT已经进行了轻量化设计但仍有进一步优化模型结构和参数空间的潜力以适应更多样化的移动设备。 多任务学习将MobileViT扩展到多任务学习框架中例如同时进行图像分类、目标检测和分割等任务。 跨领域应用探索MobileViT在其他领域的应用如视频处理、医疗影像分析等以验证其泛化能力。 鲁棒性研究研究MobileViT在不同环境和条件下的性能表现提高模型的鲁棒性。 实时性能针对实时应用场景进一步优化MobileViT的推理速度和能耗效率。 模型压缩与加速研究模型剪枝、量化等模型压缩技术以减小模型大小和加速推理过程。 开源社区贡献通过开源项目和社区合作推动MobileViT的进一步开发和应用。
综上所述MobileViT作为一种新型的移动视觉模型不仅在理论上具有创新性而且在实际应用中具有广泛的前景。随着移动设备计算能力的不断提升和深度学习技术的不断进步MobileViT有望在未来的移动视觉领域发挥更大的作用。
附录1模型结构
Model: model
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to input_1 (InputLayer) [(None, 256, 256, 3) 0
__________________________________________________________________________________________________
rescaling (Rescaling) (None, 256, 256, 3) 0 input_1[0][0]
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 128, 128, 16) 448 rescaling[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 128, 128, 32) 512 conv2d[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 128, 128, 32) 128 conv2d_1[0][0]
__________________________________________________________________________________________________
tf.nn.silu (TFOpLambda) (None, 128, 128, 32) 0 batch_normalization[0][0]
__________________________________________________________________________________________________
depthwise_conv2d (DepthwiseConv (None, 128, 128, 32) 288 tf.nn.silu[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 128, 128, 32) 128 depthwise_conv2d[0][0]
__________________________________________________________________________________________________
tf.nn.silu_1 (TFOpLambda) (None, 128, 128, 32) 0 batch_normalization_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 128, 128, 16) 512 tf.nn.silu_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 128, 128, 16) 64 conv2d_2[0][0]
__________________________________________________________________________________________________
add (Add) (None, 128, 128, 16) 0 batch_normalization_2[0][0] conv2d[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 128, 128, 32) 512 add[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 128, 128, 32) 128 conv2d_3[0][0]
__________________________________________________________________________________________________
tf.nn.silu_2 (TFOpLambda) (None, 128, 128, 32) 0 batch_normalization_3[0][0]
__________________________________________________________________________________________________
zero_padding2d (ZeroPadding2D) (None, 129, 129, 32) 0 tf.nn.silu_2[0][0]
__________________________________________________________________________________________________
depthwise_conv2d_1 (DepthwiseCo (None, 64, 64, 32) 288 zero_padding2d[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 64, 64, 32) 128 depthwise_conv2d_1[0][0]
__________________________________________________________________________________________________
tf.nn.silu_3 (TFOpLambda) (None, 64, 64, 32) 0 batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 64, 64, 24) 768 tf.nn.silu_3[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 64, 64, 24) 96 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 64, 64, 48) 1152 batch_normalization_5[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 64, 64, 48) 192 conv2d_5[0][0]
__________________________________________________________________________________________________
tf.nn.silu_4 (TFOpLambda) (None, 64, 64, 48) 0 batch_normalization_6[0][0]
__________________________________________________________________________________________________
depthwise_conv2d_2 (DepthwiseCo (None, 64, 64, 48) 432 tf.nn.silu_4[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 64, 64, 48) 192 depthwise_conv2d_2[0][0]
__________________________________________________________________________________________________
tf.nn.silu_5 (TFOpLambda) (None, 64, 64, 48) 0 batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 64, 64, 24) 1152 tf.nn.silu_5[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 64, 64, 24) 96 conv2d_6[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 64, 64, 24) 0 batch_normalization_8[0][0] batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 64, 64, 48) 1152 add_1[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 64, 64, 48) 192 conv2d_7[0][0]
__________________________________________________________________________________________________
tf.nn.silu_6 (TFOpLambda) (None, 64, 64, 48) 0 batch_normalization_9[0][0]
__________________________________________________________________________________________________
depthwise_conv2d_3 (DepthwiseCo (None, 64, 64, 48) 432 tf.nn.silu_6[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 64, 64, 48) 192 depthwise_conv2d_3[0][0]
__________________________________________________________________________________________________
tf.nn.silu_7 (TFOpLambda) (None, 64, 64, 48) 0 batch_normalization_10[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 64, 64, 24) 1152 tf.nn.silu_7[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 64, 64, 24) 96 conv2d_8[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 64, 64, 24) 0 batch_normalization_11[0][0] add_1[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 64, 64, 48) 1152 add_2[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 64, 64, 48) 192 conv2d_9[0][0]
__________________________________________________________________________________________________
tf.nn.silu_8 (TFOpLambda) (None, 64, 64, 48) 0 batch_normalization_12[0][0]
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 65, 65, 48) 0 tf.nn.silu_8[0][0]
__________________________________________________________________________________________________
depthwise_conv2d_4 (DepthwiseCo (None, 32, 32, 48) 432 zero_padding2d_1[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 32, 32, 48) 192 depthwise_conv2d_4[0][0]
__________________________________________________________________________________________________
tf.nn.silu_9 (TFOpLambda) (None, 32, 32, 48) 0 batch_normalization_13[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 32, 32, 48) 2304 tf.nn.silu_9[0][0]
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 32, 32, 48) 192 conv2d_10[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 32, 32, 64) 27712 batch_normalization_14[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 32, 32, 64) 4160 conv2d_11[0][0]
__________________________________________________________________________________________________
reshape (Reshape) (None, 4, 256, 64) 0 conv2d_12[0][0]
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, 4, 256, 64) 128 reshape[0][0]
__________________________________________________________________________________________________
multi_head_attention (MultiHead (None, 4, 256, 64) 33216 layer_normalization[0][0] layer_normalization[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 4, 256, 64) 0 multi_head_attention[0][0] reshape[0][0]
__________________________________________________________________________________________________
layer_normalization_1 (LayerNor (None, 4, 256, 64) 128 add_3[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 4, 256, 128) 8320 layer_normalization_1[0][0]
__________________________________________________________________________________________________
dropout (Dropout) (None, 4, 256, 128) 0 dense[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 4, 256, 64) 8256 dropout[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 4, 256, 64) 0 dense_1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 4, 256, 64) 0 dropout_1[0][0] add_3[0][0]
__________________________________________________________________________________________________
layer_normalization_2 (LayerNor (None, 4, 256, 64) 128 add_4[0][0]
__________________________________________________________________________________________________
multi_head_attention_1 (MultiHe (None, 4, 256, 64) 33216 layer_normalization_2[0][0] layer_normalization_2[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 4, 256, 64) 0 multi_head_attention_1[0][0] add_4[0][0]
__________________________________________________________________________________________________
layer_normalization_3 (LayerNor (None, 4, 256, 64) 128 add_5[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 4, 256, 128) 8320 layer_normalization_3[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 4, 256, 128) 0 dense_2[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 4, 256, 64) 8256 dropout_2[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout) (None, 4, 256, 64) 0 dense_3[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 4, 256, 64) 0 dropout_3[0][0] add_5[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape) (None, 32, 32, 64) 0 add_6[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 32, 32, 48) 3120 reshape_1[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 32, 32, 96) 0 batch_normalization_14[0][0] conv2d_13[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 32, 32, 64) 55360 concatenate[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 32, 32, 128) 8192 conv2d_14[0][0]
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 32, 32, 128) 512 conv2d_15[0][0]
__________________________________________________________________________________________________
tf.nn.silu_10 (TFOpLambda) (None, 32, 32, 128) 0 batch_normalization_15[0][0]
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 33, 33, 128) 0 tf.nn.silu_10[0][0]
__________________________________________________________________________________________________
depthwise_conv2d_5 (DepthwiseCo (None, 16, 16, 128) 1152 zero_padding2d_2[0][0]
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 16, 16, 128) 512 depthwise_conv2d_5[0][0]
__________________________________________________________________________________________________
tf.nn.silu_11 (TFOpLambda) (None, 16, 16, 128) 0 batch_normalization_16[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 16, 16, 64) 8192 tf.nn.silu_11[0][0]
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 16, 16, 64) 256 conv2d_16[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 16, 16, 80) 46160 batch_normalization_17[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 16, 16, 80) 6480 conv2d_17[0][0]
__________________________________________________________________________________________________
reshape_2 (Reshape) (None, 4, 64, 80) 0 conv2d_18[0][0]
__________________________________________________________________________________________________
layer_normalization_4 (LayerNor (None, 4, 64, 80) 160 reshape_2[0][0]
__________________________________________________________________________________________________
multi_head_attention_2 (MultiHe (None, 4, 64, 80) 51760 layer_normalization_4[0][0] layer_normalization_4[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 4, 64, 80) 0 multi_head_attention_2[0][0] reshape_2[0][0]
__________________________________________________________________________________________________
layer_normalization_5 (LayerNor (None, 4, 64, 80) 160 add_7[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 4, 64, 160) 12960 layer_normalization_5[0][0]
__________________________________________________________________________________________________
dropout_4 (Dropout) (None, 4, 64, 160) 0 dense_4[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 4, 64, 80) 12880 dropout_4[0][0]
__________________________________________________________________________________________________
dropout_5 (Dropout) (None, 4, 64, 80) 0 dense_5[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 4, 64, 80) 0 dropout_5[0][0] add_7[0][0]
__________________________________________________________________________________________________
layer_normalization_6 (LayerNor (None, 4, 64, 80) 160 add_8[0][0]
__________________________________________________________________________________________________
multi_head_attention_3 (MultiHe (None, 4, 64, 80) 51760 layer_normalization_6[0][0] layer_normalization_6[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 4, 64, 80) 0 multi_head_attention_3[0][0] add_8[0][0]
__________________________________________________________________________________________________
layer_normalization_7 (LayerNor (None, 4, 64, 80) 160 add_9[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 4, 64, 160) 12960 layer_normalization_7[0][0]
__________________________________________________________________________________________________
dropout_6 (Dropout) (None, 4, 64, 160) 0 dense_6[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 4, 64, 80) 12880 dropout_6[0][0]
__________________________________________________________________________________________________
dropout_7 (Dropout) (None, 4, 64, 80) 0 dense_7[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, 4, 64, 80) 0 dropout_7[0][0] add_9[0][0]
__________________________________________________________________________________________________
layer_normalization_8 (LayerNor (None, 4, 64, 80) 160 add_10[0][0]
__________________________________________________________________________________________________
multi_head_attention_4 (MultiHe (None, 4, 64, 80) 51760 layer_normalization_8[0][0] layer_normalization_8[0][0]
__________________________________________________________________________________________________
add_11 (Add) (None, 4, 64, 80) 0 multi_head_attention_4[0][0] add_10[0][0]
__________________________________________________________________________________________________
layer_normalization_9 (LayerNor (None, 4, 64, 80) 160 add_11[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 4, 64, 160) 12960 layer_normalization_9[0][0]
__________________________________________________________________________________________________
dropout_8 (Dropout) (None, 4, 64, 160) 0 dense_8[0][0]
__________________________________________________________________________________________________
dense_9 (Dense) (None, 4, 64, 80) 12880 dropout_8[0][0]
__________________________________________________________________________________________________
dropout_9 (Dropout) (None, 4, 64, 80) 0 dense_9[0][0]
__________________________________________________________________________________________________
add_12 (Add) (None, 4, 64, 80) 0 dropout_9[0][0] add_11[0][0]
__________________________________________________________________________________________________
layer_normalization_10 (LayerNo (None, 4, 64, 80) 160 add_12[0][0]
__________________________________________________________________________________________________
multi_head_attention_5 (MultiHe (None, 4, 64, 80) 51760 layer_normalization_10[0][0] layer_normalization_10[0][0]
__________________________________________________________________________________________________
add_13 (Add) (None, 4, 64, 80) 0 multi_head_attention_5[0][0] add_12[0][0]
__________________________________________________________________________________________________
layer_normalization_11 (LayerNo (None, 4, 64, 80) 160 add_13[0][0]
__________________________________________________________________________________________________
dense_10 (Dense) (None, 4, 64, 160) 12960 layer_normalization_11[0][0]
__________________________________________________________________________________________________
dropout_10 (Dropout) (None, 4, 64, 160) 0 dense_10[0][0]
__________________________________________________________________________________________________
dense_11 (Dense) (None, 4, 64, 80) 12880 dropout_10[0][0]
__________________________________________________________________________________________________
dropout_11 (Dropout) (None, 4, 64, 80) 0 dense_11[0][0]
__________________________________________________________________________________________________
add_14 (Add) (None, 4, 64, 80) 0 dropout_11[0][0] add_13[0][0]
__________________________________________________________________________________________________
reshape_3 (Reshape) (None, 16, 16, 80) 0 add_14[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, 16, 16, 64) 5184 reshape_3[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 16, 16, 128) 0 batch_normalization_17[0][0] conv2d_19[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D) (None, 16, 16, 80) 92240 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, 16, 16, 160) 12800 conv2d_20[0][0]
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 16, 16, 160) 640 conv2d_21[0][0]
__________________________________________________________________________________________________
tf.nn.silu_12 (TFOpLambda) (None, 16, 16, 160) 0 batch_normalization_18[0][0]
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 17, 17, 160) 0 tf.nn.silu_12[0][0]
__________________________________________________________________________________________________
depthwise_conv2d_6 (DepthwiseCo (None, 8, 8, 160) 1440 zero_padding2d_3[0][0]
__________________________________________________________________________________________________
batch_normalization_19 (BatchNo (None, 8, 8, 160) 640 depthwise_conv2d_6[0][0]
__________________________________________________________________________________________________
tf.nn.silu_13 (TFOpLambda) (None, 8, 8, 160) 0 batch_normalization_19[0][0]
__________________________________________________________________________________________________
conv2d_22 (Conv2D) (None, 8, 8, 80) 12800 tf.nn.silu_13[0][0]
__________________________________________________________________________________________________
batch_normalization_20 (BatchNo (None, 8, 8, 80) 320 conv2d_22[0][0]
__________________________________________________________________________________________________
conv2d_23 (Conv2D) (None, 8, 8, 96) 69216 batch_normalization_20[0][0]
__________________________________________________________________________________________________
conv2d_24 (Conv2D) (None, 8, 8, 96) 9312 conv2d_23[0][0]
__________________________________________________________________________________________________
reshape_4 (Reshape) (None, 4, 16, 96) 0 conv2d_24[0][0]
__________________________________________________________________________________________________
layer_normalization_12 (LayerNo (None, 4, 16, 96) 192 reshape_4[0][0]
__________________________________________________________________________________________________
multi_head_attention_6 (MultiHe (None, 4, 16, 96) 74400 layer_normalization_12[0][0] layer_normalization_12[0][0]
__________________________________________________________________________________________________
add_15 (Add) (None, 4, 16, 96) 0 multi_head_attention_6[0][0] reshape_4[0][0]
__________________________________________________________________________________________________
layer_normalization_13 (LayerNo (None, 4, 16, 96) 192 add_15[0][0]
__________________________________________________________________________________________________
dense_12 (Dense) (None, 4, 16, 192) 18624 layer_normalization_13[0][0]
__________________________________________________________________________________________________
dropout_12 (Dropout) (None, 4, 16, 192) 0 dense_12[0][0]
__________________________________________________________________________________________________
dense_13 (Dense) (None, 4, 16, 96) 18528 dropout_12[0][0]
__________________________________________________________________________________________________
dropout_13 (Dropout) (None, 4, 16, 96) 0 dense_13[0][0]
__________________________________________________________________________________________________
add_16 (Add) (None, 4, 16, 96) 0 dropout_13[0][0] add_15[0][0]
__________________________________________________________________________________________________
layer_normalization_14 (LayerNo (None, 4, 16, 96) 192 add_16[0][0]
__________________________________________________________________________________________________
multi_head_attention_7 (MultiHe (None, 4, 16, 96) 74400 layer_normalization_14[0][0] layer_normalization_14[0][0]
__________________________________________________________________________________________________
add_17 (Add) (None, 4, 16, 96) 0 multi_head_attention_7[0][0] add_16[0][0]
__________________________________________________________________________________________________
layer_normalization_15 (LayerNo (None, 4, 16, 96) 192 add_17[0][0]
__________________________________________________________________________________________________
dense_14 (Dense) (None, 4, 16, 192) 18624 layer_normalization_15[0][0]
__________________________________________________________________________________________________
dropout_14 (Dropout) (None, 4, 16, 192) 0 dense_14[0][0]
__________________________________________________________________________________________________
dense_15 (Dense) (None, 4, 16, 96) 18528 dropout_14[0][0]
__________________________________________________________________________________________________
dropout_15 (Dropout) (None, 4, 16, 96) 0 dense_15[0][0]
__________________________________________________________________________________________________
add_18 (Add) (None, 4, 16, 96) 0 dropout_15[0][0] add_17[0][0]
__________________________________________________________________________________________________
layer_normalization_16 (LayerNo (None, 4, 16, 96) 192 add_18[0][0]
__________________________________________________________________________________________________
multi_head_attention_8 (MultiHe (None, 4, 16, 96) 74400 layer_normalization_16[0][0] layer_normalization_16[0][0]
__________________________________________________________________________________________________
add_19 (Add) (None, 4, 16, 96) 0 multi_head_attention_8[0][0] add_18[0][0]
__________________________________________________________________________________________________
layer_normalization_17 (LayerNo (None, 4, 16, 96) 192 add_19[0][0]
__________________________________________________________________________________________________
dense_16 (Dense) (None, 4, 16, 192) 18624 layer_normalization_17[0][0]
__________________________________________________________________________________________________
dropout_16 (Dropout) (None, 4, 16, 192) 0 dense_16[0][0]
__________________________________________________________________________________________________
dense_17 (Dense) (None, 4, 16, 96) 18528 dropout_16[0][0]
__________________________________________________________________________________________________
dropout_17 (Dropout) (None, 4, 16, 96) 0 dense_17[0][0]
__________________________________________________________________________________________________
add_20 (Add) (None, 4, 16, 96) 0 dropout_17[0][0] add_19[0][0]
__________________________________________________________________________________________________
reshape_5 (Reshape) (None, 8, 8, 96) 0 add_20[0][0]
__________________________________________________________________________________________________
conv2d_25 (Conv2D) (None, 8, 8, 80) 7760 reshape_5[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 8, 8, 160) 0 batch_normalization_20[0][0] conv2d_25[0][0]
__________________________________________________________________________________________________
conv2d_26 (Conv2D) (None, 8, 8, 96) 138336 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_27 (Conv2D) (None, 8, 8, 320) 31040 conv2d_26[0][0]
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 320) 0 conv2d_27[0][0]
__________________________________________________________________________________________________
dense_18 (Dense) (None, 5) 1605 global_average_pooling2d[0][0] Total params: 1,307,621
Trainable params: 1,305,077
Non-trainable params: 2,544附录2模型训练结果
Epoch 1/30
52/52 [] - 47s 459ms/step - loss: 1.3397 - accuracy: 0.4832 - val_loss: 1.7250 - val_accuracy: 0.1662
Epoch 2/30
52/52 [] - 21s 404ms/step - loss: 1.1167 - accuracy: 0.6210 - val_loss: 1.9844 - val_accuracy: 0.1907
Epoch 3/30
52/52 [] - 21s 403ms/step - loss: 1.0217 - accuracy: 0.6709 - val_loss: 1.8187 - val_accuracy: 0.1907
Epoch 4/30
52/52 [] - 21s 409ms/step - loss: 0.9682 - accuracy: 0.7048 - val_loss: 2.0329 - val_accuracy: 0.1907
Epoch 5/30
52/52 [] - 21s 408ms/step - loss: 0.9552 - accuracy: 0.7196 - val_loss: 2.1150 - val_accuracy: 0.1907
Epoch 6/30
52/52 [] - 21s 407ms/step - loss: 0.9186 - accuracy: 0.7318 - val_loss: 2.9713 - val_accuracy: 0.1907
Epoch 7/30
52/52 [] - 21s 407ms/step - loss: 0.8986 - accuracy: 0.7457 - val_loss: 3.2062 - val_accuracy: 0.1907
Epoch 8/30
52/52 [] - 21s 408ms/step - loss: 0.8831 - accuracy: 0.7542 - val_loss: 3.8631 - val_accuracy: 0.1907
Epoch 9/30
52/52 [] - 21s 408ms/step - loss: 0.8433 - accuracy: 0.7714 - val_loss: 1.8029 - val_accuracy: 0.3542
Epoch 10/30
52/52 [] - 21s 408ms/step - loss: 0.8489 - accuracy: 0.7763 - val_loss: 1.7920 - val_accuracy: 0.4796
Epoch 11/30
52/52 [] - 21s 409ms/step - loss: 0.8256 - accuracy: 0.7884 - val_loss: 1.4992 - val_accuracy: 0.5477
Epoch 12/30
52/52 [] - 21s 407ms/step - loss: 0.7859 - accuracy: 0.8123 - val_loss: 0.9236 - val_accuracy: 0.7330
Epoch 13/30
52/52 [] - 21s 409ms/step - loss: 0.7702 - accuracy: 0.8159 - val_loss: 0.8059 - val_accuracy: 0.8011
Epoch 14/30
52/52 [] - 21s 403ms/step - loss: 0.7670 - accuracy: 0.8153 - val_loss: 1.1535 - val_accuracy: 0.7084
Epoch 15/30
52/52 [] - 21s 408ms/step - loss: 0.7332 - accuracy: 0.8344 - val_loss: 0.7746 - val_accuracy: 0.8147
Epoch 16/30
52/52 [] - 21s 404ms/step - loss: 0.7284 - accuracy: 0.8335 - val_loss: 1.0342 - val_accuracy: 0.7330
Epoch 17/30
52/52 [] - 21s 409ms/step - loss: 0.7484 - accuracy: 0.8262 - val_loss: 1.0523 - val_accuracy: 0.7112
Epoch 18/30
52/52 [] - 21s 408ms/step - loss: 0.7209 - accuracy: 0.8450 - val_loss: 0.8146 - val_accuracy: 0.8174
Epoch 19/30
52/52 [] - 21s 409ms/step - loss: 0.7141 - accuracy: 0.8435 - val_loss: 0.8016 - val_accuracy: 0.7875
Epoch 20/30
52/52 [] - 21s 410ms/step - loss: 0.7075 - accuracy: 0.8435 - val_loss: 0.9352 - val_accuracy: 0.7439
Epoch 21/30
52/52 [] - 21s 406ms/step - loss: 0.7066 - accuracy: 0.8504 - val_loss: 1.0171 - val_accuracy: 0.7139
Epoch 22/30
52/52 [] - 21s 405ms/step - loss: 0.6913 - accuracy: 0.8532 - val_loss: 0.7059 - val_accuracy: 0.8610
Epoch 23/30
52/52 [] - 21s 408ms/step - loss: 0.6681 - accuracy: 0.8671 - val_loss: 0.8007 - val_accuracy: 0.8147
Epoch 24/30
52/52 [] - 21s 409ms/step - loss: 0.6636 - accuracy: 0.8747 - val_loss: 0.9490 - val_accuracy: 0.7302
Epoch 25/30
52/52 [] - 21s 408ms/step - loss: 0.6637 - accuracy: 0.8722 - val_loss: 0.6913 - val_accuracy: 0.8556
Epoch 26/30
52/52 [] - 21s 406ms/step - loss: 0.6443 - accuracy: 0.8837 - val_loss: 1.0483 - val_accuracy: 0.7139
Epoch 27/30
52/52 [] - 21s 407ms/step - loss: 0.6555 - accuracy: 0.8695 - val_loss: 0.9448 - val_accuracy: 0.7602
Epoch 28/30
52/52 [] - 21s 409ms/step - loss: 0.6409 - accuracy: 0.8807 - val_loss: 0.9337 - val_accuracy: 0.7302
Epoch 29/30
52/52 [] - 21s 408ms/step - loss: 0.6300 - accuracy: 0.8910 - val_loss: 0.7461 - val_accuracy: 0.8256
Epoch 30/30
52/52 [] - 21s 408ms/step - loss: 0.6093 - accuracy: 0.8968 - val_loss: 0.8651 - val_accuracy: 0.7766
6/6 [] - 0s 65ms/step - loss: 0.7059 - accuracy: 0.8610
Validation accuracy: 86.1%