PyTorch笔记03----索引与切片

PyTorch中的索引与切片.

Indexing

a = torch.rand(4, 3, 28, 28)

a[0].shape #torch.Size([3, 28, 28])
a[0, 0].shape #torch.Size([28, 28])
a[0, 0, 2, 4] #tensor(0.8082)

select first / last N

a.shape         #torch.Size([4, 3, 28, 28])
a[:2].shape #torch.Size([2, 3, 28, 28]) 左闭右开
a[:2, :1, :,:] #torch.Size([2, 1, 28, 28])
a[:2, 1:, :,:] #torch.Size([2, 2, 28, 28])
a[:2, -1:, :,:] #torch.Size([2, 1, 28, 28])

select by steps

a[:, :, 0:28:2, 0:28:2].shape
#torch.Size([4, 3, 14, 14]) 后两个维度每两个取一个,2是step
a[:, :, ::2, ::2].shape #与上面的等价

同Python切片

select by specific index

  • index_select() 第一个参数表示操作的维度,第二个参数直接给索引号(必须是Tensor)
    a.index_select(0, torch.Tensor([0, 2])).shape
    #torch.Size([2, 3, 28, 28])
    a.index_select(1, torch.Tensor([1, 2])).shape
    #torch.Size([4, 2, 28, 28])

    a.index_select(2, torch.arange(8)).shape
    #torch.Size([4, 3, 8, 28])

select by ...

...代表任意维度,贪心匹配(?)

a[...].shape            #等价于a[:, :, :, :].shape
a[0, ...].shape #等价于a[0].shape
a[0,...,::2].shape #torch.Size([1, 3, 28, 14])
a[:, 1, ...].shape #torch.Size([4, 1, 28, 28])

select by mask

使用掩码选择,会打平数据,用的不多

x = torch.randn(3, 4)
#tensor([[-1.3911, -0.7871, -1.6558, -0.2542],
[-0.9011, 0.5404, -0.6612, 0.3917],
[-0.3854, 0.2968, 0.6040, 1.5771]])

mask = x.ge(0.5) #大于等于0.5的记为1
#tensor([[0, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 1]], dtype = torch.uint8)

torch.masked_select(x, mask)
#tensor([0.5404, 0.6040, 1.5771])
torch.masked_select(x, mask).shape
#torch.Size([3]) 与原shape无关

select by flatten index

  • take函数也会打平数据,用的不多
src = torch.tensor([[4, 3, 5],
[6, 7, 8]])
torch.take(src, torch.tensor([0, 2, 5]))
#tensor([4, 5, 8])
#[2, 3]先打平成[6],再选index为0,2,5的数据