⚠️ torchsummary
的限制
torchsummary.summary()
只能用于 标准的 torch.nn.Sequential
或类 CNN 网络结构,它接受的输入是标准 torch.Tensor
,不支持 list 作为输入(比如时间序列 x_seq)。
由于我的模型是 非标准结构,接受的是两个 list 类型的输入:
x_seq = [Tensor1, Tensor2, ..., Tensor10]
edge_seq = [edge_index1, ..., edge_index10]
这超出了 torchsummary
的使用范围,它会直接报错。
✅ 正确的替代方式(推荐)
如果我要查看 网络中每一层的参数数量、输出维度,推荐的方式是手动注册 forward
hook 或使用 torchinfo
(它支持 list input)。
如果不想使用 torchinfo
,可以用 print(model)
展示结构,同时统计总参数量:
✅ 代码示例:打印 xxxx的结构和参数量
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(model) # 输出模型结构
print(f"\nTotal Trainable Parameters: {count_parameters(model):,}")
🧠 输出示例会是这样的:
DGAT(
(spatial_att): SpatialAttentionLayer(...)
(spatial_proj): Linear(in_features=64, out_features=32, bias=True)
(temporal_att): TemporalAttention(...)
(fc): Sequential(
(0): ReLU()
(1): Linear(32, 32)
(2): ReLU()
(3): Linear(32, 2)
)
)
Total Trainable Parameters: 12,345