PyTorch的一些Tensor维度变换方式
View / Reshape
通用,要有物理意义,否则是污染数据 a = torch.rand(4, 1, 28, 28) a.shape #torch.Size([4, 1, 28, 28])
a.view(4, 28 * 28).shape #torch.Size([4, 784]) #特别适合全连接层
b = a.view(4, 784) #会丢失原来的维度信息 b.view(4, 28, 28, 1) #造成了数据污染
|
- Flewible but prone to corrupt 如果view的size不同,会报错
a.view(4, 783) --------Error--------
|
Squeeze / Unsqueeze
Unsqueeze
a.shape #torch.Size([4, 1, 28, 28])
a.unsqueeze(0).shape #torch.Size([1, 4, 1, 28, 28]) 在位置0处增加一维
a.unsqueeze(-1).shape #torch.Size([4, 1, 28, 28, 1])
a.unsqueeze(5).shape ---------Error--------
|
a = torch.tensor([1.2, 2.3]) a.unsqueeze(-1) #tensor([[1.2000], [2.3000]]) #shape由[2]变为[2,1]
a.unsqueeze(0) #tensor([[1.2000, 2.3000]]) #shape由[2]变为[1,2]
|
- Example > 将shape为[32]的bias增加到shape为[4, 32, 14,
14]的FeatureMap的channel中
b = torch.rand(32) f = torch.rand(4, 32, 14, 14)
b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0) #注意第二次的unsqueeze参数是按第一次操作过后的索引
b.shape #torch.Size([1, 32, 1, 1])
|
Squeeze
b.shape #torch.Size([1, 32, 1, 1]) b.squeeze().shape #torch.Size([32]) 不传递参数时将能挤压的全部挤压 b.squeeze(0).shape #torch.Size([32, 1, 1])
b.squeeze(1).shape #torch.Size([1, 32, 1, 1]) 维度不是1,所以不能挤压,但不会报错
|
Expand / Repeat
维度扩展
之前的b: [1, 32, 1, 1]``f: [4, 32, 14, 14]
仍不能相加,需要分别将第0、2、3维度扩展4、14、14倍
- Expand:broadcasting. 只改变理解方式,不改变数据
- Repeat:memory copied. 增加数据,都拷贝一遍.
推荐Expand,省略复制数据,只在操作必要时复制.
运行速度快且节约内存.
Expand
a = torch.rand(4, 32, 14, 14) b.shape #torch.Size(1, 32, 1, 1) b.expand(4, 32, 14, 14).shape #torch.Size(4, 32, 14, 14) #前提:前后dim一致(这里都为4),1扩展到N,M扩展到M
b.expand(-1, 32, -1, -1).shape #torch.Size(1, 32, 1, 1) -1保持不变
b.expand(-1, 32, -1, -4).shape #torch.Size(1, 32, 1, -4) 是个bug,无意义
|
Repeat
b.shape #torch.Size(1, 32, 1, 1) b.repeat(4, 32, 1, 1).shape #torch.Size(4, 1024, 1, 1) repeat的参数是对应维度复制的次数
b.repeat(4, 1, 14, 14).shape #torch.Size(4, 32, 14, 14)
|
会使得你无法使用原来的数据,占用内存变多会重新申请一片空间
T 矩阵转置
a = torch.randn(3, 4) a.t().shape #torch.Size([4, 3]) #仅支持2D,1D、3D...都不支持
|
Transpose 交换维度
a.shape = [4, 3, 32, 32] a1 = a.transpose(1, 3).view(4, 3*32*32).view(4, 3, 32, 32) -------------ERROR------------ #transpose操作之后会使得元素不连续,所以在view之前要加上contiguous操作
a1 = a.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 3, 32, 32) #污染数据 a1.shape #torch.Size([4, 3, 32, 32])
a2 = a.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 32, 32, 3).transpose(1, 3) #与a等价,注意第二个view与a1不同 a2.shape #torch.Size([4, 3, 32, 32])
#验证,用eq函数比较各个数据是否一致,返回[4, 3, 32, 32]的张量,all函数判断其所有元素是否都为True torch.all(torch.eq(a, a1)) #tensor(False)
torch.all(torch.eq(a, a2)) #tensor(True)
|
Permute
transpose只能两两交换 b = torch.rand(4, 3, 28, 32) b.tanspose(1, 3).transpose(1, 2).shape #torch.Size([4, 28, 32, 3])
b.permute(0, 2, 3, 1).shape #torch.Size([4, 28, 32, 3]) #permute的参数为排列后的索引
|
>
permute函数也会打乱内存顺序,需要时也要用到contifuous函数,也就是重新生成一片内存再复制过来.