当前位置: 首页 > news >正文

河南郑州做网站h汉狮长沙网站排名提升

河南郑州做网站h汉狮,长沙网站排名提升,汕头建站模板搭建,网站运营的作用Pytorch常用的函数(九)torch.gather()用法 torch.gather() 就是在指定维度上收集value。 torch.gather() 的必填也是最常用的参数有三个#xff0c;下面引用官方解释#xff1a; input (Tensor) – the source tensordim (int) – the axis along which to indexindex (Lo…Pytorch常用的函数(九)torch.gather()用法 torch.gather() 就是在指定维度上收集value。 torch.gather() 的必填也是最常用的参数有三个下面引用官方解释 input (Tensor) – the source tensordim (int) – the axis along which to indexindex (LongTensor) – the indices of elements to gather 一句话概括 gather 操作就是根据 index 在 input 的 dim 维度上收集 value。 1、举例直观理解 # 1、我们有input_tensor如下input_tensor torch.arange(24).reshape(2, 3, 4) tensor([[[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、我们有index_tensor如下index_tensor torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]] ) # 3、我们通过torch.gather()函数获取out_tensorout_tensor torch.gather(input_tensor, dim1, indexindex_tensor) tensor([[[ 0, 1, 2, 3],[ 8, 9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])我们以out_tensor中[0,1,0]8为例解释下如何利用dim和index从input_tensor中获得8。 根据上图我们很直观的了解根据 index 在 input 的 dim 维度上收集 value的过程。 假设 input 和 index 均为三维数组那么输出 tensor 每个位置的索引是列表 [i, j, k] 正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可但是由于 dim 的存在以及 input.shape 可能不等于 index.shape 所以直接取值可能就会报错 所以我们是将索引列表的相应位置替换为 dim 再去 input 取值。在上面示例中由于dim1那么我们就替换索引列表第1个值即[i,dim,k]因此由原来的[0,1,0]替换为[0,2,0]后再去input_tensor中取值。pytorch官方文档的写法如下同一个意思。 out[i][j][k] input[index[i][j][k]][j][k] # if dim 0 out[i][j][k] input[i][index[i][j][k]][k] # if dim 1 out[i][j][k] input[i][j][index[i][j][k]] # if dim 22、反推法再理解 # 1、我们有input_tensor如下input_tensor torch.arange(24).reshape(2, 3, 4) tensor([[[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、假设我们要得到out_tensor如下out_tensor tensor([[[ 0, 1, 2, 3],[ 8, 9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])、# 3、如何知道dim 和 index_tensor呢 # 首先我们要记住out_tensor的shape index_tensor的shape# 从 output_tensor 的第一个位置开始 # 此时[i, j, k]一样看不出来 dim 应该是多少 output_tensor[0, 0, :] input_tensor[0, 0, :] 0 # 同理可知此时index都为0 output_tensor[0, 0, 1] input_tensor[0, 0, 1] 1 output_tensor[0, 0, 2] input_tensor[0, 0, 2] 2 output_tensor[0, 0, 3] input_tensor[0, 0, 3] 3# 我们从下一行的第一个位置开始 # 这里我们看到维度 1 发生了变化1 变成了 2所以 dim 应该是 1而 index 应为 2 output_tensor[0, 1, 0] input_tensor[0, 2, 0] 8 # 同理可知此时index都为2 output_tensor[0, 1, 1] input_tensor[0, 2, 1] 9 output_tensor[0, 1, 2] input_tensor[0, 2, 2] 10 output_tensor[0, 1, 3] input_tensor[0, 2, 3] 11# 根据上面推导我们易知dim1,index_tensor为index_tensor torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]] ) 3、实际案例 在大神何凯明MAE模型(Masked Autoencoders Are Scalable Vision Learners)源码中多次使用了torch.gather() 函数。 论文链接https://arxiv.org/pdf/2111.06377官方源码https://github.com/facebookresearch/mae 在MAE中根据预设的掩码比例(paper 中提倡的是 75%)使用服从均匀分布的随机采样策略采样一部分 tokens 送给 Encoder另一部分mask 掉。采样25%作为unmasked tokens过程中使用了torch.gather() 函数。 # models_mae.pyimport torchdef random_masking(x, mask_ratio0.75):Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequenceN, L, D x.shape # batch, length, dimlen_keep int(L * (1 - mask_ratio)) # 计算unmasked的片数# 利用0-1均匀分布进行采样避免潜在的【中心归纳偏好】noise torch.rand(N, L, devicex.device) # noise in [0, 1]# sort noise for each sample【核心代码】ids_shuffle torch.argsort(noise, dim1) # ascend: small is keep, large is removeids_restore torch.argsort(ids_shuffle, dim1)# keep the first subsetids_keep ids_shuffle[:, :len_keep]# 利用torch.gather()从源tensor中获取25%的unmasked tokensx_masked torch.gather(x, dim1, indexids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask torch.ones([N, L], devicex.device)mask[:, :len_keep] 0# unshuffle to get the binary maskmask torch.gather(mask, dim1, indexids_restore)return x_masked, mask, ids_restoreif __name__ __main__:x torch.arange(64).reshape(1, 16, 4)random_masking(x)# x模拟一张图片经过patch_embedding后的序列 # x相当于input_tensor # 16是patch数量实际上一般为(img_size/patch_size)^2 (224 / 16)^2 14*14196 # 4是一个patch中像素个数这里只是模拟实际上一般为in_chans * patch_size * patch_size 3*16*16 768x torch.arange(64).reshape(1, 16, 4) tensor([[[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11],[12, 13, 14, 15],[16, 17, 18, 19], # 4[20, 21, 22, 23],[24, 25, 26, 27],[28, 29, 30, 31],[32, 33, 34, 35],[36, 37, 38, 39],[40, 41, 42, 43], # 10[44, 45, 46, 47],[48, 49, 50, 51], # 12[52, 53, 54, 55], # 13[56, 57, 58, 59],[60, 61, 62, 63]]]) # dim1, index相当于index_tensorindex tensor([[[10, 10, 10, 10],[12, 12, 12, 12],[ 4, 4, 4, 4],[13, 13, 13, 13]]])# x_masked(从源tensor即x中随机获取25%(4个patch)的unmasked tokens) x_masked相当于out_tensor tensor([[[40, 41, 42, 43],[48, 49, 50, 51],[16, 17, 18, 19],[52, 53, 54, 55]]])
http://www.w-s-a.com/news/346300/

相关文章:

  • 北京网站建设公司分形科技简述营销网站建设策略
  • 汉中网站建设有限公司vue网站开发
  • 网站备案背景幕布阳江东莞网站建设
  • 北京网站建设要多少钱html网站标签
  • 做兼职做网站的是什么公司网站怎么修改
  • 舆情监控都有哪些内容西安seo网站公司
  • 网站有域名没备案天津网络营销
  • 哈巴狗模式网站开发电子商务平台建设与运营技术
  • 摄影网站源码wordpress内涵段子
  • 实验一 电子商务网站建设与维护图片做网站
  • 网站策划书模板大全中国建设部官方网站资格证查询
  • vps绑定多个网站创意咨询策划公司
  • 做qq图片的网站网页制作与网站建设江西
  • 做爰全过程的视频网站网络文化经营许可证怎么办
  • 常德市网站建设网站开发用哪个软件好
  • 网站文章怎么更新时间重庆勘察设计网
  • 外卖网站设计企业网站优化做法
  • 专业的营销型网站制作wordpress版权年份
  • 程序员会搭建非法网站吗怎么把wordpress字去掉
  • 牡丹江营商环境建设监督局网站中国档案网站建设的特点
  • 网站欣赏网站欣赏知名企业网站搭建
  • 书店网站建设可行性分析为大型企业设计网络营销方案
  • 北京教育云平台网站建设中国服装设计网站
  • 网络公司专业做网站豌豆荚app下载
  • 网站建设属于什么岗位济宁网站建设_云科网络
  • wordpress网站监测fwa 网站 欣赏
  • 用jsp做的可运行的网站推广网络
  • 电商网站设计论文wordpress子文件夹建站
  • 临沂网站优化如何如何做公司的网站建设
  • 建设部网站 光纤到户沈阳网页设计兼职