网站环境搭建教程,cn 域名网站,网站导航,wordpress自定义结构空白页飞桨的scatter函数#xff0c;是通过基于 updates 来更新选定索引 index 上的输入来获得输出#xff0c;具体官网api文档见#xff1a;
scatter-API文档-PaddlePaddle深度学习平台 官网给的例子如下#xff1a; import paddle x paddle.to_tens…飞桨的scatter函数是通过基于 updates 来更新选定索引 index 上的输入来获得输出具体官网api文档见
scatter-API文档-PaddlePaddle深度学习平台 官网给的例子如下 import paddle x paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtypefloat32) index paddle.to_tensor([2, 1, 0, 1], dtypeint64) updates paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtypefloat32) output1 paddle.scatter(x, index, updates, overwriteFalse) print(output1)Tensor(shape[3, 2], dtypefloat32, placePlace(cpu), stop_gradientTrue,[[3., 3.],[6., 6.],[1., 1.]]) output2 paddle.scatter(x, index, updates, overwriteTrue) # CPU device: # [[3., 3.], # [4., 4.], # [1., 1.]] # GPU device maybe have two results because of the repeated numbers in index # result 1: # [[3., 3.], # [4., 4.], # [1., 1.]] # result 2: # [[3., 3.], # [2., 2.], # [1., 1.]]但是如果是初学者看官网的例子可能还是无法明白scatter的运算方式下面就结合一个更加明白的例子来说明
import paddle
x paddle.to_tensor([[100, 200], [300, 400], [500, 600]], dtypefloat32)
index paddle.to_tensor([2, 1, 0, 1], dtypeint64)
updates paddle.to_tensor([[10, 11], [21, 22], [33, 34], [40, 41]], dtypefloat32)output1 paddle.scatter(x, index, updates, overwriteFalse)
print(output1)
output2 paddle.scatter(x, index, updates, overwriteTrue)
print(output2)
输出结果
Tensor(shape[3, 2], dtypefloat32, placePlace(cpu), stop_gradientTrue,[[33., 34.],[61., 63.],[10., 11.]])
Tensor(shape[3, 2], dtypefloat32, placePlace(cpu), stop_gradientTrue,[[33., 34.],[40., 41.],[10., 11.]])
scatter详解
输入是三个值源值x 索引index 变量updates分析函数输出可以得出以下结论
1 scatter函数的输出shape是和x一致的
2 函数输出的值跟x没关系
3 函数输出值跟变量updates值有关
4 输出值跟updates的有关具体取值的索引跟index有关
具体来说就是不需要x的值只使用了它的维度信息然后根据索引将变量updates的值填入x的维度中。比如index值是[2, 1, 0, 1]第一位是2,那么就把updates的第一组数[10, 11]也就是updates[0]取出来放到x[2]里index第二位是1 就把的updates的第2组数[21, 22]也就是updates[1]取出来放到x[1]里。以此类推index第三位是0, 那么就把updates的第3组数[33, 34]也就是updates[2]放到x[0]里。
到了index第四位数它是1 那么就需要把updates的第四组数[40, 41]也就是updates[3],放入到x[1]中 。这时候有个问题就是前面x[1]中已经放入了[21,22]。这时候就看函数的overwrite参数的设置了如果设置overwriteTrue ,那么直接用现在的值[40, 41]取代以前的值最终函数返回结果就是[[33, 34], [40, 41], [10, 11]] 。如果函数设为overwriteFalse 那么就将值[40, 41]与以前的x[1]21, 22相加结果是[61, 63]最终返回值就是[[33, 34], [61, 63], [10, 11]] 好了这样大家就明白scatter的运算机制了吧 小贴士
飞桨官网给出了scatter函数的python代码实现其中因为用了巧妙的思路来提高速度可读性略有下降 import paddle #input: x paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtypefloat32) index paddle.to_tensor([2, 1, 0, 1], dtypeint64) # shape of updates should be the same as x # shape of updates with dim 1 should be the same as input updates paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtypefloat32) overwrite False # calculation: if not overwrite:... for i in range(len(index)):... x[index[i]] paddle.zeros([2]) for i in range(len(index)):... if (overwrite):... x[index[i]] updates[i]... else:... x[index[i]] updates[i] # output: out paddle.to_tensor([[3, 3], [6, 6], [1, 1]]) print(out.shape)[3, 2]scatter的运算机制不管overwrite是否为Truex的值都不参与运算理论上应该都清除也就是在循环里置0 x[index[i]] paddle.zeros([2])
实际上如果overwrite是True那么在赋值的时候本身可以直接写入 x[index[i]] updates[i]这样就可以省略x[index[i]] paddle.zeros([2])这句这就是为什么这段置0代码放到了条件if not overwrite: 这句里面的原因。