一、起源与定义 Deep Q-Network(深度Q网络,简称DQN)
是2013年由DeepMind团队提出的强化学习算法,首次将深度学习与传统Q-Learning结合,解决了传统强化学习在处理高维状态空间时的维度灾难问题。其核心贡献在于通过神经网络近似Q值函数,使算法能够处理图像、视频等复杂输入场景。
二、核心原理
1. Q-Learning基础
- Q值函数:
表示在状态sss下执行动作aaa后,智能体能获得的最大期望累计奖励,记为Q(s,a)Q(s,a)Q(s,a)。 - 贝尔曼方程:Q-Learning的理论基础,表达式为: Q(s,a)←Q(s,a)+α[r+γmaxa′Q(s′,a′)−Q(s,a)]Q(s,a) \leftarrow Q(s,a) + \alpha \left[r + \gamma \max_{a'} Q(s',a') - Q(s,a)\right]Q(s,a)←Q(s,a)+α[r+γa′maxQ(s′,a′)−Q(s,a)] 其中α\alphaα为学习率,γ\gammaγ为折扣因子,rrr为即时奖励,s′s's′为下一个状态。
- 目标:通过迭代更新Q值,找到最优策略π∗(s)=argmaxaQ(s,a)\pi^*(s)=\arg\max_a Q(s,a)π∗(s)=argmaxaQ(s,a)。
2. DQN对Q-Learning的改进
传统Q-Learning使用表格存储Q值,无法处理高维状态(如图像),DQN通过神经网络Q(s,a;θ)Q(s,a;\theta)Q(s,a;θ)近似Q值函数,其中θ\thetaθ为网络参数。
三、关键创新点
1. 经验回放(Experience Replay)
- 问题:传统Q-Learning中,连续样本具有强相关性,导致神经网络训练不稳定。
- 机制:
- 智能体与环境交互,将过渡样本(st,at,rt,st+1)(s_t, a_t, r_t, s_{t+1})(st,at,rt,st+1)存储到经验回放缓冲区DDD;
- 训练时从DDD中随机采样批量数据,打破样本相关性。
- 示例:类似学生从错题本中随机复习,避免重复学习相似知识。
2. 目标网络(Target Network)
- 问题:Q-Learning中使用当前网络更新自身,导致训练目标不稳定。
- 机制:
- 维护两个结构相同的网络:
- 评估网络Q(s,a;θ)Q(s,a;\theta)Q(s,a;θ):用于选择动作和生成训练标签;
- 目标网络Q(s,a;θ−)Q(s,a;\theta^-)Q(s,a;θ−):提供稳定的目标值,定期从评估网络复制参数。
- 维护两个结构相同的网络:
- 更新公式: yj=rj+γmaxa′Q(sj′,a′;θ−)y_j = r_j + \gamma \max_{a'} Q(s'_j, a';\theta^-)yj=rj+γa′maxQ(sj′,a′;θ−) 其中yjy_jyj为目标值,通过最小化均方误差minθE[(yj−Q(sj,aj;θ))2]\min_\theta \mathbb{E}[(y_j - Q(s_j,a_j;\theta))^2]minθE[(yj−Q(sj,aj;θ))2]更新评估网络。
四、网络架构
1. 典型结构(以Atari游戏为例)
- 输入层:预处理后的4帧游戏图像(如84×84×4),捕捉动态信息;
- 卷积层:通常3-4层,提取视觉特征,如:
- 第一层:卷积核8×8,步长4,激活函数ReLU;
- 第二层:卷积核4×4,步长2,激活函数ReLU;
- 全连接层:1-2层,输出各动作的Q值,维度为动作空间大小(如Atari游戏中通常为4-18个动作)。
2. 输入预处理
- 灰度化:减少计算量;
- 帧堆叠:保留时间维度信息,如4帧堆叠捕捉运动轨迹。
五、训练流程
- 初始化: - 经验回放缓冲区DDD(容量NNN); - 评估网络Q(θ)Q(\theta)Q(θ)和目标网络Q(θ−)Q(\theta^-)Q(θ−),参数θ\thetaθ随机初始化,θ−=θ\theta^-=\thetaθ−=θ; - 初始状态s1s_1s1预处理为输入图像。
- 交互循环: - 动作选择:以ϵ\epsilonϵ-贪婪策略从Q(θ)Q(\theta)Q(θ)选择动作ata_tat(ϵ\epsilonϵ随训练衰减); - 环境交互:执行动作ata_tat,获取奖励rtr_trt和新状态st+1s_{t+1}st+1,标记是否终止; - 存储经验:将(st,at,rt,st+1)(s_t, a_t, r_t, s_{t+1})(st,at,rt,st+1)存入DDD; - 状态更新:st←st+1s_t \leftarrow s_{t+1}st←st+1。
- 网络更新: - 当∣D∣≥|D| \geq∣D∣≥批量大小时,从DDD中随机采样批量数据; - 计算目标值yjy_jyj(使用目标网络); - 反向传播更新评估网络参数θ\thetaθ; - 定期(如每1000步)更新目标网络:θ−←θ\theta^- \leftarrow \thetaθ−←θ。
六、算法完整伪代码
# 参数设置
capacity = N # 回放缓存容量
batch_size = B
γ = 0.99 # 折扣因子
τ = 0.001 # 软更新比例
update_target_every = C # 硬更新步数间隔 (或仅用软更新)
# 初始化
初始化在线网络 Q(θ) (随机权重 θ)
初始化目标网络 Q(θ⁻) with θ⁻ = θ
初始化回放缓存 D (容量 = N)
for episode = 1 to M:
s = env.reset()
while not done:
# 1. 动作选择 (ε-greedy)
a = argmax_{a} Q(s, a; θ) 以概率 (1-ε), 否则随机选动作
# 2. 执行动作,观测环境
s_next, r, done, _ = env.step(a)
# 3. 存储转移
D.push((s, a, r, s_next, done))
s = s_next
# 4. 训练阶段 (每隔几步或每步都训)
if len(D) >= batch_size:
# 4a. 采样批次
batch = D.sample(batch_size)
# 4b. 计算目标值 (对于批次中每个样本)
targets = []
for each (s_j, a_j, r_j, s'_j, done_j) in batch:
if done_j: target_j = r_j
else: target_j = r_j + γ * max_{a'} Q(s'_j, a'; θ⁻) # ⚠️ 目标网络
targets.append(target_j)
# 4c. 计算在线网络预测值 & 损失 (MSE)
preds = Q(s_j, a_j; θ) # 计算当前批次状态的预测值
loss = MSE(preds, targets)
# 4d. 梯度下降更新在线网络参数 θ
θ <- optimizer.step(∇θ(loss), θ)
# 4e. 更新目标网络 (可选软更新或硬更新)
if using_soft_update: θ⁻ <- τ * θ + (1-τ) * θ⁻
if step % update_target_every == 0: # 硬更新
θ⁻ <- θ
七、应用场景
领域 | 案例 | 特点 | |
---|---|---|---|
游戏AI | Atari游戏(如Pong、Breakout) | 处理像素输入,学会复杂策略 | |
机器人控制 | 机械臂抓取、导航 | 结合视觉输入实现动态决策 | |
推荐系统 | 广告投放、内容推荐 | 优化长期收益(点击、留存等) | |
自动驾驶 | 车道保持、障碍物规避 | 处理传感器数据,实时决策 |
八、优缺点分析
优点:
- 处理高维状态:通过卷积神经网络有效提取图像特征;
- 稳定性提升:经验回放和目标网络减少训练波动;
- 泛化能力:可迁移至类似场景(如不同Atari游戏)。
缺点:
- 样本效率低:需大量交互数据,训练成本高;
- 动作空间限制:适用于离散动作空间,连续动作需特殊处理;
- 超参数敏感:如学习率、折扣因子需精细调优。
九、后续改进
- Double DQN:分离动作选择与评估,减少Q值过估计;
- Dueling DQN:将Q值分解为状态价值和动作优势,提升收敛速度;
- Rainbow DQN:整合多种改进技术(如优先经验回放、分布强化学习等)。
通过结合深度学习与强化学习,DQN开创了深度强化学习的先河,为后续算法(如Policy Gradient、Actor-Critic)奠定了基础,至今仍是复杂决策任务的重要解决方案。