python 构建RNN分类器对“名字-国家”分类、并使用Visdom绘制loss值和acc
前言
为什么使用RNN对文本的分类?
RNN考虑当前输入以及先前接收的输入,前文出现内容将影响接下来要出现的内容。因此RNN是文本和语音分析的理想选择。
Visdom的安装与使用:
刘二大人b站讲解视频:
提示:以下是本篇文章正文内容,下面案例可供参考
一、名字对应的国家分类
如图所示:通过训练,快速的预测各个姓名对应的国家。
二、整体程序
1.代码
代码中使用的数据集连接如下
链接:https://pan.baidu.com/s/1jCmi8Qj6lzzb4R4rUgMEew
提取码:trac
代码如下(示例):
import torch
import time
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import gzip
import csv
import matplotlib.pyplot as plt
import numpy as np
import math
from torch.nn.utils.rnn import pack_padded_sequence
# 引入可视化工具visdom
import visdom
#国家-人名进行分类
HIDDEN_SIZE = 100
BATCH_SIZE =256
N_LAYER =2
N_EPOCHS = 100
N_CHARS = 128
USE_GPU = True
# 为训练添加环境变量,之后程序产生的所有数据都会出现在这个环境变量中,visdom就可以监听数据实现可视化
# 而且在visdom中还可以按照环境变量名选择性监听,便于分类管理
viz = visdom.Visdom(env='countries_names')
# 创建线图初始点。loss是差值图,acc是精度图
viz.line([0], [0], win='train_loss', opts=dict(title='train loss'))
viz.line([0], [0], win='acc', opts=dict(title='test acc'))
#读取数据集
class NameDataset(Dataset):
def __init__(self, is_train_set = True):
# 判断是否要训练集还是测试集
filename = 'names_train.csv.gz' if is_train_set else 'names_test.csv.gz'
with gzip.open(filename, 'rt') as f:
reader = csv.reader(f)
rows = list(reader)
# 提取数据集中的名字
self.names = [row[0] for row in rows]
self.len = len(self.names)
self.countries = [row[1] for row in rows]
# set变成集合去除重复国家 sort排序 list 生成列表
self.country_list = list(sorted(set(self.countries)))
# 将列表变成词典
self.country_dict = self.getCountryDict()
self.country_n