Data-Free Network Quantization With Adversarial Knowledge Distillation

Data-Free Network Quantization With Adversarial Knowledge Distillation

1. Introduction

在本文中,我们提出了一个对抗性知识提炼框架,在无法获得原始训练数据的损失时,通过对抗性学习使最坏情况下的可能损失(最大损失)最小化。与[36]的关键区别在于,给定任何元数据,我们利用它们来约束对抗性学习框架中的发生器。为了避免额外的努力来制作新的元数据来分享,我们使用存储在批量规范化层中的统计数据来约束生成器,使其产生模仿原始训练数据的合成样本。此外,我们建议通过使用多个生成器来产生不同的合成样本。我们还通过经验表明,对多个学生同时进行对抗性KD会产生更好的结果。图1概述了拟议的无数据对抗性KD框架。

在这里插入图片描述

图1:无数据的对抗性知识提炼。我们最小化教师和学生输出之间的Kullback-Leibler(KL)分歧的最大值。在训练生成器产生对抗性图像的最大化步骤中,生成器被约束为通过匹配教师的批量归一化层的统计数据来产生与原始数据类似的合成图像。

对于模型压缩,我们在两种情况下进行了实验,(1)无数据KD和(2)无数据网络量化。与之前的工作[35, 36, 41]相比,所提出的方案在SVHN[39]、CIFAR-10、CIFAR-100[40]和Tiny-ImageNet1的残差网络[37]和wide residual networks[38]上表现出最先进的无数据KD性能。据我们所知,无数据网络量化(无数据量化感知训练)以前还没有被研究过。

我们使用TensorFlow的量化感知训练[24, 42]作为基线方案,我们评估了在各种数据集上训练的残差网络、宽残差网络和MobileNet的性能,此时量化感知训练是用我们的无数据KD框架产生的合成数据进行的。实验结果显示,与使用原始训练数据集的情况相比,拟议的无数据框架的性能损失不大。

2. Related work

Data-free KD and quantization. 无数据KD吸引了人们的兴趣,因为需要压缩预训练的模型,以便在资源有限的移动或边缘平台上部署,而由于隐私和许可问题,共享原始训练数据往往受到限制。

一些解决这个问题的早期尝试建议使用元数据,即从预训练的模型中收集的中间特征的统计数据。从一个预先训练好的模型中收集的元数据[33,34]。例如。例如,建议收集选定的中间层的激活输出的平均值和方差,并假定其为 提供,而不是原始数据集。鉴于任何 元数据,他们通过直接推断图像域中的样本,找到有助于训练学生网络的样本,这样 当它们产生类似于元数据的统计数据时 送给老师。

然而,最近的方法旨在解决这个问题,没有专门为无数据KD任务设计的元数据。在[43]中,类的相似性是从最后一个全连接层的权重计算出来的,它们被用来代替元数据。最近,有人提议使用存储在批量规范化层中的统计数据,而不需要额外的成本来制作新的元数据[41]。

另一方面,以前的一些方法引入了另一个网络,称为生成器,产生用于训练学生网络的合成样本[35, 36, 44]。他们优化生成器,使生成器的输出在送入预训练的教师时产生高精确度。引入对抗性学习是为了产生动态样本,对于这些样本,教师和学生的分类输出不匹配。在他们的分类输出中不匹配,并在这些对抗性样本上执行 对这些对抗性样本进行KD[36]。

就我们所知,关于无数据网络量化的工作很少。在[45]中提出了无数据权重量化的权重均衡和偏差校正,但没有考虑无数据激活量化。权重均衡是一个程序,通过在各层重新分配(均衡)其权重,将预训练的模型转化为量化友好的模型,使其在各层的偏差更小,量化误差更小。

由于权重量化而在激活中引入的偏差是在没有数据的情况下计算和修正的,而是基于存储在批量归一化层的统计数据。我们注意到,[45]中没有产生合成数据,[45]中也没有考虑data-free quantization-aware training。我们在表1中比较了无数据KD和量化方案。

表1:无数据KD和网络量化方案的比较,基于(1)它们如何产生合成数据和(2)它们是否依赖元数据。

在这里插入图片描述

Robust optimization. 稳健优化是优化的一个子领域,解决优化问题中的数据不确定性(例如,见[46,47])。在这个框架下,目标函数和约束函数被假定为属于某些集合,称为“不确定性集合”,目标是做出一个无论约束结果如何都可行的决策,并且对于最坏情况下的目标函数是最优的。在不提供数据的情况下,我们将无数据KD问题转化为一个鲁棒优化问题,而不确定性集是基于预先训练好的教师在其BN层使用统计信息来确定的。

Adversarial attacks. 生成愚弄预先训练模型的合成数据与对抗性攻击问题密切相关(例如,见[48])。尽管他们的目的与我们完全不同,但生成合成数据(或对抗性样本)的方法遵循类似的程序。在对抗性攻击中,还有两种方法,即:(1)直接在图像域生成敌对图像[49–51]和(2)使用生成器生成对抗图像[52–54]。

Deep image prior. 我们还注意到,由一系列卷积层组成的生成器网络可以作为一个很好的正则器,我们可以将其作为先验[55]施加于图像生成。因此,我们采用生成器,而不是添加任何先验的正则化[56],在[41]中采用这种方法来获得没有生成器的合成图像。

Generative adversarial networks (GANs). 对抗性学习在GANs中也是众所周知的[57]。GANs对图像合成问题的深入学习非常感兴趣。模式崩溃是GANs中众所周知的问题之一(例如,见[58])。克服模式崩溃的一个简单而有效的方法是引入多个发生器和/或多个鉴别器[59–62]。我们还发现,使用多个生成器和/或多个学生(在我们的案例中,一个学生充当鉴别器)有助于产生不同的样本,并避免在我们的无数据KD框架中过度拟合。

3. Data-free model compression

3.1. Knowledge distillation (KD)

tθt_{θ}tθ 是一种用于分类的通用非线性神经网络,其目的是产生一个分类概率分布Pθ(y∣x)P_θ(y|x)Pθ(yx)表示标签集C上输入x的标签y,即tθ(x)=[Pθ(y∣x)]y∈Ct_θ(x) =[P_θ(y | x)]_{y∈\mathcal C}tθ(x=[Pθ(yx)]yC。让y成为输入x的集合C上的one-hot code的ground truth标签y。网络tθt_θtθ 用用概率分布p(x, y)的标记数据集(称为训练数据集)进行预训练,如下所示:

在这里插入图片描述

其中Ep(x,y)\mathbb E_{p(x,y)}E

<think>嗯,用户现在遇到了一个错误,是关于在量化过程中没有收集到数据的问题。错误信息是“ValueError: No data is collected.”,并且还有一个警告建议在量化前运行预处理。用户之前已经询问过关于排除输出层的问题,现在这个问题可能和校准数据有关。 首先,我需要理解这个错误的原因。当使用ONNX Runtime的量化工具时,尤其是静态量化(quantize_static),需要校准数据来计算模型中各层的动态范围。如果校准数据读取器(calibration_data_reader)没有正确提供数据,就会抛出这个错误。 用户之前设置了YOLOv5CalibrationDataReader,可能是在初始化时传入了图像目录和图像大小。但可能的问题包括:路径是否正确?图像是否存在?数据读取器是否正确实现?可能用户提供的路径有误,或者数据读取器在读取时出了问题,导致没有数据被收集到。 另外,警告提到要运行预处理,可能用户的数据没有经过正确的预处理步骤,比如归一化、调整大小等。YOLOv5模型通常需要特定的预处理,比如将图像缩放为正方形、归一化到0-1范围等。如果数据读取器没有正确应用这些预处理步骤,模型在量化时可能无法正确处理输入,导致数据收集失败。 接下来,我需要考虑用户可能的解决步骤。首先检查路径是否正确,确保校准图像存在。然后检查数据读取器的实现是否正确,特别是__getitem__方法是否返回正确的数据格式。可能用户的数据读取器没有正确生成输入数据,或者批次处理有问题。 另外,预处理步骤是否正确集成到数据读取器中也很重要。可能需要用户确保在数据加载时应用了与模型训练时相同的预处理。例如,YOLOv5的预处理通常包括除以255、转换为RGB、调整大小等,这些步骤是否在数据读取器中正确实现。 用户可能没有正确实现数据读取器的接口,导致无法生成校准数据。ONNX Runtime的校准数据读取器需要实现特定的方法,比如get_next(),可能用户的自定义读取器没有正确实现这些方法,导致数据无法传递。 可能的解决方案包括:检查路径和图像文件是否存在,确保数据读取器正确生成输入数据,添加预处理步骤,确保数据格式与模型输入匹配,以及验证数据读取器的实现是否符合ONNX Runtime的要求。 还需要提醒用户关于数据量的问题,校准通常需要一定数量的图像(比如100-200张),如果用户提供的图像数量不足,也可能导致问题。此外,输入数据的形状是否正确,比如模型期望的是640x640的RGB图像,而数据读取器是否提供了这样的数据。 最后,建议用户逐步调试,比如打印数据读取器的输出,检查是否有数据生成,或者使用示例数据测试数据读取器,确保其正常工作。如果可能,提供一个简单的数据读取器示例供用户参考,帮助其排查问题。</think>--- ### 错误原因分析 当使用ONNX Runtime的`quantize_static`进行模型量化时,出现警告`Please consider to run pre-processing before quantization`和错误`ValueError: No data is collected`,通常由以下原因导致: 1. **校准数据未正确加载** - 数据读取器(`calibration_data_reader`)未生成有效的输入数据。 - 数据路径错误或校准图像未正确读取。 2. **预处理缺失** - 校准数据未经过与训练/推理时相同的预处理(如归一化、尺寸调整等)。 3. **数据格式不匹配** - 数据读取器输出的张量形状或数据类型与模型输入不匹配(例如YOLOv5输入为$[1,3,640,640]$的`float32`张量)。 --- ### 解决方案 #### 步骤1:检查校准数据路径和格式 - 确保`image_dir`路径正确,且包含校准图像(建议至少100-200张)。 - 验证图像格式是否支持(如JPEG、PNG),并检查图像能否正常解码。 #### 步骤2:完善数据预处理逻辑 YOLOv5的预处理通常包括: - 图像缩放到$640 \times 640$(保持长宽比,填充边缘)。 - 归一化到$[0,1]$范围(`像素值/255`)。 - 转换为`RGB`格式(如果训练时使用RGB)。 - 转置为$[C, H, W]$通道顺序并添加批次维度$[1,3,640,640]$。 **示例代码修正:** ```python class YOLOv5CalibrationDataReader: def __init__(self, image_dir, img_size=640): self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)] self.img_size = img_size self.current_index = 0 def preprocess(self, image): # YOLOv5标准预处理 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转RGB image = letterbox(image, self.img_size)[0] # 缩放并填充 image = image.astype(np.float32) / 255.0 # 归一化 image = np.transpose(image, (2, 0, 1)) # HWC -> CHW image = np.expand_dims(image, axis=0) # 添加批次维度 [1,3,640,640] return image def get_next(self): if self.current_index >= len(self.image_paths): return None # 数据遍历结束 image = cv2.imread(self.image_paths[self.current_index]) self.current_index += 1 return {"images": self.preprocess(image)} # 名称需与模型输入节点名称一致 ``` #### 步骤3:验证数据读取器输出 在量化前手动测试数据读取器: ```python reader = YOLOv5CalibrationDataReader(image_dir="path/to/images", img_size=640) sample = reader.get_next() print("数据形状:", sample["images"].shape) # 应为 (1, 3, 640, 640) print("数据类型:", sample["images"].dtype) # 应为 float32 ``` #### 步骤4:检查ONNX模型输入名称 确保`get_next()`返回的字典键名(如`"images"`)与模型输入节点名称一致。可通过以下代码查看模型输入: ```python import onnx model = onnx.load("yolov5s.onnx") input_name = model.graph.input[0].name print("模型输入名称:", input_name) # 通常为"images"或"input" ``` --- ### 附:完整量化代码示例 ```python from onnxruntime.quantization import quantize_static, QuantType, QuantFormat import cv2 import numpy as np import os class YOLOv5CalibrationDataReader: def __init__(self, image_dir, img_size=640): self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))] self.img_size = img_size self.current_index = 0 def letterbox(self, im, color=(114, 114, 114)): # YOLOv5的缩放填充函数 h, w = im.shape[:2] scale = min(self.img_size / h, self.img_size / w) new_h, new_w = int(h * scale), int(w * scale) im_resized = cv2.resize(im, (new_w, new_h), interpolation=cv2.INTER_LINEAR) dw = self.img_size - new_w dh = self.img_size - new_h top, bottom = dh // 2, dh - (dh // 2) left, right = dw // 2, dw - (dw // 2) im_padded = cv2.copyMakeBorder(im_resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return im_padded, (scale, (dw, dh)) def preprocess(self, image): image, _ = self.letterbox(image) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = image.astype(np.float32) / 255.0 image = np.transpose(image, (2, 0, 1)) return np.expand_dims(image, axis=0) def get_next(self): if self.current_index >= len(self.image_paths): return None image = cv2.imread(self.image_paths[self.current_index]) self.current_index += 1 return {"images": self.preprocess(image)} # 键名需匹配模型输入名称 # 初始化并验证数据读取器 calibration_reader = YOLOv5CalibrationDataReader(image_dir="path/to/calibration/images") print("校准数据数量:", len(calibration_reader.image_paths)) # 确保非空 # 执行量化 quantize_static( model_input="yolov5s.onnx", model_output="yolov5s_int8.onnx", calibration_data_reader=calibration_reader, quant_format=QuantFormat.QDQ, activation_type=QuantType.QInt8, weight_type=QuantType.QInt8, per_channel=True, optimize_model=True, nodes_to_exclude=["/model.24"] ) ``` --- ### 关键点总结 1. **预处理一致性**:校准数据必须与模型训练时预处理完全一致。 2. **数据格式验证**:确保输出张量形状为$[1,3,640,640]$,类型为`float32`。 3. **输入名称匹配**:数据读取器的输出字典键名需与ONNX模型输入节点名称一致。 4. **校准数据量**:建议使用100-200张多样化图像覆盖典型场景。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值