Tensor 的拼接与拆分操作
cat
Statisics about scores - [class1-4, students, scores] - [class5-9,
students, scores]
a = torch.rand(4, 32, 8) b = torch.rand(5, 32, 8)
torch.cat([a, b], dim = 0).shape #torch.Size([9, 32, 8]) #第一个参数是一个list,包含了所有需要拼接的Tensor #第二个参数dim决定了合并的维度 #其他维度要一样 c = torch.rand(3, 32, 1) torch.cat([a, c], dim = 0).shape --------ERROR--------
|
stack
create new dim
a = torch.rand(32, 8) b = torch.rand(32, 8)
torch.stack([a, b], dim = 0).shape #torch.Size([2, 32, 8]) torch.stack([a, b], dim = 1).shape #torch.Size([32, 2, 8])
|
例如一个老师统计了一个班的成绩:
[students, scores],另一个老师也是,那么使用 stack 得到 [2, students,
scores],而不是把 students 那一维度拼接起来
stack 必须维度一致
a = torch.rand(30, 8) b = torch.rand(32, 8) torch.stack([a, b], dim = 0) --------ERROR--------
|
split: by len
b = torch.rand(32, 8) a.shape #torch.Size([32, 8]) c = torch.stack([a, b], dim = 0) c.shape #torch.Size([2, 32, 8])
aa, bb = c.split(1, dim = 0) #第一个参数是长度 aa.shape, bb.shape #torch.Size([1, 32, 8]), torch.Size([1, 32, 8])
aa, bb = c.split([1, 1], dim = 0) #第一个参数是长度的list #结果同上
aa, bb = c.split(2, dim = 0) -------ERROR--------
|
chunk: by num
b = torch.rand(32, 8) a.shape #torch.Size([32, 8]) c = torch.stack([a, b], dim = 0) c.shape #torch.Size([2, 32, 8])
aa, bb = c.chunk(2, dim = 0) #平均分成2块 aa.shape, bb.shape #torch.Size([1, 32, 8]), torch.Size([1, 32, 8])
|
Be the first person to leave a comment!