PyTorch笔记03----索引与切片
PyTorch中的索引与切片.
Indexing
a = torch.rand(4, 3, 28, 28) |
select first / last N
a.shape #torch.Size([4, 3, 28, 28]) |
select by steps
a[:, :, 0:28:2, 0:28:2].shape |
同Python切片
select by specific index
- index_select()
第一个参数表示操作的维度,第二个参数直接给索引号(必须是Tensor)
a.index_select(0, torch.Tensor([0, 2])).shape
#torch.Size([2, 3, 28, 28])
a.index_select(1, torch.Tensor([1, 2])).shape
#torch.Size([4, 2, 28, 28])
a.index_select(2, torch.arange(8)).shape
#torch.Size([4, 3, 8, 28])
select by ...
...
代表任意维度,贪心匹配(?)
a[...].shape #等价于a[:, :, :, :].shape
a[0, ...].shape #等价于a[0].shape
a[0,...,::2].shape #torch.Size([1, 3, 28, 14])
a[:, 1, ...].shape #torch.Size([4, 1, 28, 28])
select by mask
使用掩码选择,会打平数据,用的不多
x = torch.randn(3, 4) |
select by flatten index
- take函数也会打平数据,用的不多
src = torch.tensor([[4, 3, 5], |