绘制混淆矩阵怎么弄
时间: 2025-06-20 16:58:37 浏览: 16
### 如何绘制混淆矩阵
在机器学习领域,混淆矩阵是一个重要的工具,用于评估分类模型的性能。通过可视化混淆矩阵,可以更清晰地理解模型的预测准确性以及误分类的情况。
以下是使用 Python 绘制混淆矩阵的具体方法:
#### 使用 Matplotlib 和 Seaborn 库绘制混淆矩阵热力图
Matplotlib 是一个强大的绘图库,而 Seaborn 则是在 Matplotlib 的基础上构建的一个高级接口,特别适合于统计图表的绘制。下面是一段完整的代码示例,演示如何基于真实标签和预测标签生成并绘制混淆矩阵热力图。
```python
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
# 假设这是真实的类别标签和模型预测的结果
y_true = ["cat", "dog", "rabbit", "cat", "dog", "rabbit", "cat", "dog"]
y_pred = ["cat", "dog", "cat", "cat", "dog", "dog", "rabbit", "dog"]
# 计算混淆矩阵
labels = list(set(y_true)) # 获取所有的类别名
cm = confusion_matrix(y_true, y_pred, labels=labels)
# 将混淆矩阵转换为百分比形式(可选)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
# 创建热力图
plt.figure(figsize=(8, 6))
sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.title("Confusion Matrix Heatmap")
plt.ylabel("Actual Label")
plt.xlabel("Predicted Label")
# 显示图像
plt.show()
```
上述代码中,`confusion_matrix` 函数来自 `sklearn.metrics` 模块,用于计算混淆矩阵的实际数值[^1]。Seaborn 的 `heatmap` 方法则负责将这些数值转化为颜色深浅不一的热力图[^2]。
#### 可视化结果的意义
- **对角线上的值** 表示正确分类的数量或比例。
- **非对角线上的值** 表示错误分类的数量或比例。
这种表示方式使得研究人员能够快速定位哪些类别之间的区分较为困难。
#### 完整工作流中的应用实例
在一个典型的机器学习项目中,除了绘制混淆矩阵外,还可以结合其他指标一起分析模型表现。例如,在鸢尾花数据集上完成分类任务后,可以通过以下代码查看模型的整体表现[^3]:
```python
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, accuracy_score
# 加载数据
data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.3, random_state=42)
# 构建逻辑回归模型
model = LogisticRegression(max_iter=200)
model.fit(X_train, y_train)
# 预测测试集
predictions = model.predict(X_test)
# 输出准确率、分类报告和混淆矩阵
print(f"Accuracy: {accuracy_score(y_test, predictions)}\n")
print(classification_report(y_test, predictions))
# 绘制混淆矩阵
cm = confusion_matrix(y_test, predictions)
sns.heatmap(cm, annot=True, fmt="d", cmap="YlGnBu")
plt.show()
```
此代码片段展示了从数据准备到最终可视化的整个过程,有助于全面掌握模型的行为特性。
---
相关问题
阅读全文
相关推荐







