SPP
import torch
import torch.nn as nn
# 定义SPP类
class SPP(nn.Module):
def __init__(self, M=None):
super(SPP, self).__init__()
self.pooling_4x4 = nn.AdaptiveAvgPool2d((4, 4))
self.pooling_2x2 = nn.AdaptiveAvgPool2d((2, 2))
self.pooling_1x1 = nn.AdaptiveAvgPool2d((1, 1))
self.M = M
print(f"初始化 M:{self.M}")
def forward(self, x):
print(f"输入 x 的形状: {x.shape}")
# 进行4x4池化
x_4x4 = self.pooling_4x4(x)
print(f"4x4池化后 x_4x4 的形状: {x_4x4.shape}")
# 进行2x2池化
x_2x2 = self.pooling_2x2(x_4x4)
print(f"2x2池化后 x_2x2 的形状: {x_2x2.shape}")
# 进行1x1池化
x_1x1 = self.pooling_1x1(x_4x4)
print(f"1x1池化后 x_1x1 的形状: {x_1x1.shape}")
# 展平特征图
x_4x4_flatten = torch.flatten(x_4x4, start_dim=2, end_dim=3)
print(f"4x4池化展平后 x_4x4_flatten 的形状: {x_4x4_flatten.shape}")
x_2x2_flatten = torch.flatten(x_2x2, start_dim=2, end_dim=3)
print(f"2x2池化展平后 x_2x2_flatten 的形状: {x_2x2_flatten.shape}")
x_1x1_flatten = torch.flatten(x_1x1, start_dim=2, end_dim=3)
print(f"1x1池化展平后 x_1x1_flatten 的形状: {x_1x1_flatten.shape}")
# 根据 M 值拼接特征
if self.M == '[1,2,4]':
x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten, x_4x4_flatten), dim=2)
print(f"特征拼接后 (M=[1,2,4]) x_feature 的形状: {x_feature.shape}")
elif self.M == '[1,2]':
x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten), dim=2)
print(f"特征拼接后 (M=[1,2]) x_feature 的形状: {x_feature.shape}")
else:
raise NotImplementedError('ERROR M')
# 计算特征强度
x_strength = x_feature.permute((2, 0, 1))
pri