如何在Keras中使用`fit_generator`训练模型?
时间: 2024-09-21 22:08:06 浏览: 88
在Keras中,`fit_generator`方法主要用于处理大型数据集,尤其是那些无法一次加载到内存中的数据,如图片、文本等。以下是使用`fit_generator`的基本步骤:
1. **准备生成器**:创建一个生成器函数,它每次迭代返回一个小批量的训练样本。你可以使用`keras.preprocessing.image.ImageDataGenerator`或其他自定义生成器来做到这一点。例如:
```python
from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1./255) # 对像素值做归一化
train_generator = train_datagen.flow_from_directory(
'train_data', target_size=(224, 224), batch_size=32, class_mode='categorical')
```
2. **定义模型**:确保你的模型已经被编译,并准备好接收输入数据。比如一个卷积神经网络(CNN):
```python
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)),
MaxPooling2D((2, 2)),
... # 其他层...
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
```
3. **开始训练**:现在你可以调用`fit_generator`并传入训练生成器以及可能需要的其它参数,如轮数epochs和验证生成器:
```python
model.fit_generator(train_generator, epochs=10, validation_data=val_generator)
```
注意,`validation_data`参数也是一个生成器,如果有的话,会用来评估模型性能。
阅读全文
相关推荐

















