在深度学习和计算机视觉任务中,经常需要对图像或特征图进行采样和变换。PyTorch 提供的 F.grid_sample
函数非常方便,用于根据指定的坐标从输入张量中采样特定点的值。本文将详细介绍如何使用 F.grid_sample
,并通过两个具体例子解释其工作原理。
什么是 F.grid_sample
?
F.grid_sample
是 PyTorch 中的一个函数,用于根据给定的坐标网格对输入张量进行采样。它常用于图像变形、数据增强等任务。函数的核心思想是使用双线性插值从输入张量中提取指定坐标的值。
示例代码
以下代码展示了如何使用 F.grid_sample
从一个 4x4 的输入张量中采样特定点的值:
import torch
import torch.nn.functional as F
# 定义一个 4x4 的输入张量
input_tensor = torch.tensor([[[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]]]], dtype=torch.float)
# 定义采样点,归一化坐标在 [-1, 1] 范围内
# 这里使用小数坐标进行采样
grid = torch.tensor([[[[-0.5, -0.5], [0.5, -0.5]],
[[-0.5, 0.5], [0.5, 0.5]]]], dtype=torch.float)
# 使用 F.grid_sample 进行采样
output = F.grid_sample(input_tensor, grid, align_corners=True)
print(output)
计算过程
假设输入张量的尺寸为 (4, 4)
,采样点坐标的归一化范围在 [-1, 1]
,我们将其转换为张量坐标的范围 [0, 3]
。
归一化坐标转换公式
归一化坐标转换公式如下:
x input = ( x grid + 1 ) ⋅ ( W − 1 ) 2 x_{\text{input}} = \frac{(x_{\text{grid}} + 1) \cdot (W - 1)}{2} xinput=2(xgrid+