输出网络结构,该用`torchsummary还是torchinfo呢?最终得出torchsummary.summary()用于结构化数据,torchinfo可以用于非结构化数据

⚠️ torchsummary 的限制

torchsummary.summary() 只能用于 标准的 torch.nn.Sequential 或类 CNN 网络结构,它接受的输入是标准 torch.Tensor不支持 list 作为输入(比如时间序列 x_seq)

由于我的模型是 非标准结构,接受的是两个 list 类型的输入:

x_seq = [Tensor1, Tensor2, ..., Tensor10]
edge_seq = [edge_index1, ..., edge_index10]

这超出了 torchsummary 的使用范围,它会直接报错。


✅ 正确的替代方式(推荐)

如果我要查看 网络中每一层的参数数量、输出维度,推荐的方式是手动注册 forward hook 或使用 torchinfo(它支持 list input)。

如果不想使用 torchinfo,可以用 print(model) 展示结构,同时统计总参数量:


✅ 代码示例:打印 xxxx的结构和参数量

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(model)  # 输出模型结构
print(f"\nTotal Trainable Parameters: {count_parameters(model):,}")

🧠 输出示例会是这样的:

DGAT(
  (spatial_att): SpatialAttentionLayer(...)
  (spatial_proj): Linear(in_features=64, out_features=32, bias=True)
  (temporal_att): TemporalAttention(...)
  (fc): Sequential(
    (0): ReLU()
    (1): Linear(32, 32)
    (2): ReLU()
    (3): Linear(32, 2)
  )
)

Total Trainable Parameters: 12,345

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值