PyTorch笔记04----Tensor维度变换

PyTorch的一些Tensor维度变换方式

View / Reshape

通用,要有物理意义,否则是污染数据

a = torch.rand(4, 1, 28, 28)
a.shape #torch.Size([4, 1, 28, 28])

a.view(4, 28 * 28).shape #torch.Size([4, 784])
#特别适合全连接层

b = a.view(4, 784) #会丢失原来的维度信息
b.view(4, 28, 28, 1) #造成了数据污染

  • Flewible but prone to corrupt 如果view的size不同,会报错
    a.view(4, 783)
    --------Error--------

Squeeze / Unsqueeze

Unsqueeze

a.shape         #torch.Size([4, 1, 28, 28])

a.unsqueeze(0).shape #torch.Size([1, 4, 1, 28, 28]) 在位置0处增加一维

a.unsqueeze(-1).shape #torch.Size([4, 1, 28, 28, 1])

a.unsqueeze(5).shape
---------Error--------
a = torch.tensor([1.2, 2.3])
a.unsqueeze(-1)
#tensor([[1.2000],
[2.3000]])
#shape由[2]变为[2,1]

a.unsqueeze(0)
#tensor([[1.2000, 2.3000]])
#shape由[2]变为[1,2]
  • Example > 将shape为[32]的bias增加到shape为[4, 32, 14, 14]的FeatureMap的channel中
b = torch.rand(32)
f = torch.rand(4, 32, 14, 14)

b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
#注意第二次的unsqueeze参数是按第一次操作过后的索引

b.shape #torch.Size([1, 32, 1, 1])

Squeeze

b.shape             #torch.Size([1, 32, 1, 1])
b.squeeze().shape #torch.Size([32]) 不传递参数时将能挤压的全部挤压
b.squeeze(0).shape #torch.Size([32, 1, 1])

b.squeeze(1).shape #torch.Size([1, 32, 1, 1]) 维度不是1,所以不能挤压,但不会报错

Expand / Repeat

维度扩展 之前的b: [1, 32, 1, 1]``f: [4, 32, 14, 14]仍不能相加,需要分别将第0、2、3维度扩展4、14、14倍

  • Expand:broadcasting. 只改变理解方式,不改变数据
  • Repeat:memory copied. 增加数据,都拷贝一遍.

推荐Expand,省略复制数据,只在操作必要时复制. 运行速度快且节约内存.

Expand

a = torch.rand(4, 32, 14, 14)
b.shape #torch.Size(1, 32, 1, 1)
b.expand(4, 32, 14, 14).shape #torch.Size(4, 32, 14, 14)
#前提:前后dim一致(这里都为4),1扩展到N,M扩展到M

b.expand(-1, 32, -1, -1).shape #torch.Size(1, 32, 1, 1) -1保持不变

b.expand(-1, 32, -1, -4).shape #torch.Size(1, 32, 1, -4) 是个bug,无意义

Repeat

b.shape         #torch.Size(1, 32, 1, 1)
b.repeat(4, 32, 1, 1).shape #torch.Size(4, 1024, 1, 1) repeat的参数是对应维度复制的次数

b.repeat(4, 1, 14, 14).shape #torch.Size(4, 32, 14, 14)

会使得你无法使用原来的数据,占用内存变多会重新申请一片空间

T 矩阵转置

a = torch.randn(3, 4)
a.t().shape #torch.Size([4, 3])
#仅支持2D,1D、3D...都不支持

Transpose 交换维度

a.shape = [4, 3, 32, 32]
a1 = a.transpose(1, 3).view(4, 3*32*32).view(4, 3, 32, 32)
-------------ERROR------------
#transpose操作之后会使得元素不连续,所以在view之前要加上contiguous操作

a1 = a.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 3, 32, 32) #污染数据
a1.shape #torch.Size([4, 3, 32, 32])

a2 = a.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 32, 32, 3).transpose(1, 3) #与a等价,注意第二个view与a1不同
a2.shape #torch.Size([4, 3, 32, 32])

#验证,用eq函数比较各个数据是否一致,返回[4, 3, 32, 32]的张量,all函数判断其所有元素是否都为True
torch.all(torch.eq(a, a1)) #tensor(False)

torch.all(torch.eq(a, a2)) #tensor(True)

Permute

transpose只能两两交换

b = torch.rand(4, 3, 28, 32)
b.tanspose(1, 3).transpose(1, 2).shape
#torch.Size([4, 28, 32, 3])

b.permute(0, 2, 3, 1).shape
#torch.Size([4, 28, 32, 3])
#permute的参数为排列后的索引
> permute函数也会打乱内存顺序,需要时也要用到contifuous函数,也就是重新生成一片内存再复制过来.