Storm of Formulae(3)

附录C

仍是一期后续的公式推导及代码验证,让我们开始吧。



4. ∂ ∂ A t r ( A B A T ) = A ( B + B T ) \frac{\partial }{\partial A}tr(ABA^{T})=A(B+B^T) Atr(ABAT)=A(B+BT)


4.1 证明

对矩阵求偏导体现了线性代数的整洁性,体现了别样的优雅,希望能在证明中大家有所感受。
证明等价于
∂ ∂ A i j t r ( A B A T ) = [ A ( B + B T ) ] i j \frac{\partial }{\partial A_{ij}}tr(ABA^{T})=[A(B+B^T)]_{ij} Aijtr(ABAT)=[A(B+BT)]ij
接下来我们对 A B A T ABA^{T} ABAT做切分,其中 B A T BA^{T} BAT包含 A i j A_{ij} Aij的元素为其第 i i i列元素,故若把 A B A T ABA^{T} ABAT看作矩阵 A A A与矩阵 B A T BA^{T} BAT的乘积,仅有第 i i i行或第 i i i列元素会包含我们追踪的元素 A i j A_{ij} Aij1
因为迹是对角线元素的和,进而我们将 A i j A_{ij} Aij引起的变化部分锁定为 ( A B A T ) i i (ABA^{T})_{ii} (ABAT)ii,将公式展开有
∑ k A i k ( B A T ) k i = ∑ k A i k ( ∑ m B k m A m i ) = \sum_{k}A_{ik}(BA^{T})_{ki}=\sum_{k}A_{ik}(\sum_{m}B_{km}A_{mi})= kAik(BAT)ki=kAik(mBkmAmi)=
∑ k ∑ m A i k B k m A i m \sum_{k}\sum_{m}A_{ik}B_{km}A_{im} kmAikBkmAim2
到这里,我们可能有点晕,仍然需要冷静紧咬住 A i j A_{ij} Aij
于是继续刨去无关的部分,留下有关的
= ∑ k ≠ j ∑ m ≠ j A i k B k m A i m + ∑ m ≠ j A i j B j m A i m + ∑ k ≠ j A i k B k m A i j + =\sum_{k\neq j}\sum_{m\neq j}A_{ik}B_{km}A_{im} +\sum_{m\neq j}A_{ij}B_{jm}A_{im} + \sum_{k\neq j}A_{ik}B_{km}A_{ij}+ =k=jm=jAikBkmAim+m=jAijBjmAim+k=jAikBkmAij+
A i j 2 B j j A^{2}_{ij}B_{jj} Aij2Bjj
接下来对 A i j A_{ij} Aij求偏导数,注意到第一项中涉及的元素与 A i j A_{ij} Aij无关,直接丢掉,剩余的做元素交换整理
∑ m ≠ j A i m B j m + ∑ k ≠ j A i k B k m + 2 A i j B j j \sum_{m\neq j}A_{im}B_{jm}+\sum_{k\neq j}A_{ik}B_{km}+2A_{ij}B_{jj} m=jAimBjm+k=jAikBkm+2AijBjj
最后一项 2 A i j B j j 2A_{ij}B_{jj} 2AijBjj拆分成两项,前两个和式一人一个,得到
∑ m A i m B j m + ∑ k A i k B k m = ( A B T ) i j + ( A B ) i j ) = [ A ( B + B T ) ] i j \sum_{m}A_{im}B_{jm}+\sum_{k}A_{ik}B_{km}=(AB^{T})_{ij}+(AB)_{ij})=[A(B+B^T)]_{ij} mAimBjm+kAikBkm=(ABT)ij+(AB)ij)=[A(B+BT)]ij
呼呼呼!


4.2 代码验证

from sympy import symbols, Matrix, trace, diff
import sympy as sp
# 定义符号变量
a11, a12, a21, a22 = symbols('a11 a12 a21 a22')
b11, b12, b21, b22 = symbols('b11 b12 b21 b22')
# 构建2x2矩阵A和B
A = Matrix([[a11, a12], [a21, a22]])
B = Matrix([[b11, b12], [b21, b22]])
# 计算矩阵乘积ABA^T的迹
trace_expr = trace(A * B * A.T)
# 计算左侧偏导数
d_trace_dA = Matrix([
    [diff(trace_expr, a11), diff(trace_expr, a12)],
    [diff(trace_expr, a21), diff(trace_expr, a22)]
])
# 计算右侧公式 A(B+B^T)
right_hand_side = A * (B + B.T)
# 打印出两个偏导数矩阵,检查一下是否一致。
sp.pprint(d_trace_dA)
sp.pprint(right_hand_side)

4.3 闲话

对矩阵求偏导实质是对矩阵中元素求偏导后,将结果重新按元素在矩阵中的位置拼成新的矩阵。没有引入新的数学知识,只是使结论更加优雅、紧凑。


  1. 我们追踪 A i j A_{ij} Aij的目的是想知道因变量 t r ( A B A T ) tr(ABA^{T}) tr(ABAT)随着 A i j A_{ij} Aij的变化而变化的幅度,这也是一种从定性的角度去理解偏导数。 ↩︎

  2. 这里使用了转置后元素的关系: A m i T = A i m A^{T}_{mi}=A_{im} AmiT=Aim ↩︎

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值