actor-critic代码逐行解析(tensorflow版)

本文详细解读了使用Tensorflow实现的actor-critic算法,介绍了Actor网络和Critic网络的工作流程。Actor网络基于Policy-Gradients,采用两层全连接层进行连续动作选择,利用Policy-Gradients的损失函数进行更新。Critic网络作为评估器,给出Actor动作的评分,并通过时间差分误差指导Actor网络的参数更新,同样由两层全连接层构成。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

深度强化学习算法actor-critic代码逐行解析(tensorflow版)

在这里插入图片描述

Actor是基于Policy-Gradients。可以选择连续动作,但是必须循环一个回合才可以更新策略。学习效率低。
Critic网络继承了Q-learning 的传统,依然可以逐步更新。
首先导入需要的包,这没什么好说的。

import numpy as np
import tensorflow as tf
import gym
import matplotlib.pyplot as plt

np.random.seed(2)
tf.set_random_seed(2)  # reproducible

# 超参数
OUTPUT_GRAPH = False
MAX_EPISODE = 5
DISPLAY_REWARD_THRESHOLD = 200  # 刷新阈值
MAX_EP_STEPS = 500  # 最大迭代次数
RENDER = False  # 渲染开关,这玩意儿是gym输出动画的开关
GAMMA = 0.9  # 衰变值
LR_A = 0.001  # Actor学习率
LR_C = 0.01  # Critic学习率

env = gym.make('CartPole-v0')
env.seed(1)
env = env.unwrapped

N_F = env.observation_space.shape[0]  # 状态空间
N_A = env.action_space.n  # 动作空间

Actor网络

class Actor(object):
    def __init__(self, sess, n_features, n_actions, lr=0.001):
        self.sess = sess

        self.s = tf.placeholder(tf.float32, [1, n_features], "state")
        self.a = tf.placeholder(tf.int32, None, "act")
        self.td_error = tf.placeholder(tf.float32, None, "td_error")  # TD_error

        with tf.variable_scope('Actor'):
            l1 = tf.layers.dense(
                inputs=self.s,
                units=20,  # number of hidden units
                activation=tf
``` import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import numpy as np class ActorCritic(nn.Module): def __init__(self, state_dim, action_dim): super(ActorCritic, self).__init__() self.fc1 = nn.Linear(state_dim, 128) self.fc2 = nn.Linear(128, 128) self.actor = nn.Linear(128, action_dim) self.critic = nn.Linear(128, 1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) action_probs = F.softmax(self.actor(x), dim=-1) state_value = self.critic(x) return action_probs, state_value class A2CScheduler: def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99): self.model = ActorCritic(state_dim, action_dim) self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.gamma = gamma def select_action(self, state): state = torch.FloatTensor(state).unsqueeze(0) action_probs, _ = self.model(state) action = torch.multinomial(action_probs, 1).item() return action, action_probs[:, action] def update(self, trajectory): rewards, log_probs, state_values = [], [], [] for (state, action, reward, log_prob, state_value) in trajectory: rewards.append(reward) log_probs.append(log_prob) state_values.append(state_value) returns = [] R = 0 for r in reversed(rewards): R = r + self.gamma * R returns.insert(0, R) returns = torch.tensor(returns) log_probs = torch.stack(log_probs) state_values = torch.stack(state_values).squeeze() advantage = returns - state_values actor_loss = -log_probs * advantage.detach() critic_loss = F.mse_loss(state_values, returns) loss = actor_loss.mean() + critic_loss self.optimizer.zero_grad() loss.backward() self.optimizer.step() # 结合 `mp-quic-go` 使用 # 1. 获取状态信息 (如带宽、RTT、丢包等) # 2. 选择路径 (基于 `select_action` 方法) # 3. 收集数据并训练模型 (基于 `update` 方法)```请详细解释每一行代码的含义和意义
最新发布
04-02
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值