1. 项目概述
本文将基于CelebA_Spoof数据集,详细介绍如何对训练好的2D活体检测模型进行模型验证和测试。我们使用的代码和训练的模型来自上一篇文章:深度学习之用CelebA_Spoof数据集搭建一个活体检测-模型搭建和训练。
补充,上传一个可用的训练结果:用CelebA-Spoof数据集搭建一个活体检测的训练结果,在验证集的ACC为93.47%,并转成了onnx,方便后续的使用。
2. 数据集准备
CelebA_Spoof是一个大规模人脸活体检测数据集,包含:
- 活体样本:真实人脸图像
- 伪造样本:打印攻击、视频重放攻击等
前面的文章已经进行了数据的预处理,详见:深度学习之用CelebA_Spoof数据集搭建一个活体检测-数据处理
3. 模型训练与验证
3.1 验证指标
以下是评估指标中各项指标的解释、重要性和计算公式:
-
准确率(Accuracy)
- 解释:模型预测正确的样本占总样本的比例
- 重要性:衡量整体分类效果的基础指标
- 公式: A c c u r a c y = T P + T N T P + T N + F P + F N Accuracy = \frac{TP + TN}{TP + TN + FP + FN} Accuracy=TP+TN+FP+FNTP+TN
-
精确率(Precision)
- 解释:预测为正样本中实际为正的比例
- 重要性:衡量模型预测正类的准确性
- 公式: P r e c i s i o n = T P T P + F P Precision = \frac{TP}{TP + FP} Precision=TP+FPTP
-
召回率(Recall)
- 解释:实际为正样本中被正确预测的比例
- 重要性:衡量模型发现正类的能力
- 公式: R e c a l l = T P T P + F N Recall = \frac{TP}{TP + FN} Recall=TP+FNTP
-
F1 Score
- 解释:精确率和召回率的调和平均数
- 重要性:综合评估模型性能的指标
- 公式: F 1 = 2 × P r e c i s i o n × R e c a l l P r e c i s i o n + R e c a l l F1 = 2 \times \frac{Precision \times Recall}{Precision + Recall} F1=2×Precision+RecallPrecision×Recall
-
假正率(FPR)
- 解释:实际为负样本中被错误预测为正的比例
- 重要性:衡量模型误报程度
- 公式: F P R = F P F P + T N FPR = \frac{FP}{FP + TN} FPR=FP+TNFP
-
ROC AUC
- 解释:ROC曲线下面积
- 重要性:衡量模型整体区分能力
- 公式:计算所有阈值下的TPR和FPR曲线下面积
-
等错误率(EER)
- 解释:FPR=FNR时的错误率
- 重要性:安全关键系统的重要指标
- 公式: E E R = F P R w h e r e F P R = 1 − T P R EER = FPR \quad where \quad FPR = 1 - TPR EER=FPRwhereFPR=1−TPR
-
PR AUC
- 解释:精确率-召回率曲线下面积
- 重要性:对不平衡数据更敏感
- 公式:计算所有阈值下的Precision-Recall曲线下面积
-
平均精度(AP)
- 解释:PR曲线的加权平均值
- 重要性:综合评估精确率和召回率
- 公式: A P = ∑ k ( R e c a l l k − R e c a l l k − 1 ) × P r e c i s i o n k AP = \sum_{k}(Recall_k - Recall_{k-1}) \times Precision_k AP=∑k(Recallk−Recallk−1)×Precisionk
这些指标在活体检测任务中各有侧重,建议重点关注:
- 安全关键场景:FPR@1%、EER
- 平衡数据集:Accuracy、F1
- 不平衡数据集:PR AUC、AP
3.1 验证代码
在训练的过程中添加验证的代码:
import numpy as np
import torch
from sklearn.metrics import (
confusion_matrix,
roc_curve,
auc,
precision_recall_curve,
average_precision_score
)
def evaluate_model(model, dataloader):
model.eval()
all_preds = []
all_scores = []
all_labels = []
with torch.no_grad():
for inputs, labels in dataloader:
outputs = model(inputs.to('cuda'))
probs = torch.softmax(outputs, dim=1)
all_scores.extend(probs[:,1].cpu().numpy())
all_preds.extend(probs.argmax(1).cpu().numpy())
all_labels.extend(labels.numpy())
return calculate_metrics(np.array(all_labels),
np.array(all_preds),
np.array(all_scores))
def calculate_metrics(y_true, y_pred, y_score):
# 基础指标
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
metrics = {
'Accuracy': (tp + tn) / (tp + tn + fp + fn),
'Precision': tp / (tp + fp),
'Recall': tp / (tp + fn),
'F1': 2 * (tp / (tp + fp)) * (tp / (tp + fn)) / (tp / (tp + fp) + tp / (tp + fn)),
'FPR': fp / (fp + tn)
}
# ROC曲线相关
fpr, tpr, _ = roc_curve(y_true, y_score)
metrics.update({
'ROC_AUC': auc(fpr, tpr),
'EER': fpr[np.argmin(np.abs(fpr - (1 - tpr)))]
})
# PR曲线相关
precision, recall, _ = precision_recall_curve(y_true, y_score)
metrics.update({
'PR_AUC': auc(recall, precision),
'AP': average_precision_score(y_true, y_score)
})
return metrics
在训练的过程中调用验证的方法:
# 定期评估
if epoch % args.eval_interval == 0 and gpu == 0:
eval_results = evaluate_model(student, val_loader)
print(f"[Epoch {epoch}] 评估结果: {eval_results}")
with open(os.path.join(args.log_dir, "eval_metrics.log"), "a") as f:
f.write(f"Epoch {epoch}: {eval_results}\n")
epoch为5的时候的一次验证结果:
[Epoch 5] 评估结果: {'Accuracy': 0.8008932559178205, 'Precision': 0.7954640614096301, 'Recall': 0.9650764085848538, 'F1': 0.8720999177552933, 'FPR': 0.5884360570166633, 'ROC_AUC': 0.8934977190511431, 'EER': 0.18329652680184702, 'PR_AUC': 0.9504641050256947, 'AP': 0.9504655592096921}
4. 模型测试与ONNX导出
直接给出代码
import argparse
import torch
import os
import cv2
import numpy as np
from utils.model_utils import ModelFactory
from torchvision import transforms
def parse_args():
parser = argparse.ArgumentParser(description='模型导出ONNX')
parser.add_argument('--model-path', type=str, required=True,
help='模型检查点路径')
parser.add_argument('--output-path', type=str, default='live_spoof.onnx',
help='ONNX输出路径')
parser.add_argument('--test-dir', type=str, default=None,
help='测试图片目录路径')
parser.add_argument('--test-num', type=int, default=10,
help='测试图片数量')
return parser.parse_args()
def preprocess_image(img_path):
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((112, 112)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return transform(img).unsqueeze(0)
def main():
args = parse_args()
# 加载并导出模型
model = ModelFactory.build_student().eval()
checkpoint = torch.load(args.model_path)
model.load_state_dict(checkpoint['student_state_dict'])
dummy_input = torch.randn(1, 3, 112, 112)
torch.onnx.export(
model,
dummy_input,
args.output_path,
input_names=["input"],
output_names=["output"],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
opset_version=11
)
print(f"ONNX模型已保存到: {args.output_path}")
# 测试转换后的模型
if args.test_dir:
print("\n开始测试ONNX模型...")
import onnxruntime as ort
ort_session = ort.InferenceSession(args.output_path)
test_count = 0
for filename in sorted(os.listdir(args.test_dir)):
if test_count >= args.test_num:
break
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(args.test_dir, filename)
input_tensor = preprocess_image(img_path)
# PyTorch推理
with torch.no_grad():
pt_output = model(input_tensor)
pt_prob = torch.softmax(pt_output, dim=1)[0]
# ONNX推理
ort_inputs = {'input': input_tensor.numpy()}
ort_output = ort_session.run(None, ort_inputs)
ort_prob = np.exp(ort_output[0]) / np.sum(np.exp(ort_output[0]), axis=1, keepdims=True)
ort_prob = ort_prob[0]
print(f"图片: {filename}")
print(f"PyTorch结果 - 活体: {pt_prob[0]:.4f} | 伪造: {pt_prob[1]:.4f}")
print(f"ONNX结果 - 活体: {ort_prob[0]:.4f} | 伪造: {ort_prob[1]:.4f}")
print("-"*50)
test_count += 1
if __name__ == '__main__':
main()
5. 结果展示
ONNX模型已保存到: models_1/checkpoint_epoch_187_best_model.onnx
开始测试ONNX模型...
图片: 494410.png
PyTorch结果 - 活体: 0.8424 | 伪造: 0.1576
ONNX结果 - 活体: 0.8424 | 伪造: 0.1576
--------------------------------------------------
图片: 494415.png
PyTorch结果 - 活体: 0.8993 | 伪造: 0.1007
ONNX结果 - 活体: 0.8993 | 伪造: 0.1007
--------------------------------------------------
图片: 494418.png
PyTorch结果 - 活体: 0.9308 | 伪造: 0.0692
ONNX结果 - 活体: 0.9308 | 伪造: 0.0692
--------------------------------------------------
图片: 494421.png
PyTorch结果 - 活体: 0.9029 | 伪造: 0.0971
ONNX结果 - 活体: 0.9029 | 伪造: 0.0971
--------------------------------------------------
图片: 494422.png
PyTorch结果 - 活体: 0.9034 | 伪造: 0.0966
ONNX结果 - 活体: 0.9034 | 伪造: 0.0966
--------------------------------------------------
图片: 494424.png
PyTorch结果 - 活体: 0.9148 | 伪造: 0.0852
ONNX结果 - 活体: 0.9148 | 伪造: 0.0852
--------------------------------------------------
6. 总结
本文详细介绍了基于CelebA_Spoof数据集的活体检测系统实现,重点讲解了模型验证和测试环节的关键技术点。通过ONNX导出和运行时验证,确保了模型在不同平台的一致性表现。后续可以针对模型和需要进行移植和封装。