python 卷积神经网络 分类
时间: 2025-05-24 08:00:52 浏览: 12
### 使用 Python 构建卷积神经网络 (CNN) 完成分类任务
以下是基于 TensorFlow 和 Keras 的 CNN 实现分类任务的完整代码示例:
#### 数据准备阶段
在构建模型之前,通常需要加载并预处理数据集。假设我们使用的是 CIFAR-10 数据集。
```python
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
# 加载CIFAR-10数据集
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
# 归一化像素值到[0, 1]范围
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# 将标签转换为one-hot编码
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)
```
#### 模型构建阶段
通过 `Sequential` API 或者 Functional API 来定义 CNN 结构。
```python
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
model = Sequential([
# 第一层卷积层 + ReLU激活函数 + 最大池化层
Conv2D(filters=32, kernel_size=(3, 3), padding='same', input_shape=(32, 32, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
# 第二层卷积层 + ReLU激活函数 + 最大池化层
Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
# 展平层
Flatten(),
# 全连接层 + Dropout正则化
Dense(units=512, activation='relu'),
Dropout(rate=0.5),
# 输出层 + Softmax分类
Dense(units=10, activation='softmax')
])
```
此部分实现了卷积操作、ReLU 激活函数以及最大池化功能[^1]。同时,全连接层和 softmax 分类被用于最终的类别预测[^1]。
#### 编译与训练阶段
配置损失函数、优化器及评估指标,并启动训练流程。
```python
# 配置模型编译选项
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
# 开始训练模型
history = model.fit(x_train, y_train,
validation_split=0.2,
epochs=20,
batch_size=64)
# 测试模型性能
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f'Test Accuracy: {test_accuracy}')
```
这段代码展示了如何利用 Adam 作为优化算法来最小化交叉熵损失函数[^4],并通过验证集调整超参数以提高泛化能力。
#### 可视化结果
为了更好地理解模型的表现情况,可以绘制训练过程中精度变化曲线图。
```python
import matplotlib.pyplot as plt
plt.plot(history.history['accuracy'], label='train accuracy')
plt.plot(history.history['val_accuracy'], label='validation accuracy')
plt.title('Accuracy over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
```
---
###
阅读全文
相关推荐


















