接上文http://myhz0606.com/article/vllm_any_resolution,补充一些公式推导
1 为什么block-diagonal attention mask能避免packing时attention跨图片计算
给定多张图片在sequence维度上拼接的序列
x=[p1p2⋮pn],pi∈RLi×d,L=∑i=1nLi,x∈RL×d(1) x = \begin{bmatrix}p_1 \\ p_2 \\ \vdots \\ p_n\end{bmatrix}, p_i \in \mathbb{R}^{L_i \times d}, L=\sum_{i=1}^{n}L_i, x\in \mathbb{R}^{L\times d} \tag{1} x=p1p2⋮pn,pi∈RLi×d,L=i=1∑nLi,x∈RL×d(1)
给定Block Diagnoal Attention Mask
M=[m11m12⋯m1nm21m22⋯m2n⋮⋮⋱⋮mn1mn2⋯mnn]mij={0Li×Liifi=j−inf⋅1Li×Ljifi≠j(2) \begin{aligned} \mathcal{M} &= \begin{bmatrix}&m_{11} &m_{12} &\cdots &m_{1n} \\&m_{21} &m_{22} &\cdots &m_{2n} \\&\vdots &\vdots &\ddots &\vdots \\&m_{n1} &m_{n2} &\cdots &m_{nn} \\\end{bmatrix} & \\ m_{ij} &= \begin{cases} \mathbf{0}_{L_i \times L_i} & \mathrm{if} \quad i = j\\\mathrm{-inf}\cdot \mathbf{1}_{L_i \times L_j} & \mathrm{if} \quad i \neq j \end{cases} \end{aligned} \tag{2} Mmij=m11m21⋮mn1m12m22⋮mn2⋯⋯⋱⋯m1nm2n⋮mnn={0Li×Li−inf⋅1Li×Ljifi=jifi=j(2)
下面来计算Attention
先通过linear计算QKV,线性层时channel-mix的运算,没有token间的信息交互
Q=Linear(x)=[q1⋮qn],K=Linear(x)=[k1⋮kn],V=Linear(x)=[v1⋮vn](3) \begin{aligned}Q &= \mathrm{Linear}(x) = \begin{bmatrix}q_1 \\ \vdots \\ q_n\end{bmatrix}, K = \mathrm{Linear}(x) = \begin{bmatrix}k_1 \\ \vdots \\ k_n\end{bmatrix}, V = \mathrm{Linear}(x) = \begin{bmatrix}v_1 \\ \vdots \\ v_n\end{bmatrix} \\ \end{aligned} \tag{3} Q=Linear(x)=q1⋮qn,K=Linear(x)=k1⋮kn,V=Linear(x)=v1⋮vn(3)
再有Block Diagnoal Attention Mask 的Attention计算
Attention(x)=Softmax(QKTd+M)V=Softmax(1d[q1⋮qn][k1T⋯knT]+[m11⋯m1n⋮⋱⋮mn1⋯mnn])[v1⋮vn]=Softmax(1d[q1k1T+m11⋯q1knT+m1n⋮⋱⋮qnk1T+m11⋯qnknT+mnn])[v1⋮vn]because. Md=M=[exp{(q1k1T+m11)/d}∑j=1Lexp{(q1kjT+m1j)/d}⋯exp{(q1knT+m1n)/d}∑j=1Lexp{(q1kjT+m1j)/d}⋮⋱⋮exp{(qnk1T+mn1)/d}∑j=1Lexp{(qnkjT+mnj)/d}⋯exp{(qnknT+mnn)/d}∑j=1Lexp{(qnkjT+mnj)/d}][v1⋮vn]=[exp{(q1k1T)/d}∑j=1L1exp{(q1kjT)/d}⋯0⋮⋱⋮0⋯exp{(qnknT)/d}∑j=L−LnLexp{(qnkjT)/d}][v1⋮vn]exp0=0,exp−inf=0=[exp{(q1k1T)/d}∑j=1L1exp{(q1kjT)/d}v1⋮exp{(qnknT)/d}∑j=L−LnLexp{(qnkjT)/d}vn]=[Attention(p1)Attention(p2)⋮Attention(pn)](4) \begin{aligned} \mathrm{Attention}(x) &= \mathrm{Softmax}({\frac{QK^T}{\sqrt{d}}} + \mathcal{M})V \\&= \mathrm{Softmax}\biggr( { \frac{1}{\sqrt{d}} \begin{bmatrix}q_1 \\ \vdots \\ q_n\end{bmatrix} \begin{bmatrix}k_1^T & \cdots &k_n^T\end{bmatrix}} + \begin{bmatrix}&m_{11} &\cdots &m_{1n} \\&\vdots &\ddots &\vdots \\&m_{n1} &\cdots &m_{nn} \\\end{bmatrix} \biggr) \begin{bmatrix}v_1 \\ \vdots \\ v_n\end{bmatrix} \\&= \mathrm{Softmax}\biggr( { \frac{1}{\sqrt{d}} \begin{bmatrix} &q_1k_1^T+m_{11} & \cdots &q_1k_n^T+m_{1n} \\ &\vdots & \ddots & \vdots \\ &q_nk_1^T+m_{11} & \cdots &q_nk_n^T+m_{nn} \end{bmatrix}} \biggr) \begin{bmatrix} v_1 \\ \vdots \\ v_n\end{bmatrix} & \text{because. } \frac{\mathcal{M}}{\sqrt{d}}=\mathcal{M} \\ &= { \begin{bmatrix} & \frac{\exp\{(q_1k_1^T+m_{11})/\sqrt{d}\}}{\sum_{j=1}^{L}{\exp\{(q_1k_j^T+m_{1j})/\sqrt{d}\}}} & \cdots &\frac{\exp\{(q_1k_n^T+m_{1n})/\sqrt{d}\}}{\sum_{j=1}^{L}{\exp\{(q_1k_j^T+m_{1j})/\sqrt{d}\}}} \\ &\vdots & \ddots & \vdots \\ & \frac{\exp\{(q_nk_1^T+m_{n1})/\sqrt{d}\}}{\sum_{j=1}^{L}{\exp\{(q_nk_j^T+m_{nj})/\sqrt{d}\}}} & \cdots & \frac{\exp\{(q_nk_n^T+m_{nn})/\sqrt{d}\}}{\sum_{j=1}^{L}{\exp\{(q_nk_j^T+m_{nj})/\sqrt{d}\}}} \end{bmatrix} } \begin{bmatrix} v_1 \\ \vdots \\ v_n\end{bmatrix} \\&= { \begin{bmatrix} & \frac{\exp\{(q_1k_1^T)/\sqrt{d}\}}{\sum_{j=1}^{L_1}{\exp\{(q_1k_j^T)/\sqrt{d}\}}} & \cdots &\mathbf{0} \\ &\vdots & \ddots & \vdots \\ &\mathbf{0} & \cdots & \frac{\exp\{(q_nk_n^T)/\sqrt{d}\}}{\sum_{j=L-L_{n}}^{L}{\exp\{(q_nk_j^T)/\sqrt{d}\}}} \end{bmatrix} } \begin{bmatrix} v_1 \\ \vdots \\ v_n\end{bmatrix} & \exp{0} = 0, \exp{\mathrm{-inf}} = 0 \\&= \begin{bmatrix} \frac{\exp\{(q_1k_1^T)/\sqrt{d}\}}{\sum_{j=1}^{L_1}{\exp\{(q_1k_j^T)/\sqrt{d}\}}}v_1 \\ \vdots \\ \frac{\exp\{(q_nk_n^T)/\sqrt{d}\}}{\sum_{j=L-L_{n}}^{L}{\exp\{(q_nk_j^T)/\sqrt{d}\}}}v_n \end{bmatrix} \\&= \begin{bmatrix} \mathrm{Attention}(p_1) \\ \mathrm{Attention}(p_2) \\ \vdots \\ \mathrm{Attention}(p_n) \end{bmatrix} \\\\ \end{aligned} \tag{4} Attention(x)=Softmax(dQKT+M)V=Softmax(d1q1⋮qn[k1T⋯knT]+m11⋮mn1⋯⋱⋯m1n⋮mnn)v1⋮vn=Softmax(d1q1k1T+m11⋮qnk1T+m11⋯⋱⋯q1knT+m1n⋮qnknT+mnn)v1⋮vn=∑j=1Lexp{(q1kjT+m1j)/d}exp{(q1k1T+m11)/d}⋮∑j=1Lexp{(qnkjT+mnj)/d}exp{(qnk1T+mn1)/d}⋯⋱⋯∑j=1Lexp{(q1kjT+m1j)/d}exp{(q1knT+m1n)/d}⋮∑j=1Lexp{(qnkjT+mnj)/d}exp{(qnknT+mnn)/d}v1⋮vn=∑j=1L1exp{(q1kjT)/d}exp{(q1k1T)/d}⋮0⋯⋱⋯0⋮∑j=L−LnLexp{(qnkjT)/d}exp{(qnknT)/d}v1⋮vn=∑j=1L1exp{(q1kjT)/d}exp{(q1k1T)/d}v1⋮∑j=L−LnLexp{(qnkjT)/d}exp{(qnknT)/d}vn=Attention(p1)Attention(p2)⋮Attention(pn)because. dM=Mexp0=0,exp−inf=0(4)
从上述推导可见Block Diagnoal Attention Mask 能有效避免跨图片的token交互。
2 为什么Cross-Attention融合要引入attention mask
在cross-attention融合方案中(上文3.2节)。文本特征作为attention的query,图片特征作为key,value。由于在动态分辨率batch推理时,每张图片所得到的图片特征长度是不同的,因此在组batch时,需要进行padding补齐。为了避免padding token干扰attention score,需引入attention mask。
下面以batch中的一个sample为例,说明attention mask的工作机制:
Q为文本特征,shape为M×dM \times dM×d
K, V为图片特征, shape为 L×dL \times dL×d , 其中slice为[1:L1][1:L_1][1:L1]为非padding的特征
Q=[q1q2⋮qM]K=[k1k2⋮kL],V=[v1v2⋮vL](5) \begin{aligned}Q &= \begin{bmatrix} q_1 \\ q_2 \\ \vdots \\ q_M \end{bmatrix}K = \begin{bmatrix} k_1 \\ k_2 \\ \vdots \\ k_L \end{bmatrix}, V = \begin{bmatrix} v_1 \\ v_2 \\ \vdots \\ v_L \end{bmatrix} \\ \end{aligned} \tag{5} Q=q1q2⋮qMK=k1k2⋮kL,V=v1v2⋮vL(5)
为了保证padding位置特征不会干扰attention score,需引入的mask如下
M=[0,0,⋯ ,0⏞#L1,−inf,⋯ ,−inf⏞#L20,0,⋯ ,0,−inf,⋯ ,−inf⋯0,0,⋯ ,0,−inf,⋯ ,−inf]∈RM×Lm=[0,0,⋯ ,0⏞#L1,−inf,⋯ ,−inf⏞#L2](6) \begin{aligned} \mathcal{M} &= \begin{bmatrix} \overbrace{0, 0, \cdots, 0}^{\#L_1}, \overbrace{\mathrm{-inf}, \cdots, \mathrm{-inf}}^{\# L_2} \\ 0, 0, \cdots, 0, \mathrm{-inf}, \cdots, \mathrm{-inf} \\ \cdots \\ 0, 0, \cdots, 0, \mathrm{-inf}, \cdots, \mathrm{-inf} \\ \end{bmatrix} \in \mathbb{R}^{M \times L}\\\\m &= \begin{bmatrix} \overbrace{0, 0, \cdots, 0}^{\#L_1}, \overbrace{\mathrm{-inf}, \cdots, \mathrm{-inf}}^{\# L_2} \end{bmatrix} \end{aligned} \tag{6} Mm=0,0,⋯,0#L1,−inf,⋯,−inf#L20,0,⋯,0,−inf,⋯,−inf⋯0,0,⋯,0,−inf,⋯,−inf∈RM×L=[0,0,⋯,0#L1,−inf,⋯,−inf#L2](6)
证明:
Attention(Q,K,V)=[y1y2⋮yM]=Softmax(QKTd+M)V=Softmax(1d[q1q2⋮qM][k1T,k2T,⋯ ,kLT]+M)[v1v2⋮vL](7) \begin{aligned} \mathrm{Attention}(Q, K, V) =\begin{bmatrix}y_1 \\ y_2 \\ \vdots \\ y_M\end{bmatrix} &= \mathrm{Softmax}(\frac{QK^T}{\sqrt{d}} + \mathcal{M})V \\&= \mathrm{Softmax}( \frac{1}{\sqrt{d}} \begin{bmatrix} q_1 \\ q_2 \\ \vdots \\ q_M \end{bmatrix} \begin{bmatrix} k_1^T, k_2^T , \cdots, k_L^T \end{bmatrix} + \mathcal{M})\begin{bmatrix} v_1 \\ v_2 \\ \vdots \\ v_L \end{bmatrix} \end{aligned} \tag{7} Attention(Q,K,V)=y1y2⋮yM=Softmax(dQKT+M)V=Softmax(d1q1q2⋮qM[k1T,k2T,⋯,kLT]+M)v1v2⋮vL(7)
以yty_tyt为视角
yt=∑i=1LSoftmax(qtkiTd+m)vi=∑i=1Lexp{qtkiTd+mi}∑i=1Lexp{qtkiTd+mi}vi=1∑i=1L1exp{qtkiTd+0}+∑i=L1Lexp{qtkiTd−inf}⏟=0[∑i=1L1exp{qtkiTd+0}vi+∑i=L1Lexp{qtkiTd−inf}vi⏟=0]=1∑i=1L1exp{qtkiTd+0}[∑i=1L1exp{qtkiTd+0}vi]=∑i=1L1Softmax(qtkiTd)vi](8) \begin{aligned}y_t &= \sum_{i=1}^{L} \mathrm{Softmax} (\frac{q_t k_i^T}{\sqrt{d}} + m)v_i \\&= \sum_{i=1}^{L}\frac{\exp\{{\frac{q_t k_i^T }{\sqrt{d}}} + m_i \}}{\sum_{i=1}^{L} \exp \{\frac{q_t k_i^T}{\sqrt{d}} + m_i \}} v_i \\&= \frac{1}{{\sum_{i=1}^{L_1} \exp \{\frac{q_t k_i^T}{\sqrt{d}}+0\}}+\underbrace{{\sum_{i=L_1}^{L} \exp \{\frac{q_t k_i^T}{\sqrt{d}}-\mathrm{\inf}\}}}_{=0}}\left[\sum_{i=1}^{L_1}\exp\{{\frac{q_t k_i^T }{\sqrt{d}}} + 0 \} v_i+ \underbrace{\sum_{i=L_1}^{L}\exp\{{\frac{q_t k_i^T }{\sqrt{d}}} -\mathrm{\inf} \} v_i}_{=0}\right] \\&= \frac{1}{{\sum_{i=1}^{L_1} \exp \{\frac{q_t k_i^T}{\sqrt{d}}+0\}}}\left[\sum_{i=1}^{L_1}\exp\{{\frac{q_t k_i^T }{\sqrt{d}}} + 0 \} v_i\right] \\&= \sum_{i=1}^{L_1} \mathrm{Softmax} (\frac{q_t k_i^T}{\sqrt{d}})v_i \end{aligned} ] \tag{8} yt=i=1∑LSoftmax(dqtkiT+m)vi=i=1∑L∑i=1Lexp{dqtkiT+mi}exp{dqtkiT+mi}vi=∑i=1L1exp{dqtkiT+0}+=0i=L1∑Lexp{dqtkiT−inf}1i=1∑L1exp{dqtkiT+0}vi+=0i=L1∑Lexp{dqtkiT−inf}vi=∑i=1L1exp{dqtkiT+0}1[i=1∑L1exp{dqtkiT+0}vi]=i=1∑L1Softmax(dqtkiT)vi](8)
可见,引入式6的mask后使得:yt=∑i=1LSoftmax(qtkiTd+m)vi=∑i=1L1Softmax(qtkiTd)viy_t = \sum_{i=1}^{L} \mathrm{Softmax} (\frac{q_t k_i^T}{\sqrt{d}} + m)v_i = \sum_{i=1}^{L_1} \mathrm{Softmax} (\frac{q_t k_i^T}{\sqrt{d}})v_iyt=∑i=1LSoftmax(dqtkiT+m)vi=∑i=1L1Softmax(dqtkiT)vi,满足了计算等价。
再看看如果不加mask,结果有何差异:
不妨将不加mask的yty_tyt记作yt′y_t'yt′。以padding zero为例,即K=[k1,⋯ ,kL1⏟#L1,0,⋯ ,0⏟#L2]T,V=[v1,⋯ ,vL1⏟#L1,0,⋯ ,0⏟#L2]TK = [ \underbrace{k_1, \cdots, k_{L_1}}_{\#L_1}, \underbrace{0, \cdots, 0}_{\#L_2} ]^T, V = [ \underbrace{v_1, \cdots, v_{L_1}}_{\#L_1}, \underbrace{0, \cdots, 0}_{\#L_2} ]^TK=[#L1k1,⋯,kL1,#L20,⋯,0]T,V=[#L1v1,⋯,vL1,#L20,⋯,0]T
yt′=∑i=1LSoftmax(qtkiTd)vi=1∑i=1Lexp{qtkiTd}∑i=1Lexp{qtkiTd}vi=1∑i=1L1exp{qtkiTd}+∑i=L1Lexp{qtkiT⏞=0d⏟=L−L1=L2}[∑i=1L1exp{qtkiTd}vi+∑i=L1Lexp{qtkiTd}vi⏟=0]=1∑i=0L1exp{qtkiTd}+L2∑i=1L1exp{qtkiTd}vi=∑i=1L1exp{qtkiTd}∑i=1L1exp{qtkiTd}+L2∑i=0L1Softmax(qtkiTd)vi=∑i=1L1exp{qtkiTd}∑i=1L1exp{qtkiTd}+L2yt(9) \begin{aligned}y_t' &= \sum_{i=1}^{L} \mathrm{Softmax} (\frac{q_t k_i^T}{\sqrt{d}})v_i \\&= \frac{1}{{\sum_{i=1}^{L} \exp \{\frac{q_t k_i^T}{\sqrt{d}} \}} }\sum_{i=1}^{L}\exp\{{\frac{q_t k_i^T }{\sqrt{d}}} \}v_i\\&= \frac{1}{ \sum_{i=1}^{L_1} \exp \{\frac{q_t k_i^T}{\sqrt{d}} \} + \underbrace{ \sum_{i=L_1}^{L} \exp \{\frac{\overbrace{ q_t k_i^T}^{=0}}{\sqrt{d}} }_{=L-L_{1}=L_2} \} }\left [\sum_{i=1}^{L_1}\exp\{{\frac{q_t k_i^T }{\sqrt{d}}} \}v_i + \underbrace{ \sum_{i=L_1}^{L}\exp\{{\frac{q_t k_i^T }{\sqrt{d}}} \}v_i }_{=0}\right] \\&= \frac{1}{\sum_{i=0}^{L_1} \exp \{\frac{q_t k_i^T}{\sqrt{d}} \} + L_{2}}\sum_{i=1}^{L_1}\exp\{{\frac{q_t k_i^T }{\sqrt{d}}} \}v_i \\&= \frac{\sum_{i=1}^{L_1} \exp \{\frac{q_t k_i^T}{\sqrt{d}} \} }{\sum_{i=1}^{L_1} \exp \{\frac{q_t k_i^T}{\sqrt{d}} \} + L_{2}} \sum_{i=0}^{L_1} \mathrm{Softmax} (\frac{q_t k_i^T}{\sqrt{d}})v_i \\ &= \frac{\sum_{i=1}^{L_1} \exp \{\frac{q_t k_i^T}{\sqrt{d}} \} }{\sum_{i=1}^{L_1} \exp \{\frac{q_t k_i^T}{\sqrt{d}} \} + L_{2}}y_t\end{aligned} \tag{9} yt′=i=1∑LSoftmax(dqtkiT)vi=∑i=1Lexp{dqtkiT}1i=1∑Lexp{dqtkiT}vi=∑i=1L1exp{dqtkiT}+=L−L1=L2i=L1∑Lexp{dqtkiT=0}1i=1∑L1exp{dqtkiT}vi+=0i=L1∑Lexp{dqtkiT}vi=∑i=0L1exp{dqtkiT}+L21i=1∑L1exp{dqtkiT}vi=∑i=1L1exp{dqtkiT}+L2∑i=1L1exp{dqtkiT}i=0∑L1Softmax(dqtkiT)vi=∑i=1L1exp{dqtkiT}+L2∑i=1L1exp{dqtkiT}yt(9)
从推导结果可见,二者存在一个∑i=1L1exp{qtkiTd}∑i=1L1exp{qtkiTd}+L2\frac{\sum_{i=1}^{L_1} \exp \{\frac{q_t k_i^T}{\sqrt{d}} \} }{\sum_{i=1}^{L_1} \exp \{\frac{q_t k_i^T}{\sqrt{d}} \} + L_{2}}∑i=1L1exp{dqtkiT}+L2∑i=1L1exp{dqtkiT}系数差异。

被折叠的 条评论
为什么被折叠?



