# 导入混淆矩阵 from sklearn.metrics import confusion_matrix # 填空:输出模型在测试集上的混淆矩阵
时间: 2025-06-23 20:28:11 浏览: 20
### 如何使用 `sklearn` 输出模型在测试集上的混淆矩阵
为了计算并展示分类模型在测试集上的混淆矩阵,可以利用 scikit-learn 库中的 `confusion_matrix` 函数[^2]。此过程涉及几个关键步骤的操作。
下面是一个完整的 Python 示例代码,展示了如何加载必要的库、训练一个简单的分类器以及输出该分类器针对给定测试数据的混淆矩阵:
```python
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
# 加载示例数据集
data = load_iris()
X, y = data.data, data.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 训练简单分类器 (这里以K近邻为例)
from sklearn.neighbors import KNeighborsClassifier
clf = KNeighborsClassifier(n_neighbors=3).fit(X_train, y_train)
# 预测测试集标签
y_pred = clf.predict(X_test)
# 使用 confusion_matrix 函数获取混淆矩阵
cm = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:")
print(cm)
# 可视化混淆矩阵
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", xticklabels=data.target_names, yticklabels=data.target_names)
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.title('Confusion Matrix Heatmap')
plt.show()
```
上述脚本不仅实现了基本的功能需求——即通过调用 `confusion_matrix()` 来获得预测结果与实际类别之间的对比情况;还进一步借助 Matplotlib 和 Seaborn 工具包来绘制热力图形式的可视化表示,使得理解不同类别的误判状况更加直观[^1]。
阅读全文
相关推荐



















