附录C
仍是一期后续的公式推导及代码验证,让我们开始吧。
目录
4. ∂ ∂ A t r ( A B A T ) = A ( B + B T ) \frac{\partial }{\partial A}tr(ABA^{T})=A(B+B^T) ∂A∂tr(ABAT)=A(B+BT)
4.1 证明
对矩阵求偏导体现了线性代数的整洁性,体现了别样的优雅,希望能在证明中大家有所感受。
证明等价于
∂
∂
A
i
j
t
r
(
A
B
A
T
)
=
[
A
(
B
+
B
T
)
]
i
j
\frac{\partial }{\partial A_{ij}}tr(ABA^{T})=[A(B+B^T)]_{ij}
∂Aij∂tr(ABAT)=[A(B+BT)]ij
接下来我们对
A
B
A
T
ABA^{T}
ABAT做切分,其中
B
A
T
BA^{T}
BAT包含
A
i
j
A_{ij}
Aij的元素为其第
i
i
i列元素,故若把
A
B
A
T
ABA^{T}
ABAT看作矩阵
A
A
A与矩阵
B
A
T
BA^{T}
BAT的乘积,仅有第
i
i
i行或第
i
i
i列元素会包含我们追踪的元素
A
i
j
A_{ij}
Aij1。
因为迹是对角线元素的和,进而我们将
A
i
j
A_{ij}
Aij引起的变化部分锁定为
(
A
B
A
T
)
i
i
(ABA^{T})_{ii}
(ABAT)ii,将公式展开有
∑
k
A
i
k
(
B
A
T
)
k
i
=
∑
k
A
i
k
(
∑
m
B
k
m
A
m
i
)
=
\sum_{k}A_{ik}(BA^{T})_{ki}=\sum_{k}A_{ik}(\sum_{m}B_{km}A_{mi})=
∑kAik(BAT)ki=∑kAik(∑mBkmAmi)=
∑
k
∑
m
A
i
k
B
k
m
A
i
m
\sum_{k}\sum_{m}A_{ik}B_{km}A_{im}
∑k∑mAikBkmAim2
到这里,我们可能有点晕,仍然需要冷静紧咬住
A
i
j
A_{ij}
Aij。
于是继续刨去无关的部分,留下有关的
=
∑
k
≠
j
∑
m
≠
j
A
i
k
B
k
m
A
i
m
+
∑
m
≠
j
A
i
j
B
j
m
A
i
m
+
∑
k
≠
j
A
i
k
B
k
m
A
i
j
+
=\sum_{k\neq j}\sum_{m\neq j}A_{ik}B_{km}A_{im} +\sum_{m\neq j}A_{ij}B_{jm}A_{im} + \sum_{k\neq j}A_{ik}B_{km}A_{ij}+
=∑k=j∑m=jAikBkmAim+∑m=jAijBjmAim+∑k=jAikBkmAij+
A
i
j
2
B
j
j
A^{2}_{ij}B_{jj}
Aij2Bjj
接下来对
A
i
j
A_{ij}
Aij求偏导数,注意到第一项中涉及的元素与
A
i
j
A_{ij}
Aij无关,直接丢掉,剩余的做元素交换整理
∑
m
≠
j
A
i
m
B
j
m
+
∑
k
≠
j
A
i
k
B
k
m
+
2
A
i
j
B
j
j
\sum_{m\neq j}A_{im}B_{jm}+\sum_{k\neq j}A_{ik}B_{km}+2A_{ij}B_{jj}
∑m=jAimBjm+∑k=jAikBkm+2AijBjj
最后一项
2
A
i
j
B
j
j
2A_{ij}B_{jj}
2AijBjj拆分成两项,前两个和式一人一个,得到
∑
m
A
i
m
B
j
m
+
∑
k
A
i
k
B
k
m
=
(
A
B
T
)
i
j
+
(
A
B
)
i
j
)
=
[
A
(
B
+
B
T
)
]
i
j
\sum_{m}A_{im}B_{jm}+\sum_{k}A_{ik}B_{km}=(AB^{T})_{ij}+(AB)_{ij})=[A(B+B^T)]_{ij}
∑mAimBjm+∑kAikBkm=(ABT)ij+(AB)ij)=[A(B+BT)]ij
呼呼呼!
4.2 代码验证
from sympy import symbols, Matrix, trace, diff
import sympy as sp
# 定义符号变量
a11, a12, a21, a22 = symbols('a11 a12 a21 a22')
b11, b12, b21, b22 = symbols('b11 b12 b21 b22')
# 构建2x2矩阵A和B
A = Matrix([[a11, a12], [a21, a22]])
B = Matrix([[b11, b12], [b21, b22]])
# 计算矩阵乘积ABA^T的迹
trace_expr = trace(A * B * A.T)
# 计算左侧偏导数
d_trace_dA = Matrix([
[diff(trace_expr, a11), diff(trace_expr, a12)],
[diff(trace_expr, a21), diff(trace_expr, a22)]
])
# 计算右侧公式 A(B+B^T)
right_hand_side = A * (B + B.T)
# 打印出两个偏导数矩阵,检查一下是否一致。
sp.pprint(d_trace_dA)
sp.pprint(right_hand_side)
4.3 闲话
对矩阵求偏导实质是对矩阵中元素求偏导后,将结果重新按元素在矩阵中的位置拼成新的矩阵。没有引入新的数学知识,只是使结论更加优雅、紧凑。