python机器学习ROC-AUC
时间: 2025-02-16 12:12:17 浏览: 56
### Python 中机器学习 ROC 和 AUC 的概念
ROC (Receiver Operating Characteristic) 曲线用于展示不同阈值下分类器的真阳性率(True Positive Rate, TPR)和假阳性率(False Positive Rate, FPR)[^1]。TPR 表示实际为正类的情况下预测为正的概率;FPR 则表示实际为负类却错误地预测为正的比例。
AUC (Area Under the Curve),即曲线下面积,衡量的是整个二分类模型的好坏程度。理想的分类器其 AUC 值接近于 1,意味着具有完美的区分能力;而当 AUC 接近 0.5,则表明该分类器几乎不具备任何判别力[^2]。
### 计算方法与实现方式
为了计算并绘制 ROC 及 AUC,在 Python 中通常会借助 `scikit-learn` 库中的函数来完成这一过程:
#### 导入必要的库
```python
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import roc_curve, auc
from sklearn.ensemble import RandomForestClassifier
import matplotlib.pyplot as plt
```
#### 加载数据集并预处理
这里以鸢尾花(Iris)数据为例说明多类别情况下的操作流程:
```python
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 将标签转换成二进制形式
Y = label_binarize(y, classes=[0, 1, 2])
n_classes = Y.shape[1]
# 数据分割
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=.5,
random_state=0)
```
#### 构建分类器并训练
采用随机森林作为基础估计器构建 OvR 多标签分类器:
```python
classifier = OneVsRestClassifier(RandomForestClassifier(n_estimators=100))
y_score = classifier.fit(X_train, y_train).predict_proba(X_test)
```
#### 绘制 ROC 曲线及计算 AUC
对于每一个类别分别获取对应的 FPR、TPR 并求得各自的 AUC 值:
```python
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# 宏平均 ROC 曲线及其 AUC
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
mean_tpr /= n_classes
macro_roc_auc = auc(all_fpr, mean_tpr)
plt.figure(figsize=(8, 6), dpi=100)
lw = 2
colors = ['aqua', 'darkorange', 'cornflowerblue']
for i, color in zip(range(n_classes), colors):
plt.plot(fpr[i], tpr[i],
color=color,
lw=lw,
label='ROC curve of class {0} (area = {1:0.2f})'
''.format(i, roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()
```
上述代码展示了如何针对多分类问题通过宏平均法得到整体表现指标,并给出了具体绘图的方法[^3]。
阅读全文
相关推荐
















