PyTorch笔记08----统计属性

记录了PyTorch一些常见统计属性.

常见统计属性: - norm 范数 - mean sum - prod 累乘 - min max argmin argmax 最小/大值的位置 - kthvalue topk

norm

不是normalize正则化 矩阵范数与向量范数有区别的

norm-p p范数

a = torch.full([8], 1)
b = a.review(2, 4)
c = a.review(2, 2, 2)

a.norm(1), b.norm(1), c.norm(1)
#tensor(8.), tensor(8.), tensor(8.)

a.norm(2), b.norm(2), c.norm(2)
#tensor(2.8284), tensor(2.8284), tensor(2.8284)

b.norm(1, dim = 0)
#tensor([4., 4.])

c.norm(1, dim = 0)
#tensor([[2, 2],
[2, 2]])

mean / sum / min / max / prod / argmax / argmin

a = torch.arange(8).view(2, 4).float()
#tensor([[0, 1, 2, 3],
[4, 5, 6, 7]])

a.min(), a.max(), a.mean(), a.prod(), a.sum()
#tensor(0.), tensor(7.), tensor(3.5000), tensor(0.), tensor(28.)

a.max(dim = 1)
#tensor([3., 7.]), tensor([3, 3])

a.argmax(), a.argmin()
#tensor(7), tensor(0) 会将tesor打平后求max、min索引

a.argmax(dim = 1)
#tensor([3, 3])

dim / keepdim

a = torch.rand(4, 10)
a.max(dim = 1)
#tensor([0.8362, 1.7015, 1.1297, 0.6386]), tensor([3, 8, 6, 4])

a.max(dim = 1, keepdim = True) #dim与原tensor保持一致
#tensor([[0.8362],
[1.7015],
[1.1297],
[0.6386]]), tensor([[3],
[8],
[6],
[4]])

a.argmax(dim = 1, keepdim = True)
#tensor([[3],
[8],
[6],
[4]])

top-k / k-th

  • top-k 返回最大的k个数和其索引
a = torch.randn(4, 10)
a.topk(3, dim = 1)
#tensor([[0.8362, 0.3913, -0.1830],
[1.7832, 1.4828, 1.2393],
[0.6392, 0.3824, 0.2227],
[0.9928, 0.1215, -0.3927]]), tensor([[3, 8, 9],
[8, 6, 5],
[2, 3, 6],
[5, 7, 9]])

a.topk(3, dim = 1, largest = false) #返回最小的k个
  • kthvalue
a.kthvalue(8, dim = 1)
#tensor([-0.1830, 1.2393, 0.2227, -0.3927]), tensor([9, 5, 6, 9])
#返回第8小的元素及其索引(在这里是第3大)
a.kthvalue(8) #结果同上

compare

>, >=, <, <=, !=, ==

a > 0
#tensor([[0, 0, 0, 1, ...],
[0, 0, 0, 0, ...],
[0, 1, 1, 0, ...],
[0, 0, 0, 0, ...]])
torch.gt(a, 0) #结果同上

b = rand(2, 2)
torch.eq(b, b)
#tensor([[1, 1],
[1, 1]])
torch.equal(b, b)
#True