DQN代码
时间: 2025-04-26 17:13:10 浏览: 33
### 关于DQN(深度Q网络)的代码实现
#### TensorFlow 实现
利用TensorFlow框架,可以高效地构建和训练深度Q网络。下面是一个简单的基于TensorFlow的DQN实现示例:
```python
import tensorflow as tf
from collections import deque
import random
import gym
class DQNAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
self.memory = deque(maxlen=2000)
self.gamma = 0.95 # 折扣率
self.epsilon = 1.0 # 探索率
self.epsilon_min = 0.01
self.epsilon_decay = 0.995
self.learning_rate = 0.001
self.model = self._build_model()
def _build_model(self):
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(24, input_dim=self.state_size, activation='relu'))
model.add(tf.keras.layers.Dense(24, activation='relu'))
model.add(tf.keras.layers.Dense(self.action_size, activation='linear'))
model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(lr=self.learning_rate))
return model
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = DQNAgent(state_size, action_size)
done = False
batch_size = 32
for e in range(episodes):
state = env.reset()
state = np.reshape(state, [1, state_size])
for time in range(500):
...
```
此段代码展示了如何创建一个基础版本的DQN代理,并初始化环境以及设置一些必要的超参数[^3]。
#### PyTorch 实现
对于PyTorch而言,其实现方式也十分相似,主要区别在于模型定义部分以及优化器的选择上:
```python
import torch
import torch.nn.functional as F
import numpy as np
from collections import namedtuple, deque
import random
import gym
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Transition = namedtuple('Transition',
('state', 'action', 'next_state', 'reward'))
class ReplayMemory(object):
def __init__(self, capacity):
self.memory = deque([],maxlen=capacity)
def push(self, *args):
"""保存转换"""
self.memory.append(Transition(*args))
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
policy_net = ... # 定义策略网络
target_net = ... # 定义目标网络
optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(10000)
steps_done = 0
def select_action(state):
global steps_done
sample = random.random()
eps_threshold = EPS_END + (EPS_START - EPS_END) * \
math.exp(-1. * steps_done / EPS_DECAY)
steps_done += 1
if sample > eps_threshold:
with torch.no_grad():
# t.max(1) will return largest column value of each row.
# second column on max result is index of where max element was
# found, so we pick action with the larger expected reward.
return policy_net(state).max(1)[1].view(1, 1)
else:
return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
episode_durations = []
num_episodes = 500
for i_episode in range(num_episodes):
# 训练过程...
if i_episode % TARGET_UPDATE == 0:
target_net.load_state_dict(policy_net.state_dict())
print('Complete')
```
这段代码片段说明了怎样通过PyTorch建立记忆库、选择行动逻辑以及周期性同步目标网络权重的方法[^4]。
阅读全文
相关推荐


















