一文彻底弄懂 PyTorch 的 `F.grid_sample`

在深度学习和计算机视觉任务中,经常需要对图像或特征图进行采样和变换。PyTorch 提供的 F.grid_sample 函数非常方便,用于根据指定的坐标从输入张量中采样特定点的值。本文将详细介绍如何使用 F.grid_sample,并通过两个具体例子解释其工作原理。

什么是 F.grid_sample

F.grid_sample 是 PyTorch 中的一个函数,用于根据给定的坐标网格对输入张量进行采样。它常用于图像变形、数据增强等任务。函数的核心思想是使用双线性插值从输入张量中提取指定坐标的值。

示例代码

以下代码展示了如何使用 F.grid_sample 从一个 4x4 的输入张量中采样特定点的值:

import torch
import torch.nn.functional as F

# 定义一个 4x4 的输入张量
input_tensor = torch.tensor([[[[1, 2, 3, 4],
                               [5, 6, 7, 8],
                               [9, 10, 11, 12],
                               [13, 14, 15, 16]]]], dtype=torch.float)

# 定义采样点,归一化坐标在 [-1, 1] 范围内
# 这里使用小数坐标进行采样
grid = torch.tensor([[[[-0.5, -0.5], [0.5, -0.5]],
                      [[-0.5, 0.5], [0.5, 0.5]]]], dtype=torch.float)

# 使用 F.grid_sample 进行采样
output = F.grid_sample(input_tensor, grid, align_corners=True)

print(output)

计算过程

假设输入张量的尺寸为 (4, 4),采样点坐标的归一化范围在 [-1, 1],我们将其转换为张量坐标的范围 [0, 3]

归一化坐标转换公式

归一化坐标转换公式如下:
x input = ( x grid + 1 ) ⋅ ( W − 1 ) 2 x_{\text{input}} = \frac{(x_{\text{grid}} + 1) \cdot (W - 1)}{2} xinput=2(xgrid+

### 使用TPS进行深度学习网络校正 #### TPS原理概述及其应用背景 薄板样条插值(TPS)是一种广泛应用于计算机视觉中的变形模型,能够有效地处理非线性形变。该方法通过最小化能量函数来找到最优的映射关系,在保持全局形状的同时允许局部细节调整[^1]。 对于深度学习而言,TPS被引入至空间变换网络(Spatial Transformer Networks, STN),作为其中的关键组件之一参与图像预处理工作。具体来说,《Robust Scene Text Recognition with Automatic Rectification》一文中提到的空间变换模块即采用了基于TPS的空间几何矫正机制,旨在解决场景文字识别任务中存在的视角扭曲等问题[^2]。 #### 实现过程 为了在神经网络框架内集成TPS功能,开发者们通常会借助于流行的机器学习库如PyTorch完成相应操作。下面给出了一种可能的方式: - **定义控制点集**:选取源图片和平面上若干对应的特征点作为控制点; - **构建参数矩阵W**:依据上述选定的控制点对建立方程组求解未知数w_i; - **计算径向基函数Φ(r)**:根据欧氏距离r确定每一对控制点间的相互影响程度; - **预测目标位置**:利用得到的权重系数和径向基函数表达式估计新坐标的精确值。 以下是采用Python编写的一个简单例子,展示了如何利用PyTorch实现这一流程: ```python import torch from scipy.spatial import distance_matrix def tps_transform(points_src, points_dst, grid_size=20): """Compute the TPS transformation matrix.""" def compute_U(distsq): return distsq * torch.log(distsq + 1e-7) num_points = len(points_src) P = torch.ones((num_points, 3)) P[:, 1:] = points_src K = compute_U(distance_matrix(points_src.cpu().numpy(), points_src.cpu().numpy())) L = torch.zeros(num_points + 3, num_points + 3).to(K.device) L[:num_points, :num_points] = K + torch.eye(num_points)*1e-3 L[:num_points, -3:] = P L[-3:, :num_points] = P.t() V = torch.cat([points_dst, torch.zeros(3, 2)], dim=0) weights = torch.solve(V.unsqueeze(-1), L)[0][:num_points] def transform_grid(grid_coords): U_dist = compute_U(((grid_coords.view(-1, 1, 2)-points_src)**2).sum(dim=-1)).view(*grid_coords.shape[:-1], -1) affine_part = (torch.matmul(P[:,:2].t(),weights[num_points:].unsqueeze(-1))).squeeze() nonrigid_part = ((U_dist*weights[:num_points]).sum(dim=-1)+affine_part).reshape_as(grid_coords) return nonrigid_part x_range = torch.linspace(min(points_src[:, 0]), max(points_src[:, 0]), steps=grid_size) y_range = torch.linspace(min(points_src[:, 1]), max(points_src[:, 1]), steps=grid_size) xv, yv = torch.meshgrid(x_range, y_range) coords = torch.stack([xv.flatten(), yv.flatten()],dim=-1).float().cuda() if next(model.parameters()).is_cuda else .float() transformed_coords = transform_grid(coords) warped_image = F.grid_sample(image_tensor, transformed_coords.reshape(y_range.size()[0], x_range.size()[0], 2)) return warped_image ``` 此段代码实现了基本的TPS转换逻辑,并将其应用于给定网格上的像素重定位。需要注意的是实际应用场景下还需考虑更多因素比如边界条件处理等。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值