DQN on CartPole-v0
时间: 2025-02-22 14:50:14 浏览: 33
### 实现深度Q网络(DQN)于CartPole-v0环境
在实现DQN解决CartPole-v0问题时,目标在于通过最小化估计的Q值\(Q(s,a)\)与目标Q值之间的平方误差来优化模型性能[^2]。此过程通常采用梯度下降算法完成。
#### 构建神经网络架构
为了适应CartPole-v0的任务需求,构建一个简单的多层感知器作为函数逼近器是常见做法:
```python
import torch.nn as nn
class DQN(nn.Module):
def __init__(self, input_dim, output_dim):
super(DQN, self).__init__()
self.fc = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, output_dim)
)
def forward(self, x):
return self.fc(x)
```
#### 定义经验回放缓冲区
引入经验回放机制有助于打破数据间的关联性并提高样本利用率:
```python
from collections import deque
import random
class ReplayBuffer(object):
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
states, actions, rewards, next_states, dones = zip(*random.sample(self.buffer, batch_size))
return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)
def __len__(self):
return len(self.buffer)
```
#### 训练循环设计
训练过程中需不断更新当前策略所对应的参数权重,并定期同步至目标网络以稳定学习过程:
```python
def train(env, policy_net, target_net, optimizer, memory, gamma=0.99, batch_size=64, eps_start=1., eps_end=0.05, eps_decay=1000, tau=0.005, episodes=500):
steps_done = 0
for i_episode in range(episodes):
# Initialize the environment and get its first observation.
state = env.reset()
total_reward = 0
while True:
# Select an action based on epsilon-greedy strategy.
sample = random.random()
eps_threshold = eps_end + (eps_start - eps_end) * math.exp(-1. * steps_done / eps_decay)
if sample > eps_threshold:
with torch.no_grad():
action = policy_net(torch.tensor([state], dtype=torch.float32)).max(1)[1].item()
else:
action = env.action_space.sample()
# Execute selected action within the environment.
next_state, reward, done, _ = env.step(action)
total_reward += reward
# Store transition into replay buffer.
memory.push(state, action, reward, next_state, float(done))
# Move to the next state.
state = next_state
# Perform one optimization step (on the policy network).
optimize_model(policy_net, target_net, optimizer, memory, gamma, batch_size)
# Soft update of the target network's weights.
target_net_state_dict = target_net.state_dict()
policy_net_state_dict = policy_net.state_dict()
for key in policy_net_state_dict:
target_net_state_dict[key] = policy_net_state_dict[key]*tau + target_net_state_dict[key]*(1-tau)
target_net.load_state_dict(target_net_state_dict)
steps_done += 1
if done:
break
print(f'Episode {i_episode}: Total Reward={total_reward}')
```
上述代码片段展示了如何利用PyTorch框架搭建适用于CartPole-v0游戏挑战的基础版DQN解决方案。
阅读全文
相关推荐











