北京做网站电话,wordpress如何换图片,网络设计需求分析,商城做网站模块出处
[ISBI 22] [link] [code] Duplex Contextual Relation Network for Polyp Segmentation 模块名称
Exterior Contextual-Relation Module (ECRM) 模块作用
内存型特征增强模块 模块结构 模块思想
原文表述#xff1a;在临床环境中#xff0c;不同样本之间存在息肉…模块出处
[ISBI 22] [link] [code] Duplex Contextual Relation Network for Polyp Segmentation 模块名称
Exterior Contextual-Relation Module (ECRM) 模块作用
内存型特征增强模块 模块结构 模块思想
原文表述在临床环境中不同样本之间存在息肉的同步视觉模式。基于这一关键观察属于所有训练数据的同一语义类的区域特征应该具有上下文关系。因此我们提出了一种新颖的跨不同样本的上下文关系探索模块。 具体做法则是对于编码器最后一层得到的全局特征(图中红色方块)进行两次增强 第一次是直接将全局特征送入一个 1 × 1 1 \times 1 1×1卷积(图中浅紫色部分)以获取一个粗糙分割mask该mask与全局特征相乘后便能得到过滤掉背景特征的增强特征(图中enqueue左边的部分)。 第二次增强则是基于网络存储的源自其他训练样本的历史上下文信息(图中的Cross-Batch Memory)。即当前特征与Memory内特征进行Cross Attention操作从而利用历史经验对当前状态进行补全。 模块代码
代码实现有几个额外要注意的地方
模块返回的aux_out要进行side supervision监督以保证准确性Memory负责维护网络的历史信息为防止被破坏这部分信息并不参与梯度更新过程在测试阶段Memory不再更新直接使用训练所存储的历史信息这一思想与BatchNorm类似。
import torch
from torch import nndef conv2d(in_channel, out_channel, kernel_size):layers [nn.Conv2d(in_channel, out_channel, kernel_size, paddingkernel_size // 2, biasFalse),nn.BatchNorm2d(out_channel),nn.ReLU(),]return nn.Sequential(*layers)def conv1d(in_channel, out_channel):layers [nn.Conv1d(in_channel, out_channel, 1, biasFalse),nn.BatchNorm1d(out_channel),nn.ReLU(),]return nn.Sequential(*layers)class ECRM(nn.Module):def __init__(self, bank_size20, feat_channels512, num_classes1):super(ECRM, self).__init__() # BANK CONFIGself.bank_size bank_sizeself.register_buffer(bank_ptr, torch.zeros(1, dtypetorch.long)) # memory bank pointerself.register_buffer(bank, torch.zeros(self.bank_size, feat_channels, num_classes)) # memory bankself.bank_full False# ATTENTION CONFIGself.feat_channels feat_channelsself.L nn.Conv2d(feat_channels, num_classes, 1)self.X conv2d(feat_channels, 512, 3)self.phi conv1d(512, 256)self.psi conv1d(512, 256)self.delta conv1d(512, 256)self.rho conv1d(256, 512)self.g conv2d(512 512, 512, 1)def init(self):self.bank_ptr[0] 0self.bank_full Falsetorch.no_grad()def update_bank(self, x):ptr int(self.bank_ptr)batch_size x.shape[0]vacancy self.bank_size - ptrif batch_size vacancy:self.bank_full Truepos min(batch_size, vacancy)self.bank[ptr:ptrpos] x[0:pos].clone()# update pointerptr (ptr pos) % self.bank_sizeself.bank_ptr[0] ptrdef enhance_by_memory(self, bank, X_flat, X):batch, n_class, height, width X.shape# query S * Cquery self.phi(bank).squeeze(dim2)# key: B * C * HWkey self.psi(X_flat)# logit HW * S * B (cross image relation)logit torch.matmul(query, key).transpose(0,2)# attn HW * S * Battn torch.softmax(logit, 2)# delta S * Cdelta self.delta(bank).squeeze(dim2)# attn_sum B * C * HWattn_sum torch.matmul(attn.transpose(1,2), delta).transpose(1,2)# x_obj B * C * H * WX_obj self.rho(attn_sum).view(batch, -1, height, width)concat torch.cat([X, X_obj], 1)out self.g(concat)return outdef get_prototype(self, input):L self.L(input)aux_out Lbatch, n_class, _, _ L.shapel_flat L.view(batch, n_class, -1)M torch.softmax(l_flat, -1)X self.X(input)channel X.shape[1]X_flat X.view(batch, channel, -1)f_k (M X_flat.transpose(1, 2)).transpose(1, 2)return aux_out, f_k, X_flat, Xdef forward(self, x, flagtrain):# x [3, 512, 11, 11]# patch [3, 512, 1]aux_out, patch, feats_flat, feats self.get_prototype(x)if flag train:self.update_bank(patch)ptr int(self.bank_ptr)if self.bank_full True:out self.enhance_by_memory(self.bank, feats_flat, feats)else:out self.enhance_by_memory(self.bank[0:ptr], feats_flat, feats)elif flag test:out self.enhance_by_memory(patch, feats_flat, feats)return out, aux_outif __name__ __main__:x torch.randn([3, 512, 11, 11])ecrm ECRM()out ecrm(x)print(out[0].shape) # 3, 512, 11, 11print(out[1].shape) # 3, 1, 11, 11