首先,先给出torch.gather函数的函数定义:
torch.gather(input, dim, index, out=None) → Tensor
官方给出的解释是这样的:
沿给定轴dim,将输入索引张量index指定位置的值进行聚合。
对一个3维张量,输出可以定义为:
out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3
刚开始看上去有点难以理解,但经过研究之后发现原来这个想表述的很简单,先给出几个代码例子让大家自行体会一下。
>>> import torch
>>> a = torch.Tensor([[1,2],[3,4]])
>>> a
1 2
3 4
[torch.FloatTensor of size 2x2]
>>> b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]]))
>>> b
1 1
4 3
[torch.FloatTensor of size 2x2]
>>> b = torch.gather(a,1,torch.LongTensor([[1,0],[1,0]]))
>>> b
2 1
4 3
[torch.FloatTensor of size 2x2]
>>> b = torch.gather(a,1,torch.LongTensor([[1,1],[1,0]]))
>>> b
2 2
4 3
[torch.FloatTensor of size 2x2]
很容易就会发现 torch.gather(input, dim, index, out=None)中的dim表示的就是第几维度,在这个二维例子中,如果dim=0,
那么它表示的就是你接下来的操作是对于第一维度进行的,也就是行;如果dim=1,那么它表示的就是你接下来的操作是对于第
二维度进行的,也就是列。index的大小和input的大小是一样的,他表示的是你所选择的维度上的操作,比如这个例子中
a = torch.Tensor([[1,2],[3,4]]) b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]])) 其中, dim=1,表示的是在第二维度上操作。
index = torch.LongTensor([[0,0],[1,0]]),[0,0]就是第一行对应元素的下标,也就是对应的是[1,1]; [1,0]就是第二行对应元素的下标,也就是对应的是[4,3]。
!!!特别注意一下,index的类型必须是LongTensor类型的。