pytorch之torch.gather方法

本文详细解析了PyTorch中的torch.gather函数的功能与用法,通过实例演示了如何使用该函数来根据索引从输入张量中选取特定维度的数据。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

首先,先给出torch.gather函数的函数定义:

        torch.gather(input, dim, index, out=None) → Tensor

官方给出的解释是这样的:

                    沿给定轴dim,将输入索引张量index指定位置的值进行聚合。

                    对一个3维张量,输出可以定义为:

                out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0

                out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1

                out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3

刚开始看上去有点难以理解,但经过研究之后发现原来这个想表述的很简单,先给出几个代码例子让大家自行体会一下。

>>> import torch
>>> a = torch.Tensor([[1,2],[3,4]])
>>> a
 1  2
 3  4
[torch.FloatTensor of size 2x2]
>>> b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]]))
>>> b
 1  1
 4  3
[torch.FloatTensor of size 2x2]
>>> b = torch.gather(a,1,torch.LongTensor([[1,0],[1,0]]))
>>> b
 2  1
 4  3
[torch.FloatTensor of size 2x2]
>>> b = torch.gather(a,1,torch.LongTensor([[1,1],[1,0]]))
>>> b
 2  2
 4  3
[torch.FloatTensor of size 2x2]


很容易就会发现 torch.gather(input, dim, index, out=None)中的dim表示的就是第几维度,在这个二维例子中,如果dim=0,

那么它表示的就是你接下来的操作是对于第一维度进行的,也就是行;如果dim=1,那么它表示的就是你接下来的操作是对于第

二维度进行的,也就是列。index的大小和input的大小是一样的,他表示的是你所选择的维度上的操作,比如这个例子中

a = torch.Tensor([[1,2],[3,4]])  b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]])) 其中, dim=1,表示的是在第二维度上操作。
index = torch.LongTensor([[0,0],[1,0]]),[0,0]就是第一行对应元素的下标,也就是对应的是[1,1]; [1,0]就是第二行对应元素的下标,也就是对应的是[4,3]。

!!!特别注意一下,index的类型必须是LongTensor类型的。

 

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值