Open In App

Model-Based Reinforcement Learning (MBRL) in AI

Last Updated : 03 Jun, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

Model-based reinforcement learning is a subclass of reinforcement learning where the agent constructs an internal model of the environment's dynamics and uses it to simulate future states, predict rewards and optimize actions efficiently. The Key Components of MBRL are:

  • Model of the Environment: This is typically a predictive model which can forecast the next state and rewards given the current state and action.
  • Planning Algorithm: Once a model is learned, a planning algorithm evaluates the model to decide the optimal sequence of actions.
Model-Based-Reinforcement-Learning_

How Does MBRL Work?

MBRL follows these steps:

  1. Model Learning: The agent collect experience by interacting with the environment and then use these experiences to learn a model to predict future states and rewards.
  2. Model-Based Planning: After learning how the environment works, the agent uses that model to plan future steps without interacting with the real world. Algorithms like Monte Carlo Tree Search (MCTS) or Dynamic Programming can be used to identify optimal actions.
  3. Policy Optimization: The agent use the results from planning to optimize its policy which is then deployed back into the real environment.
  4. Continuous Learning: The model is updated regularly as the agent gathers new experiences, improving the model’s accuracy and the agent’s performance over time.

Several popular MBRL methods have emerged in recent years:

1. World Models

  • Description: This approach involves creating a model that simulates the environment's dynamics using past experiences. It often employs recurrent neural networks (RNNs) to learn the transition function f(s,a) from state s and action a.
  • Example: The "World Models" framework allows an agent to learn to play games solely from camera images without predefined state representations.

2. Model Predictive Control (MPC)

  • Description: MPC uses a model to predict future states and optimize control actions over a finite horizon. It samples trajectories based on the model and selects actions that maximize expected rewards.
  • Example: The Model Predictive Path Integral (MPPI) method is a stochastic optimal control technique that iteratively optimizes action sequences by sampling multiple trajectories.

3. Dyna Architecture

  • Description: This approach combines model learning, data generation and policy learning in an iterative process. The agent learns a model of the environment and uses it to generate synthetic experiences for training.
  • Example: Sutton's Dyna algorithm alternates between updating the model and refining the policy based on both real and simulated experiences.

4. Sampling-Based Planning

  • Description: Involves generating candidate action sequences through sampling and evaluating them using the learned model. Techniques like random shooting or cross-entropy methods help refine action selections.
  • Example: Algorithms such as PlaNet and PETS utilize sampling-based planning for effective decision-making in complex environments.

5. Analytic Gradient Computation

  • Description: This method leverages assumptions about the dynamics to compute gradients analytically, facilitating local optimal control solutions.
  • Example: Linear Quadratic Regulator (LQR) frameworks often utilize this approach for efficient control in linear systems.

6. Value-Equivalence Prediction

  • Description: This technique predicts value functions using models, allowing for better generalization and improved policy updates.
  • Example: Incorporating model-generated data into reinforcement learning can enhance learning efficiency, although it must be managed carefully to avoid divergence due to modeling errors.

7. Meta-Reinforcement Learning

  • Description: This approach focus on learning how to learn where agents adapt their strategies based on previous experiences across different tasks.
  • Example: Context-based meta-reinforcement learning utilizes world models to generalize across various environments and tasks.

Model-Based Reinforcement Learning with the Model Predictive Control (MPC)

We use the Random Shooting Method a basic form of Model Predictive Control (MPC) to optimize actions in Model-Based Reinforcement Learning (MBRL).

1. Dynamics Model: Neural network (2 hidden layers, ReLU activations) predicts the next state based on the current state and action.

2. Planning with Random Shooting: The agent samples multiple action sequences, simulating their outcomes over a short horizon (5 steps) and selects the best action based on total reward.

3. Replay Buffer: Stores experiences (state, action, next state) for training the model using past interactions.

4. Main Loop (500 Episodes):

  • The agent plans the best action using the dynamics model.
  • The action is executed and experience is stored in the replay buffer.
  • The model is periodically trained using random batches from the buffer.
  • The process repeats until episode termination.

5. Model Training: The model minimizes Mean Squared Error (MSE) between predicted and actual next states.

Python
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
import matplotlib.pyplot as plt

# Neural network for the dynamics model
class DynamicsModel(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(DynamicsModel, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, state_dim)
    
    def forward(self, state, action):
        if len(action.shape) == 1:
            action = action.unsqueeze(1)

        x = torch.cat([state, action], dim=1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        next_state_pred = self.fc3(x)
        return next_state_pred

# Planning function using the learned model (random shooting method)
def plan_action(dynamics_model, state, action_space, horizon=5, num_simulations=100):
    best_action = None
    best_return = -float('inf')

    for _ in range(num_simulations):
        simulated_state = state.clone()
        total_return = 0
        for t in range(horizon):
            action = torch.tensor([action_space.sample()], dtype=torch.float32)
            action = action.unsqueeze(0)
            
            next_state_pred = dynamics_model(simulated_state, action)
            reward = 1.0
            total_return += reward
            simulated_state = next_state_pred
        
        if total_return > best_return:
            best_return = total_return
            best_action = action

    return best_action

# Replay buffer to store experience
class ReplayBuffer:
    def __init__(self, max_size=10000):
        self.buffer = deque(maxlen=max_size)

    def add(self, experience):
        self.buffer.append(experience)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

# Main MBRL loop
def model_based_rl():
    env = gym.make('CartPole-v1')
    state_dim = env.observation_space.shape[0]
    action_dim = 1
    
    dynamics_model = DynamicsModel(state_dim, action_dim)
    optimizer = optim.Adam(dynamics_model.parameters(), lr=0.001)
    loss_fn = nn.MSELoss()
    replay_buffer = ReplayBuffer()

    num_episodes = 500
    batch_size = 64
    max_timesteps = 200

    total_rewards = []  # List to store total rewards per episode

    for episode in range(num_episodes):
        state = env.reset()
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        
        total_reward = 0
        for t in range(max_timesteps):
            env.render()

            action = plan_action(dynamics_model, state, env.action_space)
            action = int(action.item() > 0.5)  # Convert to discrete action for CartPole (0 or 1)
            
            next_state, reward, done, _ = env.step(action)
            next_state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)
            action = torch.tensor([[action]], dtype=torch.float32)

            replay_buffer.add((state, action, next_state))

            if len(replay_buffer.buffer) >= batch_size:
                batch = replay_buffer.sample(batch_size)
                states, actions, next_states = zip(*batch)
                states = torch.cat(states)
                actions = torch.cat(actions)
                next_states = torch.cat(next_states)

                next_state_preds = dynamics_model(states, actions)
                loss = loss_fn(next_state_preds, next_states)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            state = next_state
            total_reward += reward
            if done:
                break

        total_rewards.append(total_reward)  # Store the total reward for this episode
        print(f"Episode {episode + 1}: Total Reward = {total_reward}")

    env.close()

    # Plot the rewards
    plt.figure(figsize=(10, 5))
    plt.plot(total_rewards, label='Total Reward per Episode')
    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.title('Total Rewards Over Episodes')
    plt.legend()
    plt.show()

if __name__ == "__main__":
    model_based_rl()

Output:

Episode 1: Total Reward = 10.0
Episode 2: Total Reward = 36.0
Episode 3: Total Reward = 12.0
. . .
Episode 499: Total Reward = 14.0
Episode 500: Total Reward = 39.0

download
Rewards

Advantages of Model-Based Reinforcement Learning

  • Learns with Fewer Tries: The system builds a model to understand how things work, so it doesn’t need to try every action in the real world. This is helpful when testing is slow, expensive or risky.
  • Handles New Situations Better: Because it learns how things change, it can deal with new situations more easily.
  • Plans Ahead: It can imagine future steps and their results which helps in planning better—useful in tasks like controlling robots or playing games.

Challenges of Model-Based Reinforcement Learning

  • Hard to Build a Good Model: Making a model that truly represents a complicated environment can be tough. If the model is wrong the learning may go in the wrong direction.
  • Takes a Lot of Computing Power: The system has to both learn the model and plan with it which can take a lot of time and resources.
  • May Avoid Trying New Things: If it trusts its model too much it might stop exploring better options and miss the best solution.

Next Article

Similar Reads