在神经网络训练时,常常需要采用批输入数据的方法,为此需要设定每次输入的批数据大小batch_size
,而当样本数量无法整除batch_size
时,往往会丢弃掉后面的若干个样本。在实际做项目时,我曾经这么做过,对于样本数目多的数据集,这样做影响不大,但是当数据集太小时,多余的那些样本就无法用于学习更新网络,对本来就缺少数据的任务来说这样做有些不合理,如例1所示。
例1.直接丢弃后面的样本
def get_batch_data(input_list,label,batch_size,num_epochs,shuffle=False):
"""This is a list version"""
size = len(input_list)
num_batches_per_epoch = int(size/batch_size)
input_list = np.array(input_list)
for epoch in range(num_epochs):
if shuffle:
shuffle_indices = np.random.permutation(np.arange(size))
shuffle_indices = shuffle_indices.astype(np.int32)
shuffled_data = input_list[shuffle_indices]
shuffled_label = label[shuffle_indices,:] # 此处根据label是否为one-hot进行修改
else:
shuffled_data = input_list
shuffled_label = label
for batch_num in range(num_batches_per_epoch):
start_index = batch_num*batch_size
end_index = min((batch_num+1)*batch_size,size) # 此处当size大于(batch_num+1)*batch_size时,就舍掉了后面的样本
yield shuffled_data[start_index:end_index],shuffled_label[start_index:end_index]
为此做一些改变,主要是将最后一个batch人为地补满,当然补什么数据可以自己调节,具体代码如下:
例2.人为地将最后一个batch补满
def get_batch_data(input_list,label,batch_size,num_epochs,shuffle=False):
"""This is a list version"""
size = len(input_list)
num_batches_per_epoch = int(size/batch_size) + 1
input_list = np.array(input_list)
for epoch in range(num_epochs):
if shuffle:
shuffle_indices = np.random.permutation(np.arange(size))
shuffle_indices = shuffle_indices.astype(np.int32)
shuffled_data = input_list[shuffle_indices]
shuffled_label = label[shuffle_indices,:]
else:
shuffled_data = input_list
shuffled_label = label
for batch_num in range(num_batches_per_epoch):
start_index = batch_num*batch_size
# end_index = min((batch_num+1)*batch_size,size)
end_index = (batch_num + 1) * batch_size
if end_index <= size: # 当end_index不比样本数大时,持续产生Batch
yield shuffled_data[start_index:end_index],shuffled_label[start_index:end_index]
else: # 当end_index大于样本数时,人为地补满差值
rest_dat = shuffled_data[start_index:size]
rest_label = shuffled_label[start_index:size]
diff = end_index - size
add_dat = shuffled_data[0:diff]
add_label = shuffled_label[0:diff]
dat = np.hstack((rest_dat,add_dat))
label = np.vstack((rest_label,add_label))
yield dat,label
上述标红的地方是改变的代码,使用了numpy
模块。