目录
1、ROC曲线与PR曲线常用来评估二分类问题中分类器的性能。
一、ROC曲线与PR曲线理解
1、ROC曲线与PR曲线常用来评估二分类问题中分类器的性能。
分类器是将一个实例映射到一个特定类的过程。ROC分析的是二元分类模型,也就是输出结果只有两种类别的模型,例如:(阳性/阴性)(有病/没病)(垃圾邮件/非垃圾邮件)(敌军/非敌军)。(来自维基百科)
2、混淆矩阵
首先我们需要了解混淆矩阵。
以医生诊断高血压为例,只诊断有无。其中医生就是分类器,诊断就相当于分类器的预测。
那么就有四种结局:
- 真阳性(TP):诊断为有,实际上也有高血压。
- 伪阳性(FP):诊断为有,实际却没有高血压。
- 真阴性(TN):诊断为没有,实际上也没有高血压。
- 伪阴性(FN):诊断为没有,实际却有高血压。
假如有一组数据,其包含预测值和真实值,我们去计算这四种结局的数量,那么我们就可以得到混淆矩阵:
3、ROC曲线
由此我们就可以绘制ROC曲线了。ROC曲线由TPR和FPR组成,其中FPR作为横轴,TPR作为纵轴。其计算公式为:
我们把TPR叫做命中率即在所有实际为阳性的样本中,被正确地判断为阳性之比率,FPR叫做错误命中率即在所有实际为阴性的样本中,被错误地判断为阳性之比率。
比如有10个样本,其中7个是表示阳性,3个表示阴性。其中在所有阳性样本中被正确预测为阳性的有3个,在所有阴性样本中错误预测为阳性的有2个。那么TPR=3/7,FPR=2/3.那么对于一个好的分类器,肯定要正确预测越多越好。所以当FPR小,TPR大时,ROC曲线就会偏向左上角。
4、PR曲线
同样的,以上述混淆矩阵为例。关于PR曲线。它需要精确率(precision)和召回率(recall),召回率为横轴,精确率为纵轴。其公式为:
其中精确率(precision)表示在分类器预测为阳性的样本中实际为阳性的比率,召回率(recall)表示在实际为阳性的样本中被预测正确为阳性的比率。
比如有10个样本,7个阳性,3个阴性。其中在所有阳性样本中被正确预测为阳性的有5个,在所有预测中被正确预测阳性的有5个。那么P=5/7,R=5/10。对于一个好的分类器,用PR曲线来评估的话,肯定需要在R大的同时P也大.所以PR曲线越靠右上角表示分类器效果越好。
二、ROC和PR曲线python简单实现
import matplotlib.pyplot as plt
import numpy as np
# 假设您有以下预测概率和实际标签
y_pred = np.array([0.1, 0.78, 0.35, 0.96, 0.23, 0.04, 0.64, 0.4, 0.8,0.52])
y_true = np.array([0, 1, 0, 1, 1, 0 ,1 ,0 , 0, 1])
# 计算TPR和FPR
thresholds = np.sort(y_pred)#按顺序排序
print(thresholds)
tprs = []
fprs = []
for thresh in thresholds:
# 预测结果
pred = (y_pred > thresh).astype(int)#获取不同预测值
print(pred)
# 计算TP, FP, FN, TN
TP = np.sum((pred == 1) & (y_true == 1))
FP = np.sum((pred == 1) & (y_true == 0))
FN = np.sum((pred == 0) & (y_true == 1))
TN = np.sum((pred == 0) & (y_true == 0))
# 计算TPR和FPR
TPR = TP / (TP + FN)
FPR = FP / (FP + TN)
tprs.append(TPR)
fprs.append(FPR)
# 绘制ROC曲线
plt.figure(figsize=(6,6))
plt.plot(fprs, tprs, marker='o')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC curve')
plt.show()
# 计算Precision和Recall
precisions = []
recalls = []
for thresh in thresholds:
# 预测结果
pred = (y_pred > thresh).astype(int)
# 计算TP, FP, FN
TP = np.sum((pred == 1) & (y_true == 1))
FP = np.sum((pred == 1) & (y_true == 0))
FN = np.sum((pred == 0) & (y_true == 1))
# 计算Precision和Recall
precision = TP / (TP + FP)
recall = TP / (TP + FN)
precisions.append(precision)
recalls.append(recall)
# 绘制PR曲线
plt.figure(figsize=(6,6))
plt.plot(recalls, precisions, marker='o')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('PR curve')
plt.show()
在该代码中,并没有直接定义多组数据,而是通过设置概率值,通过循环模拟多组预测数据。这只是一个简单实现代码。仅供参考。
各个数据预测值。
ROC曲线
PR曲线
三、总结
这只是ROC和PR曲线的简单理解和实现,并没有太多的深入。像置信度等等内容没有去说明,以及各个曲线的意义。主要理解其用法和绘制。