pytorch view、permute、linspace repeat

本文详细介绍了PyTorch中关键的张量操作,包括view和permute用于数据维度的重塑与排序,linspace用于生成等差数列,以及repeat进行维度上的重复。这些操作对于深度学习模型的数据预处理和后处理至关重要。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

view:取出所有元素后,按新维度重新排序

permute:交换维度的值  ,然后重新排序

 

linspace:(start,end,分成几份)

grid_x=torch.linspace(0,9,10)
print(grid_x)

输出:tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])

 

repeat:参数:表示该维度重复几次

 

 

 

这段代码的主要功能是在三维空间中生成一组归一化的坐标点,并将其组织成适合深度学习模型处理的形式(如PyTorch中的张量)。以下是对每个步骤的详细说明: ### 代码详解 ```python if dim == '3d': ``` 检查变量`dim`是否为字符串'3d',如果是,则执行后续操作。 --- #### **1. 创建Z轴方向的坐标** ```python zs = torch.linspace(0.5, Z - 0.5, num_points_in_pillar, dtype=dtype, device=device).view(-1, 1, 1).expand(num_points_in_pillar, H, W) / Z ``` - `torch.linspace(start=0.5, end=Z-0.5, steps=num_points_in_pillar)` 在区间 [0.5, Z-0.5] 上均匀地创建了`num_points_in_pillar`个值。 - `.view(-1, 1, 1)` 将结果调整形状为 `(num_points_in_pillar, 1, 1)` 的张量。 - `.expand(num_points_in_pillar, H, W)` 扩展维度,将该张量广播到大小为`(num_points_in_pillar, H, W)`。 - `/ Z` 对所有元素进行归一化处理,确保它们位于 `[0, 1]` 区间内。 最终得到的是沿 Z 轴分布的一组归一化坐标值。 --- #### **2. 创建X轴方向的坐标** ```python xs = torch.linspace(0.5, W - 0.5, W, dtype=dtype, device=device).view(1, 1, W).expand(num_points_in_pillar, H, W) / W ``` - 类似于上一步,这里使用 `torch.linspace` 来生成 X 方向上的坐标值,在范围 `[0.5, W-0.5]` 内均匀采样出 `W` 个值。 - `.view(1, 1, W)` 将其转为尺寸 `(1, 1, W)` 的张量。 - `.expand(num_points_in_pillar, H, W)` 广播到与 Z 坐标相同的形状 `(num_points_in_pillar, H, W)`。 - 最后除以 `W` 进行归一化。 --- #### **3. 创建Y轴方向的坐标** ```python ys = torch.linspace(0.5, H - 0.5, H, dtype=dtype, device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / H ``` - 同理,这步用于生成 Y 轴方向的坐标。在范围 `[0.5, H-0.5]` 中均匀采样出 `H` 个值。 - 经过 reshape 和 expand 操作之后,形成一个形状为 `(num_points_in_pillar, H, W)` 的张量,并对所有的数值做归一化 (`/ H`)。 --- #### **4. 构建完整的三维坐标系** ```python ref_3d = torch.stack((xs, ys, zs), -1) ``` - 使用 `torch.stack` 函数按最后一个维度 (即新的第 4 维度)堆叠三个坐标矩阵 `xs`, `ys`, `zs`。此时输出形状变为 `(num_points_in_pillar, H, W, 3)`,其中每一点表示了一个三维坐标的 (x,y,z) 分量。 --- #### **5. 数据重排** ```python ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1) ``` - 首先调用 `permute(0, 3, 1, 2)` 改变数据排列顺序,使得通道维度移到第二维的位置,类似于从 NHWC 到 NCHW 格式转换。 - 接着利用 `flatten(2)` 把高宽两个维度合并成一个新的单维度。 - 最后再通过一次 `permute(0, 2, 1)` 调整维度顺序,满足特定的数据结构需求。 --- #### **6. Batch扩展并返回结果** ```python ref_3d = ref_3d[None].repeat(bs, 1, 1, 1) return ref_3d ``` - 添加额外的一个批处理维度并通过 `repeat(bs,...)` 方法复制 bs 次,从而适配批量输入的要求。 - 返回最终计算好的参考三维网格。 --- ### 解释 上述逻辑的核心在于构建一个基于像素位置以及高度分层信息的空间坐标系统,通常会应用于计算机视觉任务比如立体匹配或体素特征提取等场景下作为几何先验供网络预测时使用。此外,这些归一化的坐标能够帮助简化卷积神经网络训练过程中涉及的距离比较和变换运算等问题。 --- ### 示例运行假设参数下的完整版示例代码: ```python import torch def generate_ref_3d(dim='3d', Z=8, W=16, H=12, num_points_in_pillar=4, dtype=torch.float32, device='cpu', bs=2): if dim != '3d': return None # Define the z coordinate points. zs = torch.linspace(0.5, Z - 0.5, num_points_in_pillar, dtype=dtype,device=device)\ .view(-1, 1, 1).expand(num_points_in_pillar, H, W)/Z # Define x coordinates across width dimension xs = torch.linspace(0.5,W -0.5 , W,dtype=dtype,device=device)\ .view(1, 1, W).expand(num_points_in_pillar,H,W)/W #Define y coordiantes accross height direction ys = torch.linspace(0.5,H -0.5 , H,dtype=dtype,device=device)\ .view(1,H,1).expand(num_points_in_pillar,H,W)/H # Combine into a single tensor for all three dimensions ref_3d = torch.stack((xs,ys,zs),-1) # Reshape to desired format ref_3d = ref_3d.permute(0,3,1,2).flatten(2).permute(0,2,1) # Add batch size and repeat accordingly ref_3d = ref_3d[None].repeat(bs,1,1,1 ) return ref_3d print(generate_ref_3d()) ``` --- **
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值