# data/dataset.py
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from plyfile import PlyData
import open3d as o3d
def read_ply_points(file_path):
"""
使用 plyfile 解析 .ply 文件,提取 x, y, z 字段
返回一个 numpy 数组 (N, 3)
"""
ply = PlyData.read(file_path)
# 提取 x, y, z 字段
x = ply['vertex']['x']
y = ply['vertex']['y']
z = ply['vertex']['z']
# 拼接为点云数组
points = np.vstack([x, y, z]).T
points = points.astype(np.float32)
return points
class UnifiedPointCloudDataset(Dataset):
def __init__(self, root_dirs, file_exts=['.ply', '.stl', '.obj'], num_points=1024):
"""
:param root_dirs: 包含所有数据文件夹的列表
:param file_exts: 支持的点云格式
:param num_points: 每个点云采样点数
"""
self.file_list = []
self.num_points = num_points
# 收集所有点云文件路径
for root_dir in root_dirs:
if not os.path.exists(root_dir):
raise FileNotFoundError(f"❌ 数据目录不存在: {root_dir}")
for root, _, files in os.walk(root_dir):
for file in files:
if any(file.lower().endswith(ext) for ext in file_exts):
full_path = os.path.join(root, file)
if os.path.exists(full_path): # 确保文件真实存在
self.file_list.append(full_path)
print(f"✅ 共发现 {len(self.file_list)} 个点云文件,用于训练")
if len(self.file_list) == 0:
raise ValueError("⚠️ 没有发现任何点云文件,请检查路径和文件格式")
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
path = self.file_list[idx]
ext = os.path.splitext(path)[1].lower()
try:
if ext == '.ply':
# 使用 plyfile 读取 .ply 文件
points = read_ply_points(path)
elif ext == '.stl':
# 使用 open3d 读取 STL 文件并采样成点云
mesh = o3d.io.read_triangle_mesh(path)
pcd = mesh.sample_points_uniformly(number_of_points=100000)
points = np.asarray(pcd.points)
elif ext == '.obj':
# 使用 open3d 读取 OBJ 文件
pcd = o3d.io.read_point_cloud(path)
if not pcd.has_points():
raise ValueError(f"点云为空或损坏: {path}")
points = np.asarray(pcd.points)
else:
raise ValueError(f"不支持的格式: {ext}")
# 检查点云是否为空
if len(points) < 10:
raise ValueError(f"点云为空或损坏: {path}")
# 固定点数采样
if len(points) < self.num_points:
indices = np.random.choice(len(points), self.num_points, replace=True)
else:
indices = np.random.choice(len(points), self.num_points, replace=False)
points = points[indices]
return torch.tensor(points, dtype=torch.float32)
except Exception as e:
print(f"❌ 读取失败: {path},错误: {str(e)}")
return self.__getitem__((idx + 1) % len(self.file_list))
# models/ballmamba.py
import torch
import torch.nn as nn
from torch_geometric.nn import knn_graph, radius_graph
from models.mamba_block import MambaBlock
from utils.pointcloud_utils import farthest_point_sample
class FPS_Causal(nn.Module):
def __init__(self, in_channels, hidden_channels, k=512):
super(FPS_Causal, self).__init__()
self.k = k
self.downsample = nn.Linear(in_channels, hidden_channels)
self.mamba = MambaBlock(hidden_channels)
def forward(self, x, pos):
batch_size, num_points, _ = pos.shape
idxs = []
for i in range(batch_size):
idx = farthest_point_sample(pos[i], self.k) # 现在支持 (N, 3)
idxs.append(idx)
idx = torch.stack(idxs, dim=0) # (B, k)
# 使用 gather 正确采样
x_sampled = torch.gather(pos, 1, idx.unsqueeze(-1).expand(-1, -1, 3)) # (B, k, 3)
x_sampled = self.downsample(x_sampled) # (B, k, hidden_channels)
x_sampled = self.mamba(x_sampled) # (B, k, hidden_channels)
return x_sampled
class BallQuery_Sort(nn.Module):
def __init__(self, radius=0.3, k=64):
super(BallQuery_Sort, self).__init__()
self.radius = radius
self.k = k
def forward(self, pos, x):
edge_index = radius_graph(pos, r=self.radius)
row, col = edge_index
neighbor_pos = pos[col]
neighbor_x = x[col]
dist = torch.norm(neighbor_pos - pos[row], dim=1)
sorted_indices = torch.argsort(dist.view(-1, self.k), dim=1)
neighbors = neighbor_x.view(-1, self.k, neighbor_x.shape[1])
return neighbors.gather(1, sorted_indices.unsqueeze(-1).expand(-1, -1, neighbor_x.shape[1]))
class BallMambaModel(nn.Module):
def __init__(self, in_channels=3, num_keypoints=1024):
super(BallMambaModel, self).__init__()
self.fps = FPS_Causal(in_channels, 64, k=num_keypoints)
self.ball_query = BallQuery_Sort(radius=0.1, k=64)
self.mamba = MambaBlock(64)
self.decoder = nn.Linear(64, 3)
def forward(self, pos):
print("pos shape:", pos.shape) # 添加此行,查看 pos 的形状
x = self.fps(None, pos)
x = self.ball_query(pos, x)
x = x.mean(dim=1)
x = self.mamba(x.unsqueeze(0)).squeeze(0)
out = self.decoder(x)
return out
import torch
import torch.nn as nn
import torch.nn.functional as F
class MambaBlock(nn.Module):
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2, dt_rank: int or str = "auto"):
"""
完整的 MambaBlock 实现(无 LSTM,基于状态空间模型 SSM)
:param d_model: 输入特征维度(通道数)
:param d_state: 状态维度(SSM 中的 N)
:param d_conv: 卷积核大小(用于局部依赖建模)
:param expand: 扩展因子,中间维度 = d_model * expand
:param dt_rank: 离散时间步秩,控制参数量,若为 "auto" 则自动计算
"""
super(MambaBlock, self).__init__()
self.d_model = d_model
self.d_state = d_state
self.d_inner = int(expand * d_model)
self.dt_rank = dt_rank if dt_rank != "auto" else max(1, self.d_model // 16)
# 输入投影:将输入特征升维
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
# 卷积分支(局部建模)
self.conv = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
bias=True,
groups=self.d_inner,
padding=d_conv - 1, # 保证 causal
)
# x_proj 将 x 映射到 dt、B、C
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + 2 * d_state, bias=False)
# dt_proj 将 dt 映射到 d_inner
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
# A 和 D 参数(状态矩阵和跳跃连接)
A = torch.arange(1, d_state + 1, dtype=torch.float32)[None, :].repeat(self.d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner)) # (d_inner, )
# 输出投影
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
前向传播
:param x: 输入张量,形状 (B, L, d_model)
:return: 输出张量,形状 (B, L, d_model)
"""
batch, seqlen, dim = x.shape
# 1. 输入投影 + 拆分
xz = self.in_proj(x) # (B, L, 2 * d_inner)
x, z = torch.split(xz, [self.d_inner, self.d_inner], dim=-1) # (B, L, d_inner)
# 2. 卷积处理
x = x.transpose(1, 2) # (B, d_inner, L)
x = self.conv(x)[:, :, :seqlen] # (B, d_inner, L)
x = x.transpose(1, 2) # (B, L, d_inner)
# 3. x_proj 分割出 dt、B、C
x_dbl = self.x_proj(x) # (B, L, dt_rank + 2 * d_state)
dt, B, C = torch.split(
x_dbl,
[self.dt_rank, self.d_state, self.d_state],
dim=-1
)
# 4. dt_proj 映射 dt
dt = F.softplus(self.dt_proj(dt)) # (B, L, d_inner)
# 5. 获取 A
A = -torch.exp(self.A_log) # (d_inner, d_state)
A = A.unsqueeze(0).unsqueeze(0) # (1, 1, d_inner, d_state)
B = B.unsqueeze(dim=2) # (B, L, 1, d_state)
C = C.unsqueeze(dim=2) # (B, L, 1, d_state)
# 调试信息
print(f"[MambaBlock] A.shape = {A.shape}")
print(f"[MambaBlock] B.shape = {B.shape}")
print(f"[MambaBlock] C.shape = {C.shape}")
print(f"[MambaBlock] dt.shape = {dt.shape}")
states = torch.zeros(batch, self.d_inner, self.d_state, device=x.device)
outputs = []
for t in range(seqlen):
# 更新状态
states = states + x[:, t:t+1, :, None] * B[:, t:t+1, :, :]
states = states * torch.exp(A * dt[:, t:t+1, :, None]) # 添加 dt 到状态更新
# 获取当前时间步的 C 并进行 einsum
current_C = C[:, t] # (B, 1, d_state)
current_C = current_C.squeeze(1) # (B, d_state)
# 使用广播机制
y = torch.einsum("binc,bc->bin", states, current_C) # bc 会广播为 binc
outputs.append(y)
y = torch.stack(outputs, dim=1) # (B, L, d_inner)
# ✅ 修复:self.D 扩展为 (1, 1, d_inner) 以便广播
y = y + x * self.D.view(1, 1, -1) # 加上跳跃连接
# 激活 + 输出
y = y * F.silu(z)
out = self.out_proj(y)
# 调试信息
print(f"[MambaBlock] y.shape = {y.shape}")
print(f"[MambaBlock] out.shape = {out.shape}")
return out
# train.py
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from data.dataset import UnifiedPointCloudDataset
from models.ballmamba import BallMambaModel
from utils.loss import ChamferLoss
# 👇 Windows 多进程训练必须放在 if __name__ == '__main__' 里面
if __name__ == '__main__':
# 设置多进程启动方式(Windows 下推荐 spawn)
torch.multiprocessing.set_start_method('spawn')
# ✅ 修改为你自己的路径
ROOT_DIRS = [
r"D:\桌面\point\data1\part1",
r"D:\桌面\point\data1\part2",
r"D:\桌面\point\data1\part3",
r"D:\桌面\point\data1\part4",
r"D:\桌面\point\data1\part5",
r"D:\桌面\point\data1\part6",
r"D:\桌面\point\data1\part7",
r"D:\桌面\point\data1\part8",
r"D:\桌面\point\data1\part9",
r"D:\桌面\point\data1\part10",
r"D:\桌面\point\data1\part11",
r"D:\桌面\point\data1\part12",
r"D:\桌面\point\data1\part13",
r"D:\桌面\point\data1\part14",
r"D:\桌面\point\data1\part15",
r"D:\桌面\point\data1\part16",
r"D:\桌面\point\data1\part17",
r"D:\桌面\point\data1\part18",
r"D:\桌面\point\data1\part19",
r"D:\桌面\point\data1\part20",
]
# ✅ 创建 Dataset
dataset = UnifiedPointCloudDataset(
root_dirs=ROOT_DIRS,
file_exts=['.ply', '.stl'],
num_points=1024
)
print(f"✅ 共发现 {len(dataset)} 个点云文件,用于训练")
if len(dataset) == 0:
raise ValueError("⚠️ 没有发现任何点云文件,请检查路径和文件格式")
# ✅ 创建 DataLoader(num_workers=0 可临时绕过问题)
loader = DataLoader(
dataset,
batch_size=16,
shuffle=True,
num_workers=0, # 👈 Windows 下训练先设置为 0,后续再尝试 4
pin_memory=True
)
# ✅ 模型初始化
device = torch.device("cpu")
model = BallMambaModel(in_channels=3, num_keypoints=512).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = ChamferLoss().to(device)
# ✅ 训练循环
for epoch in range(50):
model.train()
total_loss = 0
for i, points in enumerate(loader):
points = points.to(device)
# 输入输出一致(重构任务)
recon_points = model(points)
loss = criterion(recon_points, points)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if i % 10 == 0:
print(f"Epoch [{epoch+1}/50], Batch [{i+1}/{len(loader)}], Loss: {loss.item():.4f}")
print(f"Epoch [{epoch+1}/50] 完成,平均 Loss: {total_loss / len(loader):.4f}")
torch.save(model.state_dict(), f"models/ballmamba_epoch_{epoch+1}.pth")
上述是我的项目模型训练的代码,现在运行后出现问题C:\ProgramData\miniconda3\envs\torch\python.exe D:\桌面\point\scripts\train_model.py
✅ 共发现 907 个点云文件,用于训练
✅ 共发现 907 个点云文件,用于训练
pos shape: torch.Size([16, 1024, 3])
[MambaBlock] A.shape = torch.Size([1, 1, 128, 16])
[MambaBlock] B.shape = torch.Size([16, 512, 1, 16])
[MambaBlock] C.shape = torch.Size([16, 512, 1, 16])
[MambaBlock] dt.shape = torch.Size([16, 512, 128])
Traceback (most recent call last):
File "D:\桌面\point\scripts\train_model.py", line 76, in <module>
recon_points = model(points)
File "C:\ProgramData\miniconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\桌面\point\models\ballmamba.py", line 63, in forward
x = self.fps(None, pos)
File "C:\ProgramData\miniconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\桌面\point\models\ballmamba.py", line 30, in forward
x_sampled = self.mamba(x_sampled) # (B, k, hidden_channels)
File "C:\ProgramData\miniconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\桌面\point\models\mamba_block.py", line 111, in forward
y = y + x * self.D.view(1, 1, -1) # 加上跳跃连接
RuntimeError: The size of tensor a (16) must match the size of tensor b (512) at non-singleton dimension 2
进程已结束,退出代码为 1
应该如何修改代码?给我修改后的完整代码
最新发布