本文将介绍如何使用LSTM训练一个能够创作诗歌的模型。为了训练出效果优秀的模型,我整理了来自网络的4万首诗歌数据集。我们的模型可以直接使用预先训练好的参数,这意味着您无需从头开始训练,即可在自己的电脑上体验AI作诗的乐趣。我已经为您准备好了这些训练好的参数,让您能够轻松地在自己的设备上开始创作。本文将详细讲解如何在个人电脑上运行该模型,即使您没有机器学习方面的背景知识,也能轻松驾驭,让您的AI模型在自己的电脑上运行起来,体验AI创作诗歌的乐趣.所有的代码和资料都在仓库:https://gitee.com/yw18791995155/generate_poetry.git
秋风吹拂,窗外的树叶似灵动的舞者翩翩而舞,落日余晖将天际晕染成一片醉人的橘红。
与此同时,AI 于知识的瀚海中遨游,遍览数千篇文章后,开启了它的首次创作之旅。
在对近 4 万首唐诗深度学习之后,赋诗如下:
此诗颇具韵味,实乃勤勉研习之硕果。汲取全唐诗之精华,方成就这般非凡之能,常人岂易企及?
本博客将简要分析其中的技术细节,若有阐释未尽之处,在此诚挚欢迎诸君于评论区畅所欲言,各抒己见。先呈上仓库链接https://gitee.com/yw18791995155/generate_poetry.git
若诸位无暇详阅,不妨为该项目点亮 star 或进行 fork,诸君的每一份支持都将如熠熠星光,化作我砥砺前行之强劲动力源泉。言归正传,让我们一同开启打造AI诗人的旅程吧
01 环境配置
在开始之前,确保你的电脑已经安装了必要的依赖库:PyTorch 和 NumPy。安装命令如下:
pip install torch torchvision torchaudio numpy
一切就绪,我们可以开始了!
02 初识LSTM
长短期记忆网络 LSTM是一种特殊类型的循环神经网络(RNN),它被设计用来解决传统RNN在处理长序列数据时遇到的长期依赖性问题(梯度消失和梯度爆炸问题)。
LSTM的核心优势在于其能够学习并记住长期的信息依赖关系。这种能力使得LSTM在处理长文本内容时比普通RNN更为出色。LSTM网络中包含了四个主要的组件,它们通过门控机制来控制信息的流动:
- 遗忘门(Forget Gate):决定哪些信息应该被遗忘,不再保留在单元状态中。
- 输入门(Input Gate):决定哪些新信息将被存储在单元状态中。
- 单元状态(Cell State):携带数据穿越时间的信息带,可以看作是LSTM的“记忆”。
- 输出门(Output Gate):决定哪些信息将从单元状态输出到下一个隐藏状态。
这些门控机制使得LSTM能够有选择性地保留或遗忘信息,从而有效地捕捉和利用长期依赖性。这种设计灵感来源于对传统RNN在处理长序列时遗忘信息的挑战的回应,LSTM通过这些门控结构,使得网络能够更加灵活地处理时间序列数据。
03处理数据
接下来,首先要做的就是读取准备好的诗歌数据。然后对数据进行清洗,剔除那些包含特殊字符或长度不符合要求的诗歌。清洗完数据后,我们会为每首诗加上开始和结束的标志,确保生成的诗歌有明确的起止符号。
然后,我们会构建词典,为每个词分配一个唯一的索引,同时建立词汇到索引、索引到词汇的映射关系。最后,把每首诗转换成数字序列,这样就能让模型进行处理了。
import collections
import numpy as np
import torch
# 定义起始和结束标记
start_token = 'B'
end_token = 'E'
def process_poems(file_name):
"""
处理诗歌文件,将诗歌转换为数字序列,并构建词汇表。
:param file_name: 诗歌文件的路径
:return:
- poems_vector: 诗歌的数字序列列表
- word_to_idx: 词汇到索引的映射字典
- idx_to_word: 索引到词汇的映射列表
"""
# 初始化诗歌列表
poems = []
# 读取文件并处理每一行
with open(file_name, "r", encoding='utf-8') as f:
for line in f.readlines():
try:
# 分割标题和内容
title, content = line.strip().split(':')
content = content.replace(' ', '')
# 过滤掉包含特殊字符的诗歌
if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
start_token in content or end_token in content:
continue
# 过滤掉长度不符合要求的诗歌
if len(content) < 5 or len(content) > 79:
continue
# 添加起始和结束标记
content = start_token + content + end_token
poems.append(content)
except ValueError as e:
pass
# 统计所有单词的频率
all_words = [word for poem in poems for word in poem]
counter = collections.Counter(all_words)
words = sorted(counter.keys(), key=lambda x: counter[x], reverse=True)
# 添加空格作为填充符
words.append(' ')
words_length = len(words)
# 构建词汇到索引和索引到词汇的映射
word_to_idx = {
word: i for i, word in enumerate(words)}
idx_to_word = [word for word