用torch.nn实现transformer
时间: 2025-07-06 20:53:20 浏览: 7
### 使用 `torch.nn` 库实现 Transformer 模型
为了构建一个基于 `torch.nn` 的 Transformer 模型,可以创建一个新的 Python 类并继承自 `nn.Module`。这个类会封装整个模型架构以及前向传播逻辑。
下面是一个简单的例子展示如何利用 PyTorch 提供的功能来搭建这样一个网络结构:
```python
import math
import torch
from torch import nn
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: Tensor, shape [seq_len, batch_size, embedding_dim]
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
class TransformerModel(nn.Module):
def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
nlayers: int, dropout: float = 0.5):
super().__init__()
self.model_type = 'Transformer'
self.pos_encoder = PositionalEncoding(d_model, dropout)
encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=d_hid, dropout=dropout, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer=encoder_layers, num_layers=nlayers)
self.encoder = nn.Embedding(ntoken, d_model)
self.d_model = d_model
self.decoder = nn.Linear(d_model, ntoken)
self.init_weights()
def init_weights(self) -> None:
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src: torch.Tensor, src_mask: torch.Tensor | None = None) -> torch.Tensor:
"""
Arguments:
src: Tensor, shape ``[batch_size, seq_len]``
src_mask: Tensor, shape ``[seq_len, seq_len]``
Returns:
output Tensor of shape ``[batch_size, seq_len, ntoken]``
"""
src = self.encoder(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = self.decoder(output)
return output
```
在这个实现中,定义了一个名为 `PositionalEncoding` 的位置编码模块[^3],这有助于给定输入序列中的标记提供关于它们相对或绝对位置的信息;接着实现了主要的 `TransformerModel` 类,其中包括嵌入层、位置编码器、多层变压器编码器和线性解码器[^4]。
对于每一层的 `nn.TransformerEncoderLayer`,指定了维度参数如隐藏单元数 (`d_hid`) 和注意力头的数量 (`nhead`) 来控制模型复杂度和性能表现。
最后,在初始化权重函数里设置了合理的初始范围以帮助优化过程更加稳定有效。
阅读全文
相关推荐


















