keras鸾尾花MLP模型
时间: 2025-06-30 09:09:15 浏览: 9
### 使用 Keras 构建鸢尾花数据集的多层感知机(MLP)模型
以下是使用 Keras 构建鸢尾花数据集的多层感知机(MLP)模型的完整代码示例。该代码包括数据预处理、模型构建、训练和预测等步骤。
```python
# 导入必要的库
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelBinarizer
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
# 加载鸢尾花数据集
data = load_iris()
X = data.data
y = data.target
# 数据预处理
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 将标签转换为 one-hot 编码
y_one_hot = to_categorical(y)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_one_hot, test_size=0.2, random_state=42)
# 构建 MLP 模型
model = Sequential()
model.add(Dense(10, input_dim=X_train.shape[1], activation='relu')) # 隐藏层,10个神经元
model.add(Dense(10, activation='relu')) # 第二个隐藏层,10个神经元
model.add(Dense(y_one_hot.shape[1], activation='softmax')) # 输出层,3个神经元(对应3类)
# 编译模型
model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
# 训练模型
model.fit(X_train, y_train, epochs=100, batch_size=5, verbose=1, validation_split=0.2)
# 评估模型
loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Accuracy: {accuracy * 100:.2f}%")
# 预测新样本
row = np.array([[5.1, 3.5, 1.4, 0.2]])
row_scaled = scaler.transform(row)
yhat = model.predict(row_scaled)
predicted_class = np.argmax(yhat)
print('Predicted: %s (class=%d)' % (yhat, predicted_class))
```
### 说明
1. **数据加载与预处理**
使用 `sklearn.datasets.load_iris` 加载鸢尾花数据集,并通过 `StandardScaler` 对特征进行标准化[^1]。同时,将目标变量转换为 one-hot 编码形式以适应多分类任务的需求[^4]。
2. **模型构建**
使用 Keras 的 `Sequential` API 构建一个简单的多层感知机模型。模型包含两个隐藏层,每层有 10 个神经元,激活函数为 ReLU。输出层有 3 个神经元(对应鸢尾花数据集的三个类别),激活函数为 Softmax[^3]。
3. **模型编译与训练**
模型使用交叉熵损失函数(`categorical_crossentropy`)和 Adam 优化器进行编译。训练过程中,使用 20% 的数据作为验证集,确保模型不会过拟合[^4]。
4. **模型评估与预测**
在测试集上评估模型性能,并对新样本进行预测。预测结果通过 `argmax` 函数转换为具体的类别标签[^1]。
阅读全文
相关推荐


















