cnn,rnn,ganlstm,transformer
时间: 2025-03-11 21:03:57 浏览: 46
### 不同神经网络架构的特点及应用案例
#### 卷积神经网络 (CNN)
卷积神经网络专门用于处理具有网格状拓扑结构的数据,如图像和视频。CNN能够有效提取数据中的局部特征并通过层次化的方式构建更复杂的表示形式[^2]。
- **特点**: CNN利用卷积核在输入数据上滑动操作来检测空间模式,并通过池化层减少维度;参数共享机制使得模型更加高效。
- **应用场景**: 图像分类、目标检测、语义分割等领域广泛应用。例如,在自动驾驶汽车中识别交通标志或行人等物体。
```python
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
return x
```
#### 循环神经网络 (RNN) 和长短时记忆网络 (LSTM)
RNN适用于序列建模任务,因为它们能够在时间步之间传递信息。然而,标准的RNN难以捕捉长期依赖关系。为此设计了改进版本—LSTM,它引入了特殊的门控单元以更好地控制信息流[^1]。
- **特点**: LSTM增加了遗忘门、输入门以及输出门三个组件,允许选择性地保留或者丢弃某些历史状态的信息,从而解决了梯度消失问题。
- **应用场景**: 自然语言处理(NLP),比如机器翻译、情感分析、语音识别等方面表现优异。
```python
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(LSTMModel, self).__init__()
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
def forward(self, x):
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
out, _ = self.lstm(x, (h0.detach(), c0.detach()))
return out[:, -1, :]
```
#### 变压器 (Transformer)
变压器摒弃了传统的递归结构,完全基于自注意力机制工作。这不仅提高了并行计算效率,还增强了对远程上下文的理解能力。
- **特点**: 使用多头自注意力模块代替循环连接,让每个位置可以直接访问整个序列的内容,极大地提升了表达力。
- **应用场景**: 大规模预训练语言模型BERT就是基于此架构开发而成,广泛应用于各种NLP下游任务当中。
```python
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
inputs = tokenizer("Hello world", return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
```
#### 生成对抗网络 (GAN)
由两个相互竞争的部分组成:一个是负责创造样本的生成器G,另一个是用来判断真假的判别器D。两者共同进化直至达到纳什均衡点。
- **特点**: GAN可以合成逼真的图像、音频甚至是文本等内容,而且不需要标注好的监督信号就能完成学习过程。
- **应用场景**: 艺术创作领域可用来制作风格迁移作品;医学影像方面可用于增强低质量扫描件的质量。
```python
import torch.optim as optim
from torchvision.utils import save_image
def train_gan(generator, discriminator, dataloader, num_epochs=5):
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters())
optimizer_D = optim.Adam(discriminator.parameters())
for epoch in range(num_epochs):
for i, data in enumerate(dataloader):
real_images = data[0].to(device)
valid = Variable(Tensor(real_images.size(0), 1).fill_(1.0), requires_grad=False)
fake = Variable(Tensor(real_images.size(0), 1).fill_(0.0), requires_grad=False)
# Train Generator
optimizer_G.zero_grad()
z = Variable(torch.randn((real_images.shape[0], latent_dim))).to(device)
gen_imgs = generator(z)
g_loss = criterion(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# Train Discriminator
optimizer_D.zero_grad()
real_loss = criterion(discriminator(real_images), valid)
fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
batches_done = epoch * len(dataloader) + i
if batches_done % sample_interval == 0:
save_image(gen_imgs.data[:25], f"images/{batches_done}.png", nrow=5, normalize=True)
```
阅读全文
相关推荐


















