标签:链式法则, 反向传播, 深度学习, PyTorch, 数学基础
大家好。今天我们来聊聊数学中的“链式法则”(Chain Rule),这是一个让很多初学者“懵逼”的概念,尤其在深度学习中,它是反向传播(backpropagation)的核心。如果你看到loss.backward()
就头大,别担心,这篇文章会从零基础一步步讲清楚。
这篇文章适合数学或深度学习新手,老鸟也可以复习。
引言:为什么链式法则这么重要?
在深度学习中,模型是层层嵌套的函数(比如神经网络的层级)。要训练模型,我们需要计算损失函数对每个参数的导数,来知道怎么调整它们。这时候,链式法则就登场了——它帮我们处理“复合函数”的求导。
简单说:如果你有一个函数套函数(f(g(x))),链式法则告诉你怎么求整体导数,而不用展开成巨长的表达式。在反向传播中,它让PyTorch等框架高效计算梯度,避免手动求导的噩梦。
如果你学过高中数学,这其实就是“复合函数求导”。但在AI中,它被放大到成千上万的参数上。别慌,咱们一步步来。
1. 链式法则的基础:单变量版本
核心公式:对于复合函数 ( y = f(g(x)) ),它的导数是:
[dydx=dydg⋅dgdx][
\frac{dy}{dx} = \frac{dy}{dg} \cdot \frac{dg}{dx}
][dxdy=dgdy⋅dxdg]
翻译成白话:整体变化率 = “外层对内层”的变化率 × “内层对输入”的变化率。
为什么这样? 想象一下,x 变一点,g(x) 变一点,然后f 再根据g的变化变一点。链式就是把这些“小变化”乘起来。
简单例子:假设 ( y = (x^2 + 1)^3 )。
- 这里,内层 g(x) = x² + 1,外层 f(g) = g³。
- 求 dy/dx:
- dg/dx = 2x
- df/dg = 3g² = 3(x² + 1)²
- 所以 dy/dx = 3(x² + 1)² * 2x = 6x(x² + 1)²
手动展开 y = (x² + 1)³ = x^6 + 3x^4 + 3x² + 1,再求导是 6x^5 + 12x^3 + 6x,也一样。但链式更简单,尤其是函数很复杂时。
小测试:试试 y = sin(2x)。答案:dy/dx = cos(2x) * 2。
2. 多变量扩展:神经网络的现实场景
现实中,神经网络不是单变量,而是多层、多参数。链式法则扩展到偏导数(partial derivatives)。
多层复合:假设 y = f(u(v(w(x)))),那么:
[dydx=dydu⋅dudv⋅dvdw⋅dwdx][
\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dv} \cdot \frac{dv}{dw} \cdot \frac{dw}{dx}
][dxdy=dudy⋅dvdu⋅dwdv⋅dxdw]
就像一条链,从输出端“反向”乘回去。
向量版本(深度学习常用):如果参数是向量,梯度是雅可比矩阵的乘积。但别怕,框架自动处理。
3. 链式法则在深度学习中的应用:反向传播
现在进入重头戏!在神经网络中,反向传播(backprop)就是链式法则的递归应用。
- 神经网络简化:一个简单网络:输入 x,经过权重 w1 到隐藏层 h = w1 * x + b1,再经过 w2 到输出 y = w2 * h + b2。损失 L = (y - target)^2。
- 前向传播:从 x 到 y 计算。
- 反向传播:从 L 开始,反向计算每个权重的梯度 dL/dw。
用链式:
- dL/dy = 2(y - target) # 损失对输出的导数
- dL/dw2 = dL/dy * dy/dw2 = dL/dy * h # 因为 y = w2 * h + b2,dy/dw2 = h
- dL/dh = dL/dy * dy/dh = dL/dy * w2
- dL/dw1 = dL/dh * dh/dw1 = dL/dh * x # 因为 h = w1 * x + b1,dh/dw1 = x
看到没?从后往前乘,就是链式法则!PyTorch的loss.backward()
就是在做这个:构建计算图,然后递归应用链式求导。
为什么高效? 手动求导需要展开所有层(指数级复杂),链式只需O(n)时间(n是层数)。
4. 代码示例:用PyTorch看链式法则在行动
咱们用PyTorch模拟一个简单网络,观察梯度计算。
import torch
# 定义变量(requires_grad=True 启用自动微分)
x = torch.tensor(2.0)
w1 = torch.tensor(3.0, requires_grad=True)
b1 = torch.tensor(1.0, requires_grad=True)
w2 = torch.tensor(4.0, requires_grad=True)
b2 = torch.tensor(1.0, requires_grad=True)
target = torch.tensor(20.0)
# 前向传播
h = w1 * x + b1 # 内层
y = w2 * h + b2 # 外层
loss = (y - target) ** 2 # 损失
# 反向传播:自动应用链式法则
loss.backward()
# 查看梯度
print("dL/dw1:", w1.grad) # 应该接近 (2*(y-target)*w2*x)
print("dL/dw2:", w2.grad) # 应该接近 (2*(y-target)*h)
运行结果(手动计算验证):
- h = 3*2 + 1 = 7
- y = 4*7 + 1 = 29
- loss = (29-20)^2 = 81
- dL/dw1 = 2*9 * 4 * 2 = 144 # 链式:dL/dy * dy/dh * dh/dw1 * (隐含的)
- dL/dw2 = 2*9 * 7 = 126
PyTorch自动算出这些,就是链式法则的功劳!
5. 常见困惑解答
- 为什么叫‘链式’? 因为像链条一样,一环扣一环,从输出反向传导。
- 和前向传播的区别? 前向是计算值,反向是计算导数。
- 多层网络呢? 链条更长,但原理一样。Transformer有上百层,也靠这个。
- 如果不明白数学,能用吗? 能!框架封装好了,但理解原理能帮你调试(如梯度消失问题)。
结语:从懵逼到自信
链式法则不是什么高深魔法,就是处理复合函数求导的工具。在深度学习中,它让反向传播成为可能,让AI高效学习。掌握了这个,你看代码时就不会迷糊了。
如果还有问题,比如“梯度消失怎么破?”或想看更多例子,评论区见!点赞、收藏、转发支持一下~ 更多数学与AI教程,关注我。