遥感图像语义分割的全流程实现的核心目标是通过深度学习模型(U-Net)对遥感图像进行像素级分类(如区分背景、植被、建筑等),并进一步分析地物面积分布。以下从算法原理和代码实现进行详细说明:
1. 自定义数据集类(RemoteSensingDataset类)
本节介绍了加载遥感 RGB 图像和对应的单通道掩膜(mask)。掩膜中每个像素值为类别索引(如0=背景,1=植被,2=建筑),通过 transform 统一缩放(256×256)和归一化(ImageNet 均值/标准差),确保输入模型的一致性。
代码如下:
# 自定义数据集类(适配遥感RGB+单通道掩膜)
class RemoteSensingDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.images = sorted(os.listdir(image_dir))
self.masks = sorted(os.listdir(mask_dir))
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.images_dir, self.images[idx])
mask_path = os.path.join(self.masks_dir, self.masks[idx])
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L") # 单通道掩膜
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
return image, mask.squeeze(0).long() # 返回RGB图像和类别索引
# 数据预处理
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载数据集(示例路径需替换为实际路径)
train_dataset = RemoteSensingDataset(
image_dir="./data/train/images",
mask_dir="./data/train/masks",
transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
2. U-Net 模型架构(UNet类)
U-Net 是一种专为医学图像分割设计的卷积神经网络,但其对称的“编码器-解码器”结构使其在遥感图像分割中同样有效。代码中的 U-Net 实现包含以下关键部分:
(1)编码器(Encoder函数)
通过两层卷积(Conv2d)和激活函数(ReLU)提取图像特征,随后使用最大池化(MaxPool2d)下采样,逐步减少空间分辨率并增加特征深度(从3通道→64通道)。
(2)解码器(Decoder函数)
通过转置卷积(ConvTranspose2d)上采样恢复分辨率,再通过卷积层将特征映射到类别数(如3类),最终输出与输入图像尺寸相同的分割掩膜。
代码如下:
class UNet(nn.Module):
def __init__(self, n_channels=3, n_classes=3):
super().__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Conv2d(n_channels, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# 解码器
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(64, n_classes, kernel_size=1)
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
model = UNet(n_classes=3) # 假设有3类:背景、植被、建筑
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)