TensorFlow绘制混淆矩阵图
时间: 2025-05-05 20:28:01 浏览: 24
### 使用 TensorFlow 绘制混淆矩阵图
为了使用 TensorFlow 绘制混淆矩阵图,可以采用 `tf.math.confusion_matrix` 函数来生成混淆矩阵数据,并利用 Matplotlib 或 Seaborn 库来进行可视化。下面展示了一个完整的流程,包括创建预测标签、真实标签以及最终绘制混淆矩阵图表。
#### 创建并获取预测与实际标签
首先定义两个列表分别存储模型对于测试集样本做出的类别预测结果(preds)和对应的正确答案即真是标签(labels)。这些通常来自于已经训练好的神经网络模型对一批新输入图片经过前向传播后的输出处理得到softmax概率分布取最大值得到的结果[^1]。
```python
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
# 假设这是通过模型获得的一批预测值 (batch_size, num_classes)
logits = model.predict(test_images)
# 获取最高分数索引作为预测类别
predictions = np.argmax(logits, axis=1)
# 这里 test_labels 是真实的标签数组
true_labels = test_labels.flatten()
```
#### 计算混淆矩阵
接着调用 `tf.math.confusion_matrix` 来计算给定的真实标签和预测标签之间的混淆矩阵。此函数返回的是一个二维张量表示不同类别的交叉统计情况[^2]。
```python
confusion_mtx = tf.math.confusion_matrix(true_labels, predictions).numpy()
```
#### 可视化混淆矩阵
最后一步就是将上述所得的数据绘制成热力图形式以便更直观地理解各个分类间的错误模式。这里推荐使用Seaborn库中的heatmap方法完成这一步骤[^3]。
```python
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_mtx, annot=True, fmt='d', cmap="Blues", cbar=False,
xticklabels=['Class_0', 'Class_1'], yticklabels=['Class_0', 'Class_1'])
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.title('Confusion Matrix')
plt.show()
```
以上代码片段展示了如何基于TensorFlow框架实现混淆矩阵图形化的全过程。值得注意的是,在具体应用时还需要根据实际情况调整参数设置比如颜色映射表(cmap),显示数值格式(fmt)等选项以满足特定需求。
阅读全文
相关推荐


















