【Pytorch数据处理】之Torch.index_select()、Torch.gather()详解
进行数据处理的时候需要根据索引值返回tensor对应的值(value),摸索了好久,终于搞清楚了Torch.index_select和Torch.gather的区别和使用,记录并分享。
Torch.index_select(input, dim, index)
Torch.gather(input, dim, index)
可以看到,这两个方法的参数都完全一致,类型也完全一致(input: TensorBase, dim:int,idex:Tensor),input为要查找的初始张量,dim为按照什么维度进行查找,index也是张量形式,指查找的索引号。但是这两个方法的使用结果完全不一样,简而言之,Index_select()返回的是某一维度的所有值,gather()是选张量中某一位置的具体值。
下面展示具体示例帮助理解:
Torch.index_select(input, dim, index)
#初始化4*4的tensor
x = torch.randn(4, 4)
#构建需要查找的索引
index = torch.tensor([0, 2])
print(x)
>>>
tensor([[-1.3314, 0.7889, -0.3725, -0.2781],
[ 1.2508, 1.8838, 0.0600, 0.6576],
[ 2.4463, -1.6365, 1.0068, 0.5613],
[-1.6658, 0.6191, -0.9908, -0.0912]])
#按照dim=0取索引为index的张量
print(torch.index_select(x, 0, index))
>>