PyTorch笔记09----高阶OP

记录了PyTorch中两个高阶操作:WhereGather

Where

torch.where(condition, x, y)

满足condition的 返回x,否则返回y

x、y、condition的shape相同

condition矩阵:

tensor([[1, 0],
[0, 1]])
那么A的地方取x对应元素,B的地方取y对应元素

example

cond        #tensor([[0.6769, 0.7271],
[0.8884, 0.4163]])
a #tensor([[0., 0.],
[0., 0.]])
b #tensor([[1., 1.],
[1., 1.]])

torch.where(cond > 0.5, a, b)
#tensor([[0., 0.],
[0., 1.]])

Gather

torch.gather(input, dim, index, out = None) > index是表,dim是查表的维度,index是需要查的index

查表操作,如:表[x1, x2, x3],那么我们gather需要查的index:[0, 1, 0, 2]得到[x1, x2, x1, x3]

四条数据输入神经网络得到一个输出:[4, 10] 取出每条数据概率最大的index:[[1], [2], [0], [9]] 使用gather查表得到对应的元素

prob = torch.rand(4, 10)
idx = prob.topk(dim = 1, k = 3)[1]
idx
#tensor([[7, 4, 9],
[8, 1, 3],
[2, 8, 4],
[8, 6, 0]])
label = torch.arange(10) + 100
label
#tensor([100, 101, ..., 109])

torch.gather(label.expand(4, 10), dim = 1, index = idx.long())
#tensor([[107, 104, 109],
[108, 101, 103],
[102, 108, 104],
[108, 106, 100]])
#.long()转化为LongTensor类型,不加也可以得到相同结果