需要修改的代码文件有:
- partition/partition_script.py 修改对 custom_dataset 的引用。
- learning/custom_dataset.py 根据你的网络架构和设计选择,修改模板函数。
- partition/provider.py 创建 read_custom_format 函数。添加你的数据集的颜色映射到 get_color_from_label 函数。
- learning/main.py 修改对 custom_dataset 的引用。
- learning/spg.py 修改第 212 行。
- 模型配置选项(运行时参数设置) gru_10,f_K,其中 K 是你的数据集中的类别数量,gru_10_0,f_K,使用矩阵边缘过滤器(适用于较大的数据集)
1. 由于我的数据是txt格式的,所以先转为ply格式:
代码:
import numpy as np
import os
def create_ply_with_labels(input_dir, output_dir):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for area in os.listdir(input_dir):
area_path = os.path.join(input_dir, area)
if not os.path.isdir(area_path):
continue
for tree in os.listdir(area_path):
tree_path = os.path.join(area_path, tree)
tree_file = os.path.join(tree_path, f"{
tree}.txt")
annotations_dir = os.path.join(tree_path, "Annotations")
if os.path.isfile(tree_file) and os.path.isdir(annotations_dir):
prunned_file = os.path.join(annotations_dir, "prunned_tree.txt")
remained_file = os.path.join(annotations_dir, "remained_tree.txt")
# Load data
tree_data = np.loadtxt(tree_file)
prunned_data = np.loadtxt(prunned_file)
remained_data = np.loadtxt(remained_file)
# Assign labels
prunned_data = np.hstack([prunned_data, np.ones((prunned_data.shape[0], 1), dtype=np.int32)])
remained_data = np.hstack([remained_data, np.zeros((remained_data.shape[0], 1), dtype=np.int32)])
# Combine data
labeled_data = np.vstack([prunned_data, remained_data])
# Extract xyz, rgb, and labels
xyz = labeled_data[:, :3]
rgb = labeled_data[:, 3:6].astype(np.uint8) # Ensure RGB is in uint8 format
labels = labeled_data[:, 6].astype(np.int32) # Ensure labels are integers
# Write to .ply file
output_file = os.path.join(output_dir, f"{
area}_{