PyTorch笔记06----拼接与拆分

Tensor的拼接与拆分操作

  • Cat
  • Stack
  • Split
  • Chunk

cat

Statisics about scores - [class1-4, students, scores] - [class5-9, students, scores]

a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)

torch.cat([a, b], dim = 0).shape
#torch.Size([9, 32, 8])
#第一个参数是一个list,包含了所有需要拼接的Tensor
#第二个参数dim决定了合并的维度
#其他维度要一样
c = torch.rand(3, 32, 1)
torch.cat([a, c], dim = 0).shape
--------ERROR--------

stack

create new dim

a = torch.rand(32, 8)
b = torch.rand(32, 8)

torch.stack([a, b], dim = 0).shape
#torch.Size([2, 32, 8])
torch.stack([a, b], dim = 1).shape
#torch.Size([32, 2, 8])
例如一个老师统计了一个班的成绩: [students, scores],另一个老师也是,那么使用stack得到[2, students, scores],而不是把students那一维度拼接起来

stack必须维度一致

a = torch.rand(30, 8)
b = torch.rand(32, 8)
torch.stack([a, b], dim = 0)
--------ERROR--------

split: by len

b = torch.rand(32, 8)
a.shape #torch.Size([32, 8])
c = torch.stack([a, b], dim = 0)
c.shape #torch.Size([2, 32, 8])

aa, bb = c.split(1, dim = 0) #第一个参数是长度
aa.shape, bb.shape
#torch.Size([1, 32, 8]), torch.Size([1, 32, 8])

aa, bb = c.split([1, 1], dim = 0) #第一个参数是长度的list
#结果同上

aa, bb = c.split(2, dim = 0)
-------ERROR--------

chunk: by num

b = torch.rand(32, 8)
a.shape #torch.Size([32, 8])
c = torch.stack([a, b], dim = 0)
c.shape #torch.Size([2, 32, 8])

aa, bb = c.chunk(2, dim = 0) #平均分成2块
aa.shape, bb.shape
#torch.Size([1, 32, 8]), torch.Size([1, 32, 8])