强化学习实例 python 倒立摆
时间: 2025-05-05 18:03:34 浏览: 28
<think>好的,我需要帮用户找到用Python实现倒立摆问题的强化学习示例代码。首先,用户提到了倒立摆(CartPole),这通常是强化学习的经典问题,比如OpenAI Gym中的CartPole-v1环境。我应该先考虑使用常见的库,比如Gym和PyTorch或者TensorFlow来实现。
用户之前的引用提到了PyTorch,特别是引用[^2]提到了DDPG的PyTorch样例代码。不过DDPG可能更适合连续动作空间的问题,而CartPole-v1的动作空间是离散的(左或右)。所以可能需要用DQN(Deep Q-Network)或者其他适用于离散动作的算法。
接下来,我需要确认是否有现成的示例代码。比如,OpenAI Gym的文档可能有基础示例,或者参考PyTorch的强化学习教程。用户提供的引用中的DDPG样例可能需要进行调整,或者寻找其他资源。另外,要确保代码能够运行,避免版本问题,用户之前提到PyTorch比TensorFlow更友好,所以优先用PyTorch实现。
然后,我应该分步骤来构建这个示例。首先,安装必要的库,比如gym和PyTorch。接着,定义神经网络模型,设计经验回放机制,实现DQN算法。然后训练循环,每隔一定步数更新目标网络,并保存模型。
需要注意的问题包括:超参数调整(学习率、gamma、epsilon衰减等),经验回放缓冲区的大小,网络结构的设计,以及如何稳定训练过程。可能还需要处理探索与利用的平衡,比如使用epsilon-greedy策略。
最后,生成相关问题,可能涉及算法选择、超参数调整、性能优化等,确保用户能进一步了解相关知识。</think>以下是一个使用PyTorch实现倒立摆控制(CartPole-v1)的深度Q学习(DQN)示例代码:
```python
import gym
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
class DQN(nn.Module):
def __init__(self, state_size, action_size):
super(DQN, self).__init__()
self.fc = nn.Sequential(
nn.Linear(state_size, 24),
nn.ReLU(),
nn.Linear(24, 24),
nn.ReLU(),
nn.Linear(24, action_size)
)
def forward(self, x):
return self.fc(x)
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
return random.sample(self.buffer, batch_size)
def __len__(self):
return len(self.buffer)
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
# 超参数
gamma = 0.99
batch_size = 64
lr = 0.001
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.995
target_update = 10
policy_net = DQN(state_dim, action_dim)
target_net = DQN(state_dim, action_dim)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.Adam(policy_net.parameters(), lr=lr)
memory = ReplayBuffer(10000)
def select_action(state):
if random.random() < epsilon:
return env.action_space.sample()
with torch.no_grad():
return policy_net(torch.FloatTensor(state)).argmax().item()
def update_model():
if len(memory) < batch_size:
return
transitions = memory.sample(batch_size)
batch = list(zip(*transitions))
states = torch.FloatTensor(np.array(batch[0]))
actions = torch.LongTensor(batch[1])
rewards = torch.FloatTensor(batch[2])
next_states = torch.FloatTensor(np.array(batch[3]))
dones = torch.FloatTensor(batch[4])
current_q = policy_net(states).gather(1, actions.unsqueeze(1))
next_q = target_net(next_states).max(1)[0].detach()
target = rewards + gamma * next_q * (1 - dones)
loss = nn.MSELoss()(current_q, target.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
episodes = 500
for episode in range(episodes):
state = env.reset()
total_reward = 0
while True:
action = select_action(state)
next_state, reward, done, _ = env.step(action)
memory.push(state, action, reward, next_state, done)
state = next_state
total_reward += reward
update_model()
if done:
break
# 更新目标网络
if episode % target_update == 0:
target_net.load_state_dict(policy_net.state_dict())
# 衰减探索率
epsilon = max(epsilon_min, epsilon * epsilon_decay)
print(f"Episode: {episode}, Reward: {total_reward}")
# 保存模型
torch.save(policy_net.state_dict(), 'cartpole_dqn.pth')
```
阅读全文
相关推荐


















