【pytorch】torch.stack()/F.interpolate/unsqueeze_(0)/squeeze()

文章介绍了PyTorch中的torch.stack函数,用于将张量序列拼接成新的张量,以及F.interpolate函数,用于对张量进行插值以改变尺寸。torch.stack适用于处理如视频帧的数据,而F.interpolate支持多种插值算法,如最近邻、双线性等,并能在保持图像特征的同时调整尺寸。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.torch.stack()

torch.stack() 函数会将一个元素为张量的序列(列表、元组等)沿着新的维度进行拼接,生成一个新的张量。例如,如果有一个包含 n 个 k维张量的序列,那么它们拼接后的张量就会是 (n, k) 维的。

如果我们将视频的每一帧作为一个张量,存储在列表 frames 中,torch.stack(frames) 的作用就是将所有的帧拼接成一个新的张量 x。假设视频总共有 m 帧,那么 x 的形状就是 (m, c, h, w),其中 c、h、w 分别表示图像的通道数、高度、宽度。因为视频中的每一帧都具有相同的形状,所以 torch.stack() 可以正确地将这些张量拼接在一起,形成一个四维张量。

2.F.interpolate

F.interpolate()是PyTorch中的函数,用于对张量进行插值操作,从而改变其尺寸。具体来说,该函数可以对二维或三维输入张量进行插值,支持各种插值算法,例如最近邻插值、双线性插值和三线性插值等。

函数的参数说明如下:

  • input:输入的张量;
  • size:输出张量的尺寸;
  • scale_factor:输出张量与输入张量的比例系数;
  • mode:插值算法,可以是“nearest”(最近邻插值)、“bilinear”(双线性插值)、“bicubic”(三次样条插值)、“trilinear”(三线性插值)等;
  • align_corners:在进行插值时,是否将像素点放在输入张量的角落。如果该参数为True,则像素点会放在输入张量的角落,否则会放在输入张量的中心。

划重点:

  1. 输入张量的形状必须是4维的,具体形状是 (batch_size, channels, height, width)。

  2. 其中,batch_size和channels是可以为1的,表示只有1个样本和1个通道。

  3. 要对输入张量进行的插值操作必须在height和width两个维度上进行,即必须在后两个维度上指定插值后的目标大小。

  4. 目标大小的形状是(height, width)。

在进行插值操作时,F.interpolate()函数会自动根据目标大小和原始大小计算出插值的比例,并对输入张量进行插值操作。

函数的返回值为插值后的张量。

举个例子:

假设有一个 2x2 的张量 x,其值为:

tensor([[1, 2],
        [3, 4]])

现在我们想把它插值到 4x4 的大小,使用双线性插值的方式,可以这样实现:

import torch.nn.functional as F

x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)

# 对张量进行插值,大小为 4x4
y = F.interpolate(x.unsqueeze(0).unsqueeze(0), size=(4, 4), mode='bilinear', align_corners=False)

# 打印输出结果
print(y.squeeze())

》》》
tensor([[1.0000, 1.3333, 1.6667, 2.0000],
        [1.6667, 2.0000, 2.3333, 2.6667],
        [2.3333, 2.6667, 3.0000, 3.3333],
        [3.0000, 3.3333, 3.6667, 4.0000]])

3.unsqueeze_(0)

unsqueeze_(0) 是 PyTorch 中的一个方法,它会在张量的第 0 维(即最前面)增加一个维度。例如,如果原来一个张量的形状为 (3, 5),那么执行 unsqueeze_(0) 后,它的形状就变成了 (1, 3, 5)

4.squeeze()

squeeze()是一个PyTorch张量操作,用于从张量中删除尺寸为1的维度,并返回新的张量。

如果张量中有一个或多个维度的大小为1,则可以使用squeeze()将这些维度删除。例如,如果我们有一个形状为(1, 3, 1, 5)的张量,使用squeeze()将删除尺寸为1的维度,并返回一个形状为(3, 5)的新张量。

如果你想指定要删除的维度,则可以使用squeeze(dim),其中dim是要删除的维度的索引。

 

# data/dataset.py import os import numpy as np import torch from torch.utils.data import Dataset from plyfile import PlyData import open3d as o3d def read_ply_points(file_path): """ 使用 plyfile 解析 .ply 文件,提取 x, y, z 字段 返回一个 numpy 数组 (N, 3) """ ply = PlyData.read(file_path) # 提取 x, y, z 字段 x = ply['vertex']['x'] y = ply['vertex']['y'] z = ply['vertex']['z'] # 拼接为点云数组 points = np.vstack([x, y, z]).T points = points.astype(np.float32) return points class UnifiedPointCloudDataset(Dataset): def __init__(self, root_dirs, file_exts=['.ply', '.stl', '.obj'], num_points=1024): """ :param root_dirs: 包含所有数据文件夹的列表 :param file_exts: 支持的点云格式 :param num_points: 每个点云采样点数 """ self.file_list = [] self.num_points = num_points # 收集所有点云文件路径 for root_dir in root_dirs: if not os.path.exists(root_dir): raise FileNotFoundError(f"❌ 数据目录不存在: {root_dir}") for root, _, files in os.walk(root_dir): for file in files: if any(file.lower().endswith(ext) for ext in file_exts): full_path = os.path.join(root, file) if os.path.exists(full_path): # 确保文件真实存在 self.file_list.append(full_path) print(f"✅ 共发现 {len(self.file_list)} 个点云文件,用于训练") if len(self.file_list) == 0: raise ValueError("⚠️ 没有发现任何点云文件,请检查路径和文件格式") def __len__(self): return len(self.file_list) def __getitem__(self, idx): path = self.file_list[idx] ext = os.path.splitext(path)[1].lower() try: if ext == '.ply': # 使用 plyfile 读取 .ply 文件 points = read_ply_points(path) elif ext == '.stl': # 使用 open3d 读取 STL 文件并采样成点云 mesh = o3d.io.read_triangle_mesh(path) pcd = mesh.sample_points_uniformly(number_of_points=100000) points = np.asarray(pcd.points) elif ext == '.obj': # 使用 open3d 读取 OBJ 文件 pcd = o3d.io.read_point_cloud(path) if not pcd.has_points(): raise ValueError(f"点云为空或损坏: {path}") points = np.asarray(pcd.points) else: raise ValueError(f"不支持的格式: {ext}") # 检查点云是否为空 if len(points) < 10: raise ValueError(f"点云为空或损坏: {path}") # 固定点数采样 if len(points) < self.num_points: indices = np.random.choice(len(points), self.num_points, replace=True) else: indices = np.random.choice(len(points), self.num_points, replace=False) points = points[indices] return torch.tensor(points, dtype=torch.float32) except Exception as e: print(f"❌ 读取失败: {path},错误: {str(e)}") return self.__getitem__((idx + 1) % len(self.file_list)) # models/ballmamba.py import torch import torch.nn as nn from torch_geometric.nn import knn_graph, radius_graph from models.mamba_block import MambaBlock from utils.pointcloud_utils import farthest_point_sample class FPS_Causal(nn.Module): def __init__(self, in_channels, hidden_channels, k=512): super(FPS_Causal, self).__init__() self.k = k self.downsample = nn.Linear(in_channels, hidden_channels) self.mamba = MambaBlock(hidden_channels) def forward(self, x, pos): batch_size, num_points, _ = pos.shape idxs = [] for i in range(batch_size): idx = farthest_point_sample(pos[i], self.k) # 现在支持 (N, 3) idxs.append(idx) idx = torch.stack(idxs, dim=0) # (B, k) # 使用 gather 正确采样 x_sampled = torch.gather(pos, 1, idx.unsqueeze(-1).expand(-1, -1, 3)) # (B, k, 3) x_sampled = self.downsample(x_sampled) # (B, k, hidden_channels) x_sampled = self.mamba(x_sampled) # (B, k, hidden_channels) return x_sampled class BallQuery_Sort(nn.Module): def __init__(self, radius=0.3, k=64): super(BallQuery_Sort, self).__init__() self.radius = radius self.k = k def forward(self, pos, x): edge_index = radius_graph(pos, r=self.radius) row, col = edge_index neighbor_pos = pos[col] neighbor_x = x[col] dist = torch.norm(neighbor_pos - pos[row], dim=1) sorted_indices = torch.argsort(dist.view(-1, self.k), dim=1) neighbors = neighbor_x.view(-1, self.k, neighbor_x.shape[1]) return neighbors.gather(1, sorted_indices.unsqueeze(-1).expand(-1, -1, neighbor_x.shape[1])) class BallMambaModel(nn.Module): def __init__(self, in_channels=3, num_keypoints=1024): super(BallMambaModel, self).__init__() self.fps = FPS_Causal(in_channels, 64, k=num_keypoints) self.ball_query = BallQuery_Sort(radius=0.1, k=64) self.mamba = MambaBlock(64) self.decoder = nn.Linear(64, 3) def forward(self, pos): print("pos shape:", pos.shape) # 添加此行,查看 pos 的形状 x = self.fps(None, pos) x = self.ball_query(pos, x) x = x.mean(dim=1) x = self.mamba(x.unsqueeze(0)).squeeze(0) out = self.decoder(x) return out import torch import torch.nn as nn import torch.nn.functional as F class MambaBlock(nn.Module): def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2, dt_rank: int or str = "auto"): """ 完整的 MambaBlock 实现(无 LSTM,基于状态空间模型 SSM) :param d_model: 输入特征维度(通道数) :param d_state: 状态维度(SSM 中的 N) :param d_conv: 卷积核大小(用于局部依赖建模) :param expand: 扩展因子,中间维度 = d_model * expand :param dt_rank: 离散时间步秩,控制参数量,若为 "auto" 则自动计算 """ super(MambaBlock, self).__init__() self.d_model = d_model self.d_state = d_state self.d_inner = int(expand * d_model) self.dt_rank = dt_rank if dt_rank != "auto" else max(1, self.d_model // 16) # 输入投影:将输入特征升维 self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) # 卷积分支(局部建模) self.conv = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, kernel_size=d_conv, bias=True, groups=self.d_inner, padding=d_conv - 1, # 保证 causal ) # x_proj 将 x 映射到 dt、B、C self.x_proj = nn.Linear(self.d_inner, self.dt_rank + 2 * d_state, bias=False) # dt_proj 将 dt 映射到 d_inner self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) # A 和 D 参数(状态矩阵和跳跃连接) A = torch.arange(1, d_state + 1, dtype=torch.float32)[None, :].repeat(self.d_inner, 1) self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(self.d_inner)) # (d_inner, ) # 输出投影 self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """ 前向传播 :param x: 输入张量,形状 (B, L, d_model) :return: 输出张量,形状 (B, L, d_model) """ batch, seqlen, dim = x.shape # 1. 输入投影 + 拆分 xz = self.in_proj(x) # (B, L, 2 * d_inner) x, z = torch.split(xz, [self.d_inner, self.d_inner], dim=-1) # (B, L, d_inner) # 2. 卷积处理 x = x.transpose(1, 2) # (B, d_inner, L) x = self.conv(x)[:, :, :seqlen] # (B, d_inner, L) x = x.transpose(1, 2) # (B, L, d_inner) # 3. x_proj 分割出 dt、B、C x_dbl = self.x_proj(x) # (B, L, dt_rank + 2 * d_state) dt, B, C = torch.split( x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1 ) # 4. dt_proj 映射 dt dt = F.softplus(self.dt_proj(dt)) # (B, L, d_inner) # 5. 获取 A A = -torch.exp(self.A_log) # (d_inner, d_state) A = A.unsqueeze(0).unsqueeze(0) # (1, 1, d_inner, d_state) B = B.unsqueeze(dim=2) # (B, L, 1, d_state) C = C.unsqueeze(dim=2) # (B, L, 1, d_state) # 调试信息 print(f"[MambaBlock] A.shape = {A.shape}") print(f"[MambaBlock] B.shape = {B.shape}") print(f"[MambaBlock] C.shape = {C.shape}") print(f"[MambaBlock] dt.shape = {dt.shape}") states = torch.zeros(batch, self.d_inner, self.d_state, device=x.device) outputs = [] for t in range(seqlen): # 更新状态 states = states + x[:, t:t+1, :, None] * B[:, t:t+1, :, :] states = states * torch.exp(A * dt[:, t:t+1, :, None]) # 添加 dt 到状态更新 # 获取当前时间步的 C 并进行 einsum current_C = C[:, t] # (B, 1, d_state) current_C = current_C.squeeze(1) # (B, d_state) # 使用广播机制 y = torch.einsum("binc,bc->bin", states, current_C) # bc 会广播为 binc outputs.append(y) y = torch.stack(outputs, dim=1) # (B, L, d_inner) # ✅ 修复:self.D 扩展为 (1, 1, d_inner) 以便广播 y = y + x * self.D.view(1, 1, -1) # 加上跳跃连接 # 激活 + 输出 y = y * F.silu(z) out = self.out_proj(y) # 调试信息 print(f"[MambaBlock] y.shape = {y.shape}") print(f"[MambaBlock] out.shape = {out.shape}") return out # train.py import os import torch import torch.nn as nn from torch.utils.data import DataLoader from data.dataset import UnifiedPointCloudDataset from models.ballmamba import BallMambaModel from utils.loss import ChamferLoss # 👇 Windows 多进程训练必须放在 if __name__ == '__main__' 里面 if __name__ == '__main__': # 设置多进程启动方式(Windows 下推荐 spawn) torch.multiprocessing.set_start_method('spawn') # ✅ 修改为你自己的路径 ROOT_DIRS = [ r"D:\桌面\point\data1\part1", r"D:\桌面\point\data1\part2", r"D:\桌面\point\data1\part3", r"D:\桌面\point\data1\part4", r"D:\桌面\point\data1\part5", r"D:\桌面\point\data1\part6", r"D:\桌面\point\data1\part7", r"D:\桌面\point\data1\part8", r"D:\桌面\point\data1\part9", r"D:\桌面\point\data1\part10", r"D:\桌面\point\data1\part11", r"D:\桌面\point\data1\part12", r"D:\桌面\point\data1\part13", r"D:\桌面\point\data1\part14", r"D:\桌面\point\data1\part15", r"D:\桌面\point\data1\part16", r"D:\桌面\point\data1\part17", r"D:\桌面\point\data1\part18", r"D:\桌面\point\data1\part19", r"D:\桌面\point\data1\part20", ] # ✅ 创建 Dataset dataset = UnifiedPointCloudDataset( root_dirs=ROOT_DIRS, file_exts=['.ply', '.stl'], num_points=1024 ) print(f"✅ 共发现 {len(dataset)} 个点云文件,用于训练") if len(dataset) == 0: raise ValueError("⚠️ 没有发现任何点云文件,请检查路径和文件格式") # ✅ 创建 DataLoader(num_workers=0 可临时绕过问题) loader = DataLoader( dataset, batch_size=16, shuffle=True, num_workers=0, # 👈 Windows 下训练先设置为 0,后续再尝试 4 pin_memory=True ) # ✅ 模型初始化 device = torch.device("cpu") model = BallMambaModel(in_channels=3, num_keypoints=512).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = ChamferLoss().to(device) # ✅ 训练循环 for epoch in range(50): model.train() total_loss = 0 for i, points in enumerate(loader): points = points.to(device) # 输入输出一致(重构任务) recon_points = model(points) loss = criterion(recon_points, points) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() if i % 10 == 0: print(f"Epoch [{epoch+1}/50], Batch [{i+1}/{len(loader)}], Loss: {loss.item():.4f}") print(f"Epoch [{epoch+1}/50] 完成,平均 Loss: {total_loss / len(loader):.4f}") torch.save(model.state_dict(), f"models/ballmamba_epoch_{epoch+1}.pth") 上述是我的项目模型训练的代码,现在运行后出现问题C:\ProgramData\miniconda3\envs\torch\python.exe D:\桌面\point\scripts\train_model.py ✅ 共发现 907 个点云文件,用于训练 ✅ 共发现 907 个点云文件,用于训练 pos shape: torch.Size([16, 1024, 3]) [MambaBlock] A.shape = torch.Size([1, 1, 128, 16]) [MambaBlock] B.shape = torch.Size([16, 512, 1, 16]) [MambaBlock] C.shape = torch.Size([16, 512, 1, 16]) [MambaBlock] dt.shape = torch.Size([16, 512, 128]) Traceback (most recent call last): File "D:\桌面\point\scripts\train_model.py", line 76, in <module> recon_points = model(points) File "C:\ProgramData\miniconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "D:\桌面\point\models\ballmamba.py", line 63, in forward x = self.fps(None, pos) File "C:\ProgramData\miniconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "D:\桌面\point\models\ballmamba.py", line 30, in forward x_sampled = self.mamba(x_sampled) # (B, k, hidden_channels) File "C:\ProgramData\miniconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "D:\桌面\point\models\mamba_block.py", line 111, in forward y = y + x * self.D.view(1, 1, -1) # 加上跳跃连接 RuntimeError: The size of tensor a (16) must match the size of tensor b (512) at non-singleton dimension 2 进程已结束,退出代码为 1 应该如何修改代码?给我修改后的完整代码
最新发布
07-25
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值