PyTorch笔记09----高阶OP
记录了PyTorch中两个高阶操作:Where
与Gather
Where
torch.where(condition, x, y)
满足condition的 返回x,否则返回y
x、y、condition的shape相同
condition矩阵:
那么A的地方取x对应元素,B的地方取y对应元素tensor([[1, 0],
[0, 1]])
example
cond #tensor([[0.6769, 0.7271], |
Gather
torch.gather(input, dim, index, out = None)
>
index是表,dim是查表的维度,index是需要查的index
查表操作,如:表[x1, x2, x3],那么我们gather需要查的index:[0, 1, 0, 2]得到[x1, x2, x1, x3]
四条数据输入神经网络得到一个输出:[4, 10] 取出每条数据概率最大的index:[[1], [2], [0], [9]] 使用gather查表得到对应的元素
prob = torch.rand(4, 10) |