深入理解Xilinx Brevitas量化神经网络框架
概述
Xilinx Brevitas是一个基于PyTorch的量化神经网络框架,专注于为FPGA和ASIC等硬件平台提供高效的量化支持。本文将深入解析Brevitas的核心概念和使用方法,帮助开发者快速掌握这一强大的量化工具。
环境准备与安装
Brevitas要求Python 3.8+和PyTorch 1.5.0+环境,可以通过pip直接安装:
pip install brevitas
基础概念:量化线性层
QuantLinear简介
brevitas.nn.QuantLinear
是torch.nn.Linear
的量化版本,属于QuantWeightBiasInputOutputLayer
(QuantWBIOL)类,支持对权重(weight)、偏置(bias)、输入(input)和输出(output)的量化。类似的量化层还包括QuantConv1d
、QuantConv2d
等。
class QuantLinear(Linear, QuantWBIOL):
def __init__(
self,
in_features: int,
out_features: int,
bias: Optional[bool] = True,
weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,
bias_quant: Optional[BiasQuantType] = None,
input_quant: Optional[ActQuantType] = None,
output_quant: Optional[ActQuantType] = None,
return_quant_tensor: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
**kwargs) -> None:
默认情况下,weight_quant=Int8WeightPerTensorFloat
表示权重使用8位有符号整数量化,而其他量化选项默认关闭。
默认权重量化示例
torch.manual_seed(0)
quant_linear = QuantLinear(2, 4, bias=True)
print(f"原始浮点权重:\n {quant_linear.weight}")
print(f"量化后权重:\n {quant_linear.quant_weight()}")
print(f"整数表示:\n {quant_linear.quant_weight().int()}")
输出显示权重被量化为8位整数,同时保留了浮点缩放因子(scale),这个缩放因子是基于权重张量的最大绝对值计算得到的,并且是可微分的。
量化类型详解
1. 固定点量化
对于需要硬件友好的定点量化(缩放因子限制为2的幂次),可以使用Int8WeightPerTensorFixedPoint
:
from brevitas.quant import Int8WeightPerTensorFixedPoint
import math
quant_linear = QuantLinear(2, 4, weight_quant=Int8WeightPerTensorFixedPoint)
print(f"量化权重:\n {quant_linear.quant_weight()}")
print(f"定点位置: {-math.log2(quant_linear.quant_weight().scale)}")
2. 二值量化
对于极致的压缩,可以使用二值量化SignedBinaryWeightPerTensorConst
:
from brevitas.quant import SignedBinaryWeightPerTensorConst
quant_linear = QuantLinear(2, 4, weight_quant=SignedBinaryWeightPerTensorConst)
print(f"二值量化权重:\n {quant_linear.quant_weight()}")
二值量化将权重限制为-α和+α两个值,默认α=0.1。
量化共享机制
Brevitas支持在不同层之间共享量化器实例(不仅仅是相同的量化配置),这意味着这些层将强制使用相同的scale、zero-point和bit-width:
# 创建第一个量化层
quant_linear1 = QuantLinear(2, 4)
print(f"共享前scale: {quant_linear1.quant_weight().scale:.4f}")
# 第二个层共享第一个层的量化器
quant_linear2 = QuantLinear(2, 4, weight_quant=quant_linear1.weight_quant)
print(f"共享后scale: {quant_linear1.quant_weight().scale:.4f}")
这种机制在需要保持多层级联时量化参数一致性的场景非常有用。
输入输出量化
输入量化
通过设置input_quant
可以量化输入:
from brevitas.quant import Int8ActPerTensorFloat
quant_linear = QuantLinear(2, 4, input_quant=Int8ActPerTensorFloat)
float_input = torch.randn(3, 2)
quant_output = quant_linear(float_input)
返回量化张量
默认返回解量化后的浮点张量,如需获取量化张量需设置return_quant_tensor=True
:
quant_linear = QuantLinear(2, 4, input_quant=Int8ActPerTensorFloat,
return_quant_tensor=True)
quant_output = quant_linear(float_input)
print(quant_output) # 返回QuantTensor对象
实际应用建议
- 渐进式量化:建议从权重量化开始,逐步添加输入/输出量化
- 量化感知训练:Brevitas支持端到端的量化感知训练
- 硬件友好配置:面向硬件部署时,优先考虑定点量化
- 精度-效率权衡:根据应用需求选择合适的位宽(从8-bit开始尝试)
通过本文的介绍,开发者应该能够掌握Brevitas的基本使用方法。在实际应用中,建议结合具体硬件平台的要求,进一步探索更复杂的量化策略和优化技巧。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考