Keras之model.fit_generator()的使用
model.fit_generator()是利用生成器,分批次向模型送入数据的方式,可以有效节省单次内存的消耗
一、使用方式
1.引入库
代码如下(示例):
from keras import models
model.fit_generator(
train_gen,
steps_per_epoch=int(len(train[0]) / batch_size),
epochs=MAX_EPOCHS,
validation_data=dev_gen,
validation_steps=int(len(dev[0]) / batch_size),
callbacks=[checkpointer, reduce_lr, stopping])
2.参数解释
代码如下(示例):
fit_generator(self, generator, steps_per_epoch, epochs=1, verbose=1, \
callbacks=None, validation_data=None, validation_steps=None,\
class_weight=None, max_q_size=10, workers=1, pickle_safe=False, initial_epoch=0)
generator:指需要训练的训练集
steps_per_epochs:是指在每个epoch中生成器执行生成数据的次数
epochs:指训练迭代的次数
verbose:日志的显示模式,取 1 时表示“进度条模式”,取2时表示“每轮一行”,取0时表示“安静模式”;
validation_data:验证集
validation_steps:指验证集的情况,类似于steps_per_epoch