bPytorch系列(1):torch.gather()_c-minus的博客-CSDN博客_torch。gather
a.gather(dim,b) --以dim和b为索引,从a中提取数据组成新的tensor,shape与b一致。
其中a为tensor类型,b为longtensor类型。
dim = 0 ,按列提取;dim = 1 ,按行提取。
总结:确定好按行or列提取之后,b中每行数据代表从a中提取每行or列的第几个数。(若b第一行为[1,2],则表示提取a中第一行or列的第2,3个数。以此类推)
import torch
a = [
[2, 3, 4, 5, 0, 0],
[1, 4, 3, 0, 0, 0],
[4, 2, 2, 5, 7, 0],
[1, 0, 0, 0, 0, 0]
]
a = torch.tensor(a)
b = torch.LongTensor([[4],[3],[5],[1]])
#b之所以减1,是因为序列维度是从0开始计算的
out = torch.gather(a, 1, b-1)
out