河南郑州做网站h汉狮,长沙网站排名提升,汕头建站模板搭建,网站运营的作用Pytorch常用的函数(九)torch.gather()用法
torch.gather() 就是在指定维度上收集value。
torch.gather() 的必填也是最常用的参数有三个#xff0c;下面引用官方解释#xff1a;
input (Tensor) – the source tensordim (int) – the axis along which to indexindex (Lo…Pytorch常用的函数(九)torch.gather()用法
torch.gather() 就是在指定维度上收集value。
torch.gather() 的必填也是最常用的参数有三个下面引用官方解释
input (Tensor) – the source tensordim (int) – the axis along which to indexindex (LongTensor) – the indices of elements to gather
一句话概括 gather 操作就是根据 index 在 input 的 dim 维度上收集 value。
1、举例直观理解
# 1、我们有input_tensor如下input_tensor torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、我们有index_tensor如下index_tensor torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]]
) # 3、我们通过torch.gather()函数获取out_tensorout_tensor torch.gather(input_tensor, dim1, indexindex_tensor)
tensor([[[ 0, 1, 2, 3],[ 8, 9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])我们以out_tensor中[0,1,0]8为例解释下如何利用dim和index从input_tensor中获得8。 根据上图我们很直观的了解根据 index 在 input 的 dim 维度上收集 value的过程。
假设 input 和 index 均为三维数组那么输出 tensor 每个位置的索引是列表 [i, j, k] 正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可但是由于 dim 的存在以及 input.shape 可能不等于 index.shape 所以直接取值可能就会报错 所以我们是将索引列表的相应位置替换为 dim 再去 input 取值。在上面示例中由于dim1那么我们就替换索引列表第1个值即[i,dim,k]因此由原来的[0,1,0]替换为[0,2,0]后再去input_tensor中取值。pytorch官方文档的写法如下同一个意思。
out[i][j][k] input[index[i][j][k]][j][k] # if dim 0
out[i][j][k] input[i][index[i][j][k]][k] # if dim 1
out[i][j][k] input[i][j][index[i][j][k]] # if dim 22、反推法再理解
# 1、我们有input_tensor如下input_tensor torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、假设我们要得到out_tensor如下out_tensor
tensor([[[ 0, 1, 2, 3],[ 8, 9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])、# 3、如何知道dim 和 index_tensor呢
# 首先我们要记住out_tensor的shape index_tensor的shape# 从 output_tensor 的第一个位置开始
# 此时[i, j, k]一样看不出来 dim 应该是多少
output_tensor[0, 0, :] input_tensor[0, 0, :] 0
# 同理可知此时index都为0
output_tensor[0, 0, 1] input_tensor[0, 0, 1] 1
output_tensor[0, 0, 2] input_tensor[0, 0, 2] 2
output_tensor[0, 0, 3] input_tensor[0, 0, 3] 3# 我们从下一行的第一个位置开始
# 这里我们看到维度 1 发生了变化1 变成了 2所以 dim 应该是 1而 index 应为 2
output_tensor[0, 1, 0] input_tensor[0, 2, 0] 8
# 同理可知此时index都为2
output_tensor[0, 1, 1] input_tensor[0, 2, 1] 9
output_tensor[0, 1, 2] input_tensor[0, 2, 2] 10
output_tensor[0, 1, 3] input_tensor[0, 2, 3] 11# 根据上面推导我们易知dim1,index_tensor为index_tensor torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]]
) 3、实际案例
在大神何凯明MAE模型(Masked Autoencoders Are Scalable Vision Learners)源码中多次使用了torch.gather() 函数。
论文链接https://arxiv.org/pdf/2111.06377官方源码https://github.com/facebookresearch/mae
在MAE中根据预设的掩码比例(paper 中提倡的是 75%)使用服从均匀分布的随机采样策略采样一部分 tokens 送给 Encoder另一部分mask 掉。采样25%作为unmasked tokens过程中使用了torch.gather() 函数。
# models_mae.pyimport torchdef random_masking(x, mask_ratio0.75):Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequenceN, L, D x.shape # batch, length, dimlen_keep int(L * (1 - mask_ratio)) # 计算unmasked的片数# 利用0-1均匀分布进行采样避免潜在的【中心归纳偏好】noise torch.rand(N, L, devicex.device) # noise in [0, 1]# sort noise for each sample【核心代码】ids_shuffle torch.argsort(noise, dim1) # ascend: small is keep, large is removeids_restore torch.argsort(ids_shuffle, dim1)# keep the first subsetids_keep ids_shuffle[:, :len_keep]# 利用torch.gather()从源tensor中获取25%的unmasked tokensx_masked torch.gather(x, dim1, indexids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask torch.ones([N, L], devicex.device)mask[:, :len_keep] 0# unshuffle to get the binary maskmask torch.gather(mask, dim1, indexids_restore)return x_masked, mask, ids_restoreif __name__ __main__:x torch.arange(64).reshape(1, 16, 4)random_masking(x)# x模拟一张图片经过patch_embedding后的序列
# x相当于input_tensor
# 16是patch数量实际上一般为(img_size/patch_size)^2 (224 / 16)^2 14*14196
# 4是一个patch中像素个数这里只是模拟实际上一般为in_chans * patch_size * patch_size 3*16*16 768x torch.arange(64).reshape(1, 16, 4)
tensor([[[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11],[12, 13, 14, 15],[16, 17, 18, 19], # 4[20, 21, 22, 23],[24, 25, 26, 27],[28, 29, 30, 31],[32, 33, 34, 35],[36, 37, 38, 39],[40, 41, 42, 43], # 10[44, 45, 46, 47],[48, 49, 50, 51], # 12[52, 53, 54, 55], # 13[56, 57, 58, 59],[60, 61, 62, 63]]])
# dim1, index相当于index_tensorindex
tensor([[[10, 10, 10, 10],[12, 12, 12, 12],[ 4, 4, 4, 4],[13, 13, 13, 13]]])# x_masked(从源tensor即x中随机获取25%(4个patch)的unmasked tokens) x_masked相当于out_tensor
tensor([[[40, 41, 42, 43],[48, 49, 50, 51],[16, 17, 18, 19],[52, 53, 54, 55]]])