import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
import matplotlib.pyplot as plt
# Neural network for the dynamics model
class DynamicsModel(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=64):
super(DynamicsModel, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, state_dim)
def forward(self, state, action):
if len(action.shape) == 1:
action = action.unsqueeze(1)
x = torch.cat([state, action], dim=1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
next_state_pred = self.fc3(x)
return next_state_pred
# Planning function using the learned model (random shooting method)
def plan_action(dynamics_model, state, action_space, horizon=5, num_simulations=100):
best_action = None
best_return = -float('inf')
for _ in range(num_simulations):
simulated_state = state.clone()
total_return = 0
for t in range(horizon):
action = torch.tensor([action_space.sample()], dtype=torch.float32)
action = action.unsqueeze(0)
next_state_pred = dynamics_model(simulated_state, action)
reward = 1.0
total_return += reward
simulated_state = next_state_pred
if total_return > best_return:
best_return = total_return
best_action = action
return best_action
# Replay buffer to store experience
class ReplayBuffer:
def __init__(self, max_size=10000):
self.buffer = deque(maxlen=max_size)
def add(self, experience):
self.buffer.append(experience)
def sample(self, batch_size):
return random.sample(self.buffer, batch_size)
# Main MBRL loop
def model_based_rl():
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = 1
dynamics_model = DynamicsModel(state_dim, action_dim)
optimizer = optim.Adam(dynamics_model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()
replay_buffer = ReplayBuffer()
num_episodes = 500
batch_size = 64
max_timesteps = 200
total_rewards = [] # List to store total rewards per episode
for episode in range(num_episodes):
state = env.reset()
state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
total_reward = 0
for t in range(max_timesteps):
env.render()
action = plan_action(dynamics_model, state, env.action_space)
action = int(action.item() > 0.5) # Convert to discrete action for CartPole (0 or 1)
next_state, reward, done, _ = env.step(action)
next_state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)
action = torch.tensor([[action]], dtype=torch.float32)
replay_buffer.add((state, action, next_state))
if len(replay_buffer.buffer) >= batch_size:
batch = replay_buffer.sample(batch_size)
states, actions, next_states = zip(*batch)
states = torch.cat(states)
actions = torch.cat(actions)
next_states = torch.cat(next_states)
next_state_preds = dynamics_model(states, actions)
loss = loss_fn(next_state_preds, next_states)
optimizer.zero_grad()
loss.backward()
optimizer.step()
state = next_state
total_reward += reward
if done:
break
total_rewards.append(total_reward) # Store the total reward for this episode
print(f"Episode {episode + 1}: Total Reward = {total_reward}")
env.close()
# Plot the rewards
plt.figure(figsize=(10, 5))
plt.plot(total_rewards, label='Total Reward per Episode')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Total Rewards Over Episodes')
plt.legend()
plt.show()
if __name__ == "__main__":
model_based_rl()