网站建设网站模版,科技教育司,深圳微商城网站制作联系电话,西安做网站公司哪个好6、关于Medical-Transformer
Axial-Attention原文链接#xff1a;Axial-attention Medical-Transformer原文链接#xff1a;Medical-Transformer Medical-Transformer实际上是Axial-Attention在医学领域的运行#xff0c;只是在这基础上增加了门机制#xff0c;实际上也就…6、关于Medical-Transformer
Axial-Attention原文链接Axial-attention Medical-Transformer原文链接Medical-Transformer Medical-Transformer实际上是Axial-Attention在医学领域的运行只是在这基础上增加了门机制实际上也就是在原来Axial-attention基础之上增加权重机制虚弱位置信息对于数据的影响发现虚弱之后的效果比Axial-Attention机制效果更好 Axial-Attention Axial-Attention与传统Transformer的self-attention相比较,将2D计算转成1D计算Axial-attention机制对于qkv的计算做出了简化仅仅某个点的横竖两个方向上的特殊同时在qkv的基础上加上了各自位置特征这些特征都是更新学习的。 Axial-attention模型架构图 左图为传统的self-attention机制右图为Axial-attention机制对于qkv都加上rqrkrv这样的位置参数这些参数都是可以更新的也就是说每个的q在和自己对应的横竖轴反向进行计算的时候q会和自己rq先进行权重计算同样的k和v也会进行同样的计算随后进行q和k进行计算得到权重计算过程和原来的self-attention机制是一样的。 class AxialAttention(nn.Module):def forward(self, x):# 前向传播函数# 如果设置了 width 参数调整张量维度顺序if self.width:x x.permute(0, 2, 1, 3) # 调整维度顺序else:x x.permute(0, 3, 1, 2) # N, W, C, H 调整为 N, C, H, WN, W, C, H x.shape # 获取张量形状x x.contiguous().view(N * W, C, H) # 重新调整形状合并 N 和 W 维度# 通过x获得对应的qkv 批归一化后计算 qkvqkv self.bn_qkv(self.qkv_transform(x)) q, k, v torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H),[self.group_planes // 2, self.group_planes // 2, self.group_planes], dim2) # 将 qkv 拆分为 q, k, v# 计算位置嵌入all_embeddings torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)q_embedding, k_embedding, v_embedding torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim0) # 拆分嵌入# 计算 QR, KR, QK 相似性分别计算得出rqrkqr torch.einsum(bgci,cij-bgij, q, q_embedding) # QR: q 和 q_embedding 的爱因斯坦求和kr torch.einsum(bgci,cij-bgij, k, k_embedding).transpose(2, 3) # KR: k 和 k_embedding 的爱因斯坦求和并转置# q和k进行计算得到最后的权重qk torch.einsum(bgci, bgcj-bgij, q, k) # QK: q 和 k 之间的点积# 将 QR, KR, QK 相似性进行堆叠连在一起进行计算stacked_similarity torch.cat([qk, qr, kr], dim1) # 将 qk, qr, kr 连接起来stacked_similarity self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim1) # 批归一化并调整形状# similarity为q和k计算得出权重关系similarity F.softmax(stacked_similarity, dim3) # 在第 3 维度上计算 softmax# 将q和v计算出来权重和v加权求和sv torch.einsum(bgij,bgcj-bgci, similarity, v) # 将相似度与 v 进行求和# v与位置信息结合sve torch.einsum(bgij,cij-bgci, similarity, v_embedding) # 将similarity与 v_embedding 进行求和# 将位置加权后的v和q和k计算结果与v加权的合并并调整形状输出stacked_output torch.cat([sv, sve], dim-1).view(N * W, self.out_planes * 2, H) # 合并 sv 和 sve并调整形状output self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim-2) # 批归一化并调整形状# 恢复维度顺序if self.width:output output.permute(0, 2, 1, 3) # 调整维度顺序else:output output.permute(0, 2, 3, 1) # 调整维度顺序# 如果步长大于 1应用池化操作if self.stride 1:output self.pooling(output) # 池化return output # 返回输出横竖轴计算过程 先通过卷积把特征图缩小然后横竖轴计算时是将横轴一起进行计算然后再进行纵轴计算的完成计算后通过1x1卷积将特征图还原为原来的大小在传入下一层进行计算。 Medical-Transformer
Medical-Transformer架构图 Medical-Transformer实际就是Axial-attention在医学图像分割领域的应用medical-tranformer大模型架构采用整个图像进行Axial-attention特征提取同时也将图像分成多个窗口对每个窗口进行axial-attention特征提取窗口由于计算量小可以多进行几层Axial-attention最终将整个图像特征和窗口特征融合完成整个的特征提取值得一提的是在进行窗口Axial-attention时qkv都没有加上位置编码(也就是下面部分的图像)。 主体架构
class medt_net(nn.Module):def _forward_impl(self, x):xin x.clone() # 保存输入数据的副本x self.conv1(x) # 第一个卷积层x self.bn1(x) # 第一个批归一化层x self.relu(x) # ReLU 激活函数x self.conv2(x) # 第二个卷积层x self.bn2(x) # 第二个批归一化层x self.relu(x) # ReLU 激活函数x self.conv3(x) # 第三个卷积层x self.bn3(x) # 第三个批归一化层x self.relu(x) # ReLU 激活函数x1 self.layer1(x) # 第一个残差层 实际上就是 Gated Axial Attention Layerx2 self.layer2(x1) # 第二个残差层 同样是 Gated Axial Attention Layer# 对输入进行插值放大并通过解码器处理x F.relu(F.interpolate(self.decoder4(x2), scale_factor(2, 2), modebilinear))x torch.add(x, x1) # 将放大的特征图与 x1 相加x F.relu(F.interpolate(self.decoder5(x), scale_factor(2, 2), modebilinear))# 以上完成就是图上方整个图像的卷积过程# -------------------------------------------------------------------------------------------x_loc x.clone() # 生成一个本地副本# 下面对图像进行切分分别对每个窗口进行局部处理实际上是16个窗口for i in range(0, 4):for j in range(0, 4):x_p xin[:, :, 32 * i:32 * (i 1), 32 * j:32 * (j 1)] # 提取32x32的局部patch# 逐层卷积处理patchx_p self.conv1_p(x_p)x_p self.bn1_p(x_p)x_p self.relu(x_p)x_p self.conv2_p(x_p)x_p self.bn2_p(x_p)x_p self.relu(x_p)x_p self.conv3_p(x_p)x_p self.bn3_p(x_p)x_p self.relu(x_p)# 进行四个x1_p self.layer1_p(x_p) # 第一个残差层patch-wise 这里进行的axial-attention在进行qkv计算时qkv都没有加入位置信息计算x2_p self.layer2_p(x1_p) # 第二个残差层patch-wisex3_p self.layer3_p(x2_p) # 第三个残差层patch-wisex4_p self.layer4_p(x3_p) # 第四个残差层patch-wise# 对patch进行插值放大并通过解码器处理x_p F.relu(F.interpolate(self.decoder1_p(x4_p), scale_factor(2, 2), modebilinear))x_p torch.add(x_p, x4_p) # 将放大的特征图与 x4_p 相加x_p F.relu(F.interpolate(self.decoder2_p(x_p), scale_factor(2, 2), modebilinear))x_p torch.add(x_p, x3_p) # 将放大的特征图与 x3_p 相加x_p F.relu(F.interpolate(self.decoder3_p(x_p), scale_factor(2, 2), modebilinear))x_p torch.add(x_p, x2_p) # 将放大的特征图与 x2_p 相加x_p F.relu(F.interpolate(self.decoder4_p(x_p), scale_factor(2, 2), modebilinear))x_p torch.add(x_p, x1_p) # 将放大的特征图与 x1_p 相加x_p F.relu(F.interpolate(self.decoder5_p(x_p), scale_factor(2, 2), modebilinear))x_loc[:, :, 32 * i:32 * (i 1), 32 * j:32 * (j 1)] x_p # 将局部处理后的结果放回原始位置# 将整个图片的axial-attention和每个窗口得出的结果进行结合x torch.add(x, x_loc) # 将全局和局部特征进行融合x F.relu(self.decoderf(x)) # 通过最终的解码器层x self.adjust(F.relu(x)) # 调整输出return x # 返回最终输出Gated Axial Attention Layer 从架构图中可以看出就是在Axial-attention的基础上加上了门机制说白了也就是在qkv和各自的rqrkrv计算完成后再进行下一步计算之前进行了一个加权计算虚弱了位置变量对特征提取结果的影响。 横向或纵向Gated Axial-attention过程 注意里面qrkr实际上就是图片中的rqrk而 class AxialAttention_dynamic(nn.Module):def forward(self, x):# 判断是否需要对宽度维度进行变换if self.width:x x.permute(0, 2, 1, 3) # 交换维度顺序形状变为 [N, C, W, H]else:x x.permute(0, 3, 1, 2) # 交换维度顺序形状变为 [N, W, C, H]N, W, C, H x.shape # 获取输入张量的形状x x.contiguous().view(N * W, C, H) # 将张量变形为 [N * W, C, H]print(x.shape) # 输出形状: [64, 16, 64]# 变换操作qkv self.bn_qkv(self.qkv_transform(x)) # 对qkv进行批归一化print(qkv.shape) # 输出形状: [64, 32, 64]# 将qkv张量拆分为q、k、v分别表示查询、键和值q, k, v torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim2)print(q.shape) # 输出q的形状: [64, 8, 1, 64]print(k.shape) # 输出k的形状: [64, 8, 1, 64]print(v.shape) # 输出v的形状: [64, 8, 2, 64]v有两份# 计算位置嵌入all_embeddings torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)print(all_embeddings.shape) # 输出嵌入的形状: [4, 64, 64]共有4份q_embedding, k_embedding, v_embedding torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim0)print(q_embedding.shape) # 输出q的位置嵌入形状: [1, 64, 64]print(k_embedding.shape) # 输出k的位置嵌入形状: [1, 64, 64]print(v_embedding.shape) # 输出v的位置嵌入形状: [2, 64, 64]v有两份位置编码# 计算q与位置嵌入的乘积qr torch.einsum(bgci,cij-bgij, q, q_embedding)print(qr.shape) # 输出qr的形状: [64, 8, 64, 64]# 计算k与位置嵌入的乘积并进行转置kr torch.einsum(bgci,cij-bgij, k, k_embedding).transpose(2, 3)print(kr.shape) # 输出kr的形状: [64, 8, 64, 64]# 计算q和k的点积qk torch.einsum(bgci, bgcj-bgij, q, k)print(qk.shape) # 输出qk的形状: [64, 8, 64, 64]# 对qr和kr进行初始化使用self.f_qr和self.f_kr作为初始化的权重qr torch.mul(qr, self.f_qr)print(qr.shape) # 输出qr的形状: [64, 8, 64, 64]kr torch.mul(kr, self.f_kr)print(kr.shape) # 输出kr的形状: [64, 8, 64, 64]# 将qk、qr和kr拼接起来stacked_similarity torch.cat([qk, qr, kr], dim1)print(stacked_similarity.shape) # 输出拼接后的形状: [64, 24, 64, 64]# 进行批归一化重新变形为[N * W, 3, groups, H, H]并对维度1求和stacked_similarity self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim1)print(stacked_similarity.shape) # 输出归一化后的形状: [64, 8, 64, 64]# 计算相似度similarity F.softmax(stacked_similarity, dim3)print(similarity.shape) # 输出相似度的形状: [64, 8, 64, 64]# 使用相似度与v相乘获得加权后的值sv torch.einsum(bgij,bgcj-bgci, similarity, v)print(sv.shape) # 输出加权后的形状: [64, 8, 2, 64]# 使用相似度与v的位置嵌入相乘sve torch.einsum(bgij,cij-bgci, similarity, v_embedding)print(sve.shape) # 输出位置嵌入加权后的形状: [64, 8, 2, 64]# 对sv和sve进行初始化sv torch.mul(sv, self.f_sv)print(sv.shape) # 输出sv的形状: [64, 8, 2, 64]sve torch.mul(sve, self.f_sve)print(sve.shape) # 输出sve的形状: [64, 8, 2, 64]# 将sv和sve拼接在一起并重新变形为[N * W, out_planes * 2, H]stacked_output torch.cat([sv, sve], dim-1).view(N * W, self.out_planes * 2, H)print(stacked_output.shape) # 输出拼接后的形状: [64, 32, 64]# 进行批归一化并变形为[N, W, out_planes, 2, H]对维度-2求和output self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim-2)print(output.shape) # 输出归一化后的形状: [1, 64, 16, 64]# 根据宽度调整维度顺序if self.width:output output.permute(0, 2, 1, 3)else:output output.permute(0, 2, 3, 1)print(output.shape) # 输出最终的形状: [1, 16, 64, 64]# 如果步幅大于1进行池化操作if self.stride 1:output self.pooling(output)return output