troch.nn.Linear
时间: 2025-05-15 19:06:25 浏览: 15
### PyTorch `nn.Linear` 的用法与实现
#### 1. 基本定义
PyTorch 中的 `torch.nn.Linear` 是一种用于全连接层(Fully Connected Layer)的操作。它接受输入张量并将其线性变换为指定大小的输出张量。其核心操作可以表示为 \( y = xA^T + b \),其中 \( A \) 和 \( b \) 分别是权重矩阵和偏置向量。
该模块继承自 `nn.Module`,因此可以通过调用 `.forward()` 方法完成前向传播计算[^4]。
---
#### 2. 参数说明
以下是 `torch.nn.Linear` 构造函数中的参数及其含义:
- **in_features**: 输入特征的数量。
- **out_features**: 输出特征的数量。
- **bias (可选)**: 如果设置为 `True`,则会添加一个偏置项,默认值为 `True`。
代码示例如下:
```python
import torch
from torch import nn
linear_layer = nn.Linear(in_features=10, out_features=5, bias=True)
input_tensor = torch.randn(3, 10) # 批次大小为3,每条数据有10个特征
output_tensor = linear_layer(input_tensor)
print(output_tensor.shape) # 结果形状应为 [3, 5]
```
上述代码创建了一个具有 10 个输入节点和 5 个输出节点的全连接层,并通过随机初始化的数据验证了它的功能。
---
#### 3. 权重与偏置的访问
在实际应用中,可能需要手动调整或查看模型的权重和偏置。这些属性可以直接通过 `.weight` 和 `.bias` 访问。
示例代码如下:
```python
weights = linear_layer.weight # 形状为 [5, 10]
biases = linear_layer.bias # 形状为 [5]
print(weights.shape) # 应打印 [5, 10]
print(biases.shape) # 应打印 [5]
```
需要注意的是,在某些情况下可能会遇到原地操作引发的问题,这涉及到 PyTorch 自动求导机制的设计细节[^3]。
---
#### 4. 实现原理
实际上,`nn.Linear` 可以被简单理解为以下形式的手写实现:
```python
class MyLinear(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super(MyLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
# 初始化权重和偏置
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
# 使用均匀分布初始化权重和偏置
nn.init.kaiming_uniform_(self.weight, a=torch.sqrt(torch.tensor(5)))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input):
return torch.addmm(self.bias, input, self.weight.t()) if self.bias is not None \
else torch.mm(input, self.weight.t())
```
这段代码展示了如何从零构建一个简单的线性层,包括权重初始化以及前向传播的核心逻辑。
---
#### 5. 注意事项
当使用 `nn.Linear` 进行训练时,需注意以下几点:
- 确保输入张量的最后一维等于 `in_features`,否则会抛出维度不匹配错误。
- 若涉及梯度更新,则避免对张量执行破坏性的原地修改操作,以免影响自动求导过程。
---
阅读全文
相关推荐















