Zygote 是 Julia 上一个实现自动微分、自动求导的包,其中 @adjoint
宏是 Zygote 接口的一个重要组成部分。使用 @adjoint
可以自定义函数的后向传播。
Pullbacks
要理解 @adjoint
首先要先理解更为底层的函数 pullback
。gradient
实际上就是 pullback
的语法糖(syntactic sugar)。
julia> y, back = Zygote.pullback(sin, 0.5)
(0.479425538604203, Zygote.var"#41#42"{Zygote.ZBack{ChainRules.var"#sin_pullback#1430"{Float64}}}(Zygote.ZBack{ChainRules.var"#sin_pullback#1430"{Float64}}(ChainRules.var"#sin_pullback#1430"{Float64}(0.8775825618903728))))
julia> y
0.479425538604203
给 pullback
输入两个参数 sin
和 0.5
分别代表要求导的函数和要求导的值,会得到两个输出:给定函数的结果 sin(0.5)
以及一个 pullback
,也就是上面代码中的 back
变量。back
对函数 sin
进行梯度计算,接受的是一个派生,并且产生新的一个变量。。从数学上讲,就是 vector-Jacobian 积的实现。其中
y
=
f
(
x
)
y=f(x)
y=f(x) 和梯度
∂
l
∂
x
\frac{\partial{l}}{\partial{x}}
∂x∂l 写为
x
ˉ
\bar{x}
xˉ,pullback
B
y
\mathcal{B}_y
By 如下计算:
x
ˉ
=
∂
l
∂
x
=
∂
l
∂
y
∂
y
∂
x
=
B
y
(
y
ˉ
)
\bar{x}=\frac{\partial l}{\partial x}=\frac{\partial l}{\partial y} \frac{\partial y}{\partial x}=\mathcal{B}_{y}(\bar{y})
xˉ=∂x∂l=∂y∂l∂x∂y=By(yˉ)
更为具体的讲,以上面的代码为例子,函数
y
=
sin
(
x
)
y=\sin(x)
y=sin(x).
∂
y
∂
x
=
cos
(
x
)
\frac{\partial y}{\partial x}=\cos (x)
∂x∂y=cos(x),所以 pullback 就为
y
ˉ
cos
(
x
)
\bar{y}\cos(x)
yˉcos(x),其中
y
ˉ
=
∂
l
∂
y
\bar{y}=\frac{\partial l}{\partial y}
yˉ=∂y∂l。换句话说,pullback(sin, x)
与 dsin(x) = (sin(x), ȳ -> (ȳ * cos(x),))
等价。
gradient
中函数
l
=
f
(
x
)
l=f(x)
l=f(x) 并且假设
l
ˉ
=
∂
l
∂
l
=
1
\bar{l}=\frac{\partial l}{\partial l}=1
lˉ=∂l∂l=1,并且将其输入到 pullback 中。在 sin
的例子中,
julia> dsin(x) = (sin, ȳ -> (ȳ * cos(x),))
dsin (generic function with 1 method)
julia> function gradsin(x)
_, back = dsin(x)
back(1)
end
gradsin (generic function with 1 method)
julia> gradsin(0.5)
(0.8775825618903728,)
julia> cos(0.5)
0.8775825618903728
julia> back(1)
(0.8775825618903728,)
个人理解,为什么前面要加一项
∂
l
∂
y
\frac{\partial l}{\partial y}
∂y∂l,这是为了实现链式法则。比如假设最终的损失是
l
l
l,函数
y
(
x
)
y(x)
y(x),要得到损失函数
l
l
l 对参数
x
x
x 的微分
∂
l
∂
x
\frac{\partial l}{\partial x}
∂x∂l,根据链式法则就是损失函数对函数
y
y
y 的微分乘以函数对参数
x
x
x 的微分,即
∂
l
∂
y
∂
y
∂
x
\frac{\partial l}{\partial y} \frac{\partial y}{\partial x}
∂y∂l∂x∂y。函数
y
y
y 的 pullback
就是损失函数对函数
y
y
y 的微分(用
y
ˉ
\bar{y}
yˉ 表示)乘以函数对
x
x
x 的微分。
对于上面的例子,pullback
函数返回的第一个结果为:假设函数
y
=
sin
(
x
)
y=\sin(x)
y=sin(x) 就是损失函数
l
l
l 时,
x
=
0.5
x=0.5
x=0.5 时的结果,即
cos
(
0.5
)
\cos(0.5)
cos(0.5),并且返回的 back
就是一个关于
∂
l
∂
y
\frac{\partial l}{\partial y}
∂y∂l 的函数,可以看成是
B
(
∂
l
∂
y
)
=
∂
l
∂
y
cos
(
0.5
)
\mathcal{B}(\frac{\partial l}{\partial y})=\frac{\partial l}{\partial y}\cos(0.5)
B(∂y∂l)=∂y∂lcos(0.5)。
假如 l = 0.5 y = 0.5 sin ( x ) l=0.5y=0.5\sin(x) l=0.5y=0.5sin(x),我们可以得到 ∂ l ∂ y = 0.5 \frac{\partial l}{\partial y}=0.5 ∂y∂l=0.5,那么 ∂ l ∂ x = B ( ∂ l ∂ y ) = B ( 0.5 ) \frac{\partial l}{\partial x}=\mathcal{B}(\frac{\partial l}{\partial y})=\mathcal{B}(0.5) ∂x∂l=B(∂y∂l)=B(0.5)。
参考: