Soft Actor-Critic(SAC)深度解析

当前大部分博客对SAC的解释都不够透彻,对于第一次接触最大熵学习的初学者非常不友好,此外大部分博客混用了有关SAC的两篇论文:

[1] T. Haarnoja et al. 2018. Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor

[2] T. Haarnoja et al. 2018. Soft Actor-Critic Algorithms and Applications.

注意,两篇论文有些许不同,第一次看可能会感到疑惑,这里我介绍初始的方法,即第一篇论文中使用的三个基础网络:状态网络,动作网络和策略网络。

SAC 定义用于涉及连续动作的强化学习任务。SAC 的最大特点在于它使用了一个修改后的强化学习目标函数。SAC 不仅寻求最大化终身奖励,还寻求最大化策略的熵。熵这个术语有一个相当晦涩的定义,并且根据应用的不同有多种解释,但我想在这里分享一个直观的解释。我们可以将熵视为一个随机变量有多不可预测。如果一个随机变量总是取单个值,那么它的熵为零,因为它根本不可预测。如果一个随机变量以相等的概率可以是任何实数,那么它的熵非常高,因为它非常不可预测。我们为什么希望策略具有高熵?我们希望在策略中具有高熵,以明确地鼓励探索,鼓励策略对具有相同或几乎相同的 Q 值的动作分配相等的概率,并且确保它不会崩溃成反复选择某个特定动作,这可能利用了近似 Q 函数中的一些不一致性。 因此,SAC 通过鼓励策略网络进行探索,而不是将非常高概率分配给动作范围中的任何一部分,从而克服了脆性问题。

现在我们知道要优化什么了,让我们来了解如何进行优化。SAC 使用三个网络:由 ψ 参数化的状态值函数 V,由 θ 参数化的软 Q 函数 Q,以及由 ϕ 参数化的策略函数 π。虽然原则上没有必要为与策略相关的 V 和 Q 函数分别使用不同的近似器,但作者们说在实践中使用不同的函数近似器有助于收敛。因此,我们需要按照以下方式训练三个函数近似器:

首先通过最小化误差来训练值网络:

它所表达的意思是,在从经验回放缓冲区中采样的所有状态下,我们需要减少价值网络的预测与 Q 函数的预期预测之间的平方差,再加上策略函数π的熵(这里通过策略函数的负对数来衡量)。其实就是缩小两个估计值的差值,由于Q网络参数的不准确性,其实后一项对target soft V的估计也不完全准确,这里的操作就有点像soft Q-learning 中交替计算soft Q和soft V的过程,通过迭代两者都能够收敛到真实值。该函数梯度为:

随后我们通过最小化以下误差来训练 Q 网络

Q值的近似目标为:

梯度为:

最小化这个目标函数意味着以下内容:对于经验回放缓冲区中的所有(状态,动作)对,我们希望最小化 Q 函数的预测值与即时(一步)奖励加上下一个状态的折扣预期值之间的平方差。这实质上还是两个Q值估计值间的差值,由于V网络参数的不确定性,对targetQ的估计也不是完全准确的

请注意,这个预期值来自一个带有横线的ψ参数化的值函数。这是一个额外的值函数,称为目标值函数。我们将会讲到为什么需要这个,但现在不用担心它,只需将其视为我们正在训练的值函数。

最后我们通过最小化以下误差来训练策略网络π

策略函数的学习目标于soft Q-learning一致,即使得策略分布尽可能接近规范化的Q值分布。

这个目标函数看起来很复杂,但实际上它表达的意思非常简单。期望中的 KL 函数,即 Kullback-Leibler 散度,我强烈建议你阅读有关 KL 散度的内容,因为它在深度学习研究和应用中经常出现。对于本教程而言,你可以将其理解为两个分布的差异。因此,这个目标函数基本上是试图使策略函数的分布看起来更像 Q 函数指数化的分布,再除以另一个函数 Z 后的分布。将其展开便于后续推导:

要使两个分布尽可能接近,我们只需关注对数项,又因为归一化项与策略无关,因此可以当做常数提到前面。将其展开可得:

为了最小化这个目标,作者使用了一种称为重参数化技巧的方法。这个技巧用于确保从策略分布中采样是一个可微的过程,以便在反向传播误差时没有问题。现在,策略被参数化为如下形式:

误差项是一个从高斯分布中采样的噪声向量。我们将在实现部分中进一步解释它。

现在,我们可以将目标函数表示如下:

相应的梯度为

数学部分就到这里!

现在我们已经了解了算法背后的理论,让我们在 Pytorch 中实现它的一个版本。我的实现基于 higgsfield 的,但有一个关键的改变:我使用了重参数化技巧,这使得训练收敛更好,因为方差更低。首先,让我们看看算法的主体部分,以便从高层理解正在发生的事情,然后我们可以深入到各个组件的细节。

env = NormalizedActions(gym.make("Pendulum-v0"))

action_dim = env.action_space.shape[0]
state_dim  = env.observation_space.shape[0]
hidden_dim = 256

value_net        = ValueNetwork(state_dim, hidden_dim).to(device)
target_value_net = ValueNetwork(state_dim, hidden_dim).to(device)

soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device)
soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device)
policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device)

for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):
    target_param.data.copy_(param.data)
    

value_criterion  = nn.MSELoss()
soft_q_criterion1 = nn.MSELoss()
soft_q_criterion2 = nn.MSELoss()
lr  = 3e-4

value_optimizer  = optim.Adam(value_net.parameters(), lr=lr)
soft_q_optimizer = optim.Adam(soft_q_net.parameters(), lr=lr)
policy_optimizer = optim.Adam(policy_net.parameters(), lr=lr)


replay_buffer_size = 1000000
replay_buffer = ReplayBuffer(replay_buffer_size)

首先,我们初始化一个 OpenAI Gym 环境,我们的智能体将在其中玩强化学习游戏。我们存储关于环境观察维度、动作空间维度的信息,然后,设置我们网络中隐藏层的数量超参数。然后我们初始化我们想要训练的三个网络以及一个目标 V 网络。你会注意到我们有两个 Q 网络。我们维护两个 Q 网络来解决 Q 值过估计的问题。为了对抗这个问题,我们维护两个 Q 网络,并使用两者的最小值来执行我们的策略和 V 函数更新。

现在,是时候解释整个目标 V 网络业务了。使用target network是由训练 V 网络中的一个问题所驱动的。如果你回到理论部分中的目标函数,你会发现 Q 网络训练的目标取决于 V 网络,而 V 网络的目标取决于 Q 网络(这很有道理,因为我们试图在两个函数之间强制执行贝尔曼一致性)。相当于V 网络有一个间接依赖于自身的训练目标(这个问题和DQN相同,训练的目标也在不停变化,因为目标V的值也受到V网络参数psi影响),这使得训练非常不稳定。

解决方案也与DQN中相同,解决方案是使用一组参数,这些参数接近主 V 网络的参数,但带有时间延迟。因此我们创建一个第二网络,它滞后于主网络,称为目标网络。有两种方法来处理这个问题。第一种方法是在一定数量的步骤后定期将目标网络从主网络复制过来(DQN中的做法)。另一种方法是通过对主网络本身和目标网络进行 Polyak 平均(一种移动平均)来更新目标网络。在这个实现中,我们使用 Polyak 平均。我们将主网络和目标 V 网络初始化为具有相同的参数。

下面的这部分代码展示的是如何与环境交互并采样的过程:

while frame_idx < max_frames:
    state = env.reset()
    episode_reward = 0
    
    for step in range(max_steps):
        if frame_idx >1000:
            action = policy_net.get_action(state).detach()
            next_state, reward, done, _ = env.step(action.numpy())
        else:
            action = env.action_space.sample()
            next_state, reward, done, _ = env.step(action)
        
        
        replay_buffer.push(state, action, reward, next_state, done)
        
        state = next_state
        episode_reward += reward
        frame_idx += 1
        
        if len(replay_buffer) > batch_size:
            update(batch_size)
        
        if frame_idx % 1000 == 0:
            plot(frame_idx, rewards)
        
        if done:
            break
        
    rewards.append(episode_reward)

这里使用了两层循环,外层循环用于初始化环境作为新episode的开始,内层循环记录时间步并采样动作。在靠前的时间步中我们使用均匀采样策略,当episode足够长时,我们认为遵循soft Q-learning中的建议,通过策略网络对动作采样。将与环境交互的值存放于缓冲区中,直到有最小数量的观测值后执行网络更新。

网络的更新代码如下:

def update(batch_size,gamma=0.99,soft_tau=1e-2,):
    
    state, action, reward, next_state, done = replay_buffer.sample(batch_size)

    state      = torch.FloatTensor(state).to(device)
    next_state = torch.FloatTensor(next_state).to(device)
    action     = torch.FloatTensor(action).to(device)
    reward     = torch.FloatTensor(reward).unsqueeze(1).to(device)
    done       = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(device)

    predicted_q_value1 = soft_q_net1(state, action)
    predicted_q_value2 = soft_q_net2(state, action)
    predicted_value    = value_net(state)
    new_action, log_prob, epsilon, mean, log_std = policy_net.evaluate(state)

# Training Q Function
    target_value = target_value_net(next_state)
    target_q_value = reward + (1 - done) * gamma * target_value
    q_value_loss1 = soft_q_criterion1(predicted_q_value1, target_q_value.detach())
    q_value_loss2 = soft_q_criterion2(predicted_q_value2, target_q_value.detach())
    print("Q Loss")
    print(q_value_loss1)
    soft_q_optimizer1.zero_grad()
    q_value_loss1.backward()
    soft_q_optimizer1.step()
    soft_q_optimizer2.zero_grad()
    q_value_loss2.backward()
    soft_q_optimizer2.step()    

# Training Value Function
    #predicted_new_q_value = torch.min(soft_q_net1(state, new_action),soft_q_net2(state, new_action))
	predicted_new_q_value = torch.min(soft_q_net1(state, action),soft_q_net2(state, action))
    target_value_func = predicted_new_q_value - log_prob
    value_loss = value_criterion(predicted_value, target_value_func.detach())
    print("V Loss")
    print(value_loss)
    value_optimizer.zero_grad()
    value_loss.backward()
    value_optimizer.step()

# Training Policy Function
    policy_loss = (log_prob - predicted_new_q_value).mean()

    policy_optimizer.zero_grad()
    policy_loss.backward()
    policy_optimizer.step()
    
    
    for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):
        target_param.data.copy_(
            target_param.data * (1.0 - soft_tau) + param.data * soft_tau
        )

对于Q网络的更新,这里关注到我们使用了两个不同的Q网络,目的是解决对Q值的过度估计问题,但这里处理过度估计的方法和Double DQN不同,这里我们取Q值估计中的较小值,直接砍掉可能出现过度估计的Q值噪声估计。这里的过度估计问题出现在估计V网络的target soft V value中。

此时V网络的更新可改写为:

此外,在Q网络的更新中我们还应用了fixed target方法,这里我们使用了一个滞后的V网络作为固定的优化目标。

这里我们可能会疑惑为什么在更新策略网络时,我们使用了new_action而非回放池采样得到的action。这是因为SAC策略网络的输出高斯分布的均值和方差,直接从其输出采样动作会使策略更新的梯度计算不可导,这是因为我们采样的动作at与策略网络中的参数phi相关,但是采样的返回是一个具体值,在at和网络参数phi之间没有明确的可微映射,这就导致梯度是不可计算的,为了解决这个问题我们使用了重参数化的技巧。

具体来说,重参数化将采样拆成了两部分:固定分布的噪声和可微的变换,即:

为此,我们从标准正态分布中采样一些噪声,并将其乘以我们的标准差,然后加上均值。然后将这个数通过 tanh 函数激活,以得到我们的动作。最后,使用 tanh(mean + std*z) 的对数似然近似计算对数概率。这种操作使得策略网络的可微性得到了保证。

最后使用polyak平均法更新target V network,polyak是一种软更新策略,更新规则为:

下面的代码展示的是网络结构的具体实现:

class ValueNetwork(nn.Module):
    def __init__(self, state_dim, hidden_dim, init_w=3e-3):
        super(ValueNetwork, self).__init__()
        
        self.linear1 = nn.Linear(state_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)
        
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)
        
    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x
        
        
class SoftQNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3):
        super(SoftQNetwork, self).__init__()
        
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, 1)
        
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)
        
    def forward(self, state, action):
        x = torch.cat([state, action], 1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x
        
        
class PolicyNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3, log_std_min=-20, log_std_max=2):
        super(PolicyNetwork, self).__init__()
        
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        
        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        
        self.mean_linear = nn.Linear(hidden_size, num_actions)
        self.mean_linear.weight.data.uniform_(-init_w, init_w)
        self.mean_linear.bias.data.uniform_(-init_w, init_w)
        
        self.log_std_linear = nn.Linear(hidden_size, num_actions)
        self.log_std_linear.weight.data.uniform_(-init_w, init_w)
        self.log_std_linear.bias.data.uniform_(-init_w, init_w)
        
    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        
        mean    = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        
        return mean, log_std
    
    def evaluate(self, state, epsilon=1e-6):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        
        normal = Normal(0, 1)
        z      = normal.sample()
        action = torch.tanh(mean+ std*z.to(device))
        log_prob = Normal(mean, std).log_prob(mean+ std*z.to(device)) - torch.log(1 - action.pow(2) + epsilon)
        return action, log_prob, z, mean, log_std
        
    
    def get_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        mean, log_std = self.forward(state)
        std = log_std.exp()
        
        normal = Normal(0, 1)
        z      = normal.sample().to(device)
        action = torch.tanh(mean + std*z)
        
        action  = action.cpu()
        return action[0]

这里我们需要注意的只有一点,就是对log_prob的估计,这里的log_prob对应的是策略函数更新部分的

可以看到在实现的部分我们使用Normal(mean, std).log_prob(mean+ std*z.to(device)) - torch.log(1 - action.pow(2) + epsilon)​

虽然我们的动作通过均值和方差还原到了近似于SAC输出的高斯分布上,但是两者并不是同一个分布,并且由于大部分连续动作控制场景中要求将动作a限制在某个区间内,因此重参数化还需要经过一个可逆变换,如果你直接从 里采样,就有可能得到超出边界的值。把样本 过一个 ,就能把它“挤”到 区间内。

当你对随机变量做可逆变换(这里是 ),它们的概率密度要满足

在对数域里,这变成

由于我们实际要算的是 log⁡π(a∣s),而代码中先算了 log⁡pU(u),为了矫正必须减去

这就是代码里 - torch.log(1 - action.pow(2) + ε)​ 那一项。没有这一步校正,算法就会错误地估计动作在真实策略下出现的概率,进而导致策略梯度和熵正则化都跑偏。

以上就是对SAC的全部解析内容,如有需要,欢迎讨论

参考内容:

https://medium.com/data-science/soft-actor-critic-demystified-b8427df61665

### Soft Actor-Critic (SAC) 算法的图解与可视化解释 Soft Actor-Critic (SAC) 是一种基于最大熵强化学习框架的离策略深度强化学习算法[^1]。它通过优化目标函数来平衡奖励最大化和动作分布的熵最大化,从而实现更高效的探索能力。 #### SAC 的核心概念 SAC 结合了 actor-critic 架构以及软策略更新机制。以下是其主要组成部分及其作用: 1. **Actor**: 负责生成遵循当前最优策略的动作概率分布 \( \pi(a|s) \)[^2]。 2. **Critic**: 使用两个 Q 函数估计状态-动作对的价值 \( Q(s, a) \),并通过最小化 TD 错误来进行训练[^3]。 3. **Entropy Term**: 引入了一个额外的目标项——动作分布的熵 \( H(\pi) \),用于鼓励策略更加随机化以促进更好的探索[^4]。 #### 图解说明 为了更好地理解 SAC 工作原理,可以借助以下几种常见的图表形式进行展示: 1. **架构流程图** 这种类型的图形展示了整个系统的组成模块如何相互交互。通常会显示 `actor` 和双路 `critic` 的结构布局,并标注数据流动方向(前向传播 vs 后向梯度传递)。此外还会特别强调温度参数 α 对于调节熵的重要性[^5]。 2. **损失函数分解图** 可视化不同部分损失之间的关系有助于深入剖析模型内部运作细节。例如分别绘制针对 policy improvement step 中涉及的各项成分:expected reward term、regularization entropy bonus;同时对比 twin delayed deep deterministic policy gradient(TD3)-like target network updates during value estimation steps. 3. **实验结果曲线** 绘制在特定环境下的性能指标随时间变化趋势可以帮助直观感受该方法的优势所在。比如平均累积回报 versus episodes 或者 timesteps; 并且可以通过比较其他 baseline algorithms 来突出 sac 较高的样本效率及稳定性表现[^6]. ```python import matplotlib.pyplot as plt # Example of plotting training progress over time. def plot_training_curve(rewards_over_time): plt.figure(figsize=(8, 6)) plt.plot(range(len(rewards_over_time)), rewards_over_time) plt.title('Training Progress') plt.xlabel('Episodes / Timesteps') plt.ylabel('Average Reward per Episode') plt.grid(True) plt.show() ``` 上述代码片段提供了一种简单的方法用来呈现强化学习过程中收益的变化情况。 #### 总结 综上所述,SAC 不仅继承了传统 actor-critic 方法的优点还融入了现代先进理念使得整体设计更为合理高效。利用恰当的形式化表达工具能够极大地辅助我们掌握这一复杂技术背后的精髓之处[^7]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值