python绘制混淆矩阵百分比
时间: 2025-03-26 16:30:48 浏览: 25
### 如何用 Python 绘制带有百分比的混淆矩阵
为了实现这一目标,可以利用 `sklearn` 库中的 `confusion_matrix` 函数来计算混淆矩阵,并通过 `matplotlib` 和 `seaborn` 来可视化这个矩阵。下面是一个完整的例子:
```python
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
def plot_confusion_matrix_with_percentage(y_true, y_pred, labels=None):
"""
计算并绘制带百分比标签的混淆矩阵
参数:
y_true : array-like of shape (n_samples,)
正确的目标值
y_pred : array-like of shape (n_samples,)
预测的目标值
labels : list or None, default=None
类别的名称列表,默认为None表示自动推断类别名
"""
cm = confusion_matrix(y_true, y_pred, labels=labels)
fig, ax = plt.subplots(figsize=(8, 6))
# 归一化处理得到比例
row_sums = cm.sum(axis=1).astype(float)
normalized_cm = cm / row_sums[:, np.newaxis]
# 使用热力图展示混淆矩阵
heatmap = sns.heatmap(normalized_cm, annot=True, fmt='.2%', cmap='Blues', cbar=False,
xticklabels=labels if labels is not None else True,
yticklabels=labels if labels is not None else True)
# 设置图表属性
ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top')
plt.show()
```
此函数接受真实标签 (`y_true`) 和预测标签 (`y_pred`) 的数组作为输入参数,并可选地接收类别的名称列表用于标注坐标轴。该方法会先调用 `confusion_matrix()` 得到原始计数形式的混淆矩阵;接着对其进行按行求和再除以总和的方式完成归一化操作从而转换成占比的形式;最后借助于 Seaborn 的 `heatmap()` 方法创建图形对象。
阅读全文
相关推荐















