【pytorch】nn.linear 中为什么是y=xA^T+b

本文解释了PyTorch库中nn.linear函数采用y=xA^T+b的右乘表示线性变化的原因,考虑到输入可能是多样本行向量,以及与传统教材中列向量表示的差异。作者通过实例和图解帮助读者理解这种实现逻辑和维度处理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

我记得读教材的时候是y=Wx+b, 左乘矩阵W,这样才能表示线性变化。
但是pytorch中的nn.linear中,计算方式是y=xA^T+b,其中A是权重矩阵。
为什么右乘也能表示线性变化操作呢?因为pytorch中,照顾到输入是多个样本一起算的(第一个维度是多个样本数,所以输入默认是行向量),所以用y=xA^T+b,输出的y也是行向量。

在这里插入图片描述

我们的教材中默认输入是列向量的,而pytorch为了用户方便,输入当作列向量,维度为(batch, dim),每行是特征

m = nn.Linear(20, 30)
input = torch.randn(128, 20
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值