深度学习之用CelebA_Spoof数据集搭建一个活体检测-模型验证与测试

1. 项目概述

本文将基于CelebA_Spoof数据集,详细介绍如何对训练好的2D活体检测模型进行模型验证和测试。我们使用的代码和训练的模型来自上一篇文章:深度学习之用CelebA_Spoof数据集搭建一个活体检测-模型搭建和训练
补充,上传一个可用的训练结果:用CelebA-Spoof数据集搭建一个活体检测的训练结果,在验证集的ACC为93.47%,并转成了onnx,方便后续的使用。

2. 数据集准备

CelebA_Spoof是一个大规模人脸活体检测数据集,包含:

3. 模型训练与验证

3.1 验证指标

以下是评估指标中各项指标的解释、重要性和计算公式:

  1. 准确率(Accuracy)

    • 解释:模型预测正确的样本占总样本的比例
    • 重要性:衡量整体分类效果的基础指标
    • 公式 A c c u r a c y = T P + T N T P + T N + F P + F N Accuracy = \frac{TP + TN}{TP + TN + FP + FN} Accuracy=TP+TN+FP+FNTP+TN
  2. 精确率(Precision)

    • 解释:预测为正样本中实际为正的比例
    • 重要性:衡量模型预测正类的准确性
    • 公式 P r e c i s i o n = T P T P + F P Precision = \frac{TP}{TP + FP} Precision=TP+FPTP
  3. 召回率(Recall)

    • 解释:实际为正样本中被正确预测的比例
    • 重要性:衡量模型发现正类的能力
    • 公式 R e c a l l = T P T P + F N Recall = \frac{TP}{TP + FN} Recall=TP+FNTP
  4. F1 Score

    • 解释:精确率和召回率的调和平均数
    • 重要性:综合评估模型性能的指标
    • 公式 F 1 = 2 × P r e c i s i o n × R e c a l l P r e c i s i o n + R e c a l l F1 = 2 \times \frac{Precision \times Recall}{Precision + Recall} F1=2×Precision+RecallPrecision×Recall
  5. 假正率(FPR)

    • 解释:实际为负样本中被错误预测为正的比例
    • 重要性:衡量模型误报程度
    • 公式 F P R = F P F P + T N FPR = \frac{FP}{FP + TN} FPR=FP+TNFP
  6. ROC AUC

    • 解释:ROC曲线下面积
    • 重要性:衡量模型整体区分能力
    • 公式:计算所有阈值下的TPR和FPR曲线下面积
  7. 等错误率(EER)

    • 解释:FPR=FNR时的错误率
    • 重要性:安全关键系统的重要指标
    • 公式 E E R = F P R w h e r e F P R = 1 − T P R EER = FPR \quad where \quad FPR = 1 - TPR EER=FPRwhereFPR=1TPR
  8. PR AUC

    • 解释:精确率-召回率曲线下面积
    • 重要性:对不平衡数据更敏感
    • 公式:计算所有阈值下的Precision-Recall曲线下面积
  9. 平均精度(AP)

    • 解释:PR曲线的加权平均值
    • 重要性:综合评估精确率和召回率
    • 公式 A P = ∑ k ( R e c a l l k − R e c a l l k − 1 ) × P r e c i s i o n k AP = \sum_{k}(Recall_k - Recall_{k-1}) \times Precision_k AP=k(RecallkRecallk1)×Precisionk

这些指标在活体检测任务中各有侧重,建议重点关注:

  • 安全关键场景:FPR@1%、EER
  • 平衡数据集:Accuracy、F1
  • 不平衡数据集:PR AUC、AP

3.1 验证代码

在训练的过程中添加验证的代码:

import numpy as np
import torch
from sklearn.metrics import (
    confusion_matrix,
    roc_curve,
    auc,
    precision_recall_curve,
    average_precision_score
)

def evaluate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_scores = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            outputs = model(inputs.to('cuda'))
            probs = torch.softmax(outputs, dim=1)
            
            all_scores.extend(probs[:,1].cpu().numpy())
            all_preds.extend(probs.argmax(1).cpu().numpy())
            all_labels.extend(labels.numpy())
    
    return calculate_metrics(np.array(all_labels), 
                          np.array(all_preds), 
                          np.array(all_scores))

def calculate_metrics(y_true, y_pred, y_score):
    # 基础指标
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    
    metrics = {
        'Accuracy': (tp + tn) / (tp + tn + fp + fn),
        'Precision': tp / (tp + fp),
        'Recall': tp / (tp + fn),
        'F1': 2 * (tp / (tp + fp)) * (tp / (tp + fn)) / (tp / (tp + fp) + tp / (tp + fn)),
        'FPR': fp / (fp + tn)
    }
    
    # ROC曲线相关
    fpr, tpr, _ = roc_curve(y_true, y_score)
    metrics.update({
        'ROC_AUC': auc(fpr, tpr),
        'EER': fpr[np.argmin(np.abs(fpr - (1 - tpr)))]
    })
    
    # PR曲线相关
    precision, recall, _ = precision_recall_curve(y_true, y_score)
    metrics.update({
        'PR_AUC': auc(recall, precision),
        'AP': average_precision_score(y_true, y_score)
    })
    
    return metrics

在训练的过程中调用验证的方法:

        # 定期评估
        if epoch % args.eval_interval == 0 and gpu == 0:
            eval_results = evaluate_model(student, val_loader)
            print(f"[Epoch {epoch}] 评估结果: {eval_results}")
            with open(os.path.join(args.log_dir, "eval_metrics.log"), "a") as f:
                f.write(f"Epoch {epoch}: {eval_results}\n")

epoch为5的时候的一次验证结果:

[Epoch 5] 评估结果: {'Accuracy': 0.8008932559178205, 'Precision': 0.7954640614096301, 'Recall': 0.9650764085848538, 'F1': 0.8720999177552933, 'FPR': 0.5884360570166633, 'ROC_AUC': 0.8934977190511431, 'EER': 0.18329652680184702, 'PR_AUC': 0.9504641050256947, 'AP': 0.9504655592096921}

4. 模型测试与ONNX导出

直接给出代码

import argparse
import torch
import os
import cv2
import numpy as np
from utils.model_utils import ModelFactory
from torchvision import transforms

def parse_args():
    parser = argparse.ArgumentParser(description='模型导出ONNX')
    parser.add_argument('--model-path', type=str, required=True, 
                      help='模型检查点路径')
    parser.add_argument('--output-path', type=str, default='live_spoof.onnx',
                      help='ONNX输出路径')
    parser.add_argument('--test-dir', type=str, default=None,
                      help='测试图片目录路径')
    parser.add_argument('--test-num', type=int, default=10,
                      help='测试图片数量')
    return parser.parse_args()

def preprocess_image(img_path):
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return transform(img).unsqueeze(0)

def main():
    args = parse_args()
    
    # 加载并导出模型
    model = ModelFactory.build_student().eval()
    checkpoint = torch.load(args.model_path)
    model.load_state_dict(checkpoint['student_state_dict'])
    
    dummy_input = torch.randn(1, 3, 112, 112)
    
    torch.onnx.export(
        model,
        dummy_input,
        args.output_path,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
        opset_version=11
    )
    print(f"ONNX模型已保存到: {args.output_path}")

    # 测试转换后的模型
    if args.test_dir:
        print("\n开始测试ONNX模型...")
        import onnxruntime as ort
        ort_session = ort.InferenceSession(args.output_path)
        
        test_count = 0
        for filename in sorted(os.listdir(args.test_dir)):
            if test_count >= args.test_num:
                break
            if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                img_path = os.path.join(args.test_dir, filename)
                input_tensor = preprocess_image(img_path)
                
                # PyTorch推理
                with torch.no_grad():
                    pt_output = model(input_tensor)
                    pt_prob = torch.softmax(pt_output, dim=1)[0]
                
                # ONNX推理
                ort_inputs = {'input': input_tensor.numpy()}
                ort_output = ort_session.run(None, ort_inputs)
                ort_prob = np.exp(ort_output[0]) / np.sum(np.exp(ort_output[0]), axis=1, keepdims=True)
                ort_prob = ort_prob[0] 
                print(f"图片: {filename}")
                print(f"PyTorch结果 - 活体: {pt_prob[0]:.4f} | 伪造: {pt_prob[1]:.4f}")
                print(f"ONNX结果 - 活体: {ort_prob[0]:.4f} | 伪造: {ort_prob[1]:.4f}")
                print("-"*50)
                test_count += 1

if __name__ == '__main__':
    main()

5. 结果展示

ONNX模型已保存到: models_1/checkpoint_epoch_187_best_model.onnx

开始测试ONNX模型...
图片: 494410.png
PyTorch结果 - 活体: 0.8424 | 伪造: 0.1576
ONNX结果 - 活体: 0.8424 | 伪造: 0.1576
--------------------------------------------------
图片: 494415.png
PyTorch结果 - 活体: 0.8993 | 伪造: 0.1007
ONNX结果 - 活体: 0.8993 | 伪造: 0.1007
--------------------------------------------------
图片: 494418.png
PyTorch结果 - 活体: 0.9308 | 伪造: 0.0692
ONNX结果 - 活体: 0.9308 | 伪造: 0.0692
--------------------------------------------------
图片: 494421.png
PyTorch结果 - 活体: 0.9029 | 伪造: 0.0971
ONNX结果 - 活体: 0.9029 | 伪造: 0.0971
--------------------------------------------------
图片: 494422.png
PyTorch结果 - 活体: 0.9034 | 伪造: 0.0966
ONNX结果 - 活体: 0.9034 | 伪造: 0.0966
--------------------------------------------------
图片: 494424.png
PyTorch结果 - 活体: 0.9148 | 伪造: 0.0852
ONNX结果 - 活体: 0.9148 | 伪造: 0.0852
--------------------------------------------------

6. 总结

本文详细介绍了基于CelebA_Spoof数据集的活体检测系统实现,重点讲解了模型验证和测试环节的关键技术点。通过ONNX导出和运行时验证,确保了模型在不同平台的一致性表现。后续可以针对模型和需要进行移植和封装。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值