批输入任务中,数据长度无法整除batch_size的处理方法

在神经网络训练时,常常需要采用批输入数据的方法,为此需要设定每次输入的批数据大小batch_size,而当样本数量无法整除batch_size时,往往会丢弃掉后面的若干个样本。在实际做项目时,我曾经这么做过,对于样本数目多的数据集,这样做影响不大,但是当数据集太小时,多余的那些样本就无法用于学习更新网络,对本来就缺少数据的任务来说这样做有些不合理,如例1所示。
例1.直接丢弃后面的样本

def get_batch_data(input_list,label,batch_size,num_epochs,shuffle=False):

        """This is a list version"""
        size = len(input_list)
        num_batches_per_epoch = int(size/batch_size)
        input_list = np.array(input_list)
        for epoch in range(num_epochs):
            if shuffle:
                shuffle_indices = np.random.permutation(np.arange(size))
                shuffle_indices = shuffle_indices.astype(np.int32)
                shuffled_data = input_list[shuffle_indices]
                shuffled_label = label[shuffle_indices,:]  # 此处根据label是否为one-hot进行修改
            else:
                shuffled_data = input_list
                shuffled_label = label

        for batch_num in range(num_batches_per_epoch):
            start_index = batch_num*batch_size
            end_index = min((batch_num+1)*batch_size,size)        # 此处当size大于(batch_num+1)*batch_size时,就舍掉了后面的样本

            yield shuffled_data[start_index:end_index],shuffled_label[start_index:end_index]

为此做一些改变,主要是将最后一个batch人为地补满,当然补什么数据可以自己调节,具体代码如下:
例2.人为地将最后一个batch补满

def get_batch_data(input_list,label,batch_size,num_epochs,shuffle=False):

        """This is a list version"""
        size = len(input_list)
        num_batches_per_epoch = int(size/batch_size) + 1
        input_list = np.array(input_list)
        for epoch in range(num_epochs):
            if shuffle:
                shuffle_indices = np.random.permutation(np.arange(size))
                shuffle_indices = shuffle_indices.astype(np.int32)
                shuffled_data = input_list[shuffle_indices]
                shuffled_label = label[shuffle_indices,:]
            else:
                shuffled_data = input_list
                shuffled_label = label

        for batch_num in range(num_batches_per_epoch):
            start_index = batch_num*batch_size
            # end_index = min((batch_num+1)*batch_size,size)
            end_index = (batch_num + 1) * batch_size
            if end_index <= size:                                     # 当end_index不比样本数大时,持续产生Batch
                 yield shuffled_data[start_index:end_index],shuffled_label[start_index:end_index]
            else:                                                    # 当end_index大于样本数时,人为地补满差值
                rest_dat = shuffled_data[start_index:size]
                rest_label = shuffled_label[start_index:size]
                diff = end_index - size
                add_dat = shuffled_data[0:diff]
                add_label = shuffled_label[0:diff]
                dat = np.hstack((rest_dat,add_dat))
                label = np.vstack((rest_label,add_label))
                yield  dat,label

上述标红的地方是改变的代码,使用了numpy模块。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值