{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "UHvnYSvvdDCg"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import random\n",
        "\n",
        "\n",
        "class GridWorld:\n",
        "    def __init__(self, width, height, start, goal, obstacles):\n",
        "        self.width = width\n",
        "        self.height = height\n",
        "        self.start = start\n",
        "        self.goal = goal\n",
        "        self.obstacles = obstacles\n",
        "        self.state = start\n",
        "\n",
        "    def reset(self):\n",
        "        self.state = self.start\n",
        "        return self.state\n",
        "\n",
        "    def step(self, action):\n",
        "        x, y = self.state\n",
        "        if action == 0:\n",
        "            x = max(x - 1, 0)\n",
        "        elif action == 1:\n",
        "            x = min(x + 1, self.height - 1)\n",
        "        elif action == 2:\n",
        "            y = max(y - 1, 0)\n",
        "        elif action == 3:\n",
        "            y = min(y + 1, self.width - 1)\n",
        "\n",
        "        next_state = (x, y)\n",
        "\n",
        "        if next_state in self.obstacles:\n",
        "            reward = -10\n",
        "            done = True\n",
        "        elif next_state == self.goal:\n",
        "            reward = 10\n",
        "            done = True\n",
        "        else:\n",
        "            reward = -1\n",
        "            done = False\n",
        "\n",
        "        self.state = next_state\n",
        "        return next_state, reward, done"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def sarsa(env, episodes, alpha, gamma, epsilon):\n",
        "    Q = np.zeros((env.height, env.width, 4))\n",
        "\n",
        "    for episode in range(episodes):\n",
        "        state = env.reset()\n",
        "        action = epsilon_greedy_policy(Q, state, epsilon)\n",
        "        done = False\n",
        "\n",
        "        while not done:\n",
        "            next_state, reward, done = env.step(action)\n",
        "            next_action = epsilon_greedy_policy(Q, next_state, epsilon)\n",
        "\n",
        "            Q[state[0], state[1], action] += alpha * \\\n",
        "                (reward + gamma * Q[next_state[0], next_state[1],\n",
        "                 next_action] - Q[state[0], state[1], action])\n",
        "\n",
        "            state = next_state\n",
        "            action = next_action\n",
        "\n",
        "    return Q"
      ],
      "metadata": {
        "id": "PVr2AEFkdFAb"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def epsilon_greedy_policy(Q, state, epsilon):\n",
        "    if random.uniform(0, 1) < epsilon:\n",
        "        return random.randint(0, 3)\n",
        "    else:\n",
        "        q_values = Q[state[0], state[1]]\n",
        "        max_q = np.max(q_values)\n",
        "        best_actions = np.where(q_values == max_q)[0]\n",
        "        return np.random.choice(best_actions)"
      ],
      "metadata": {
        "id": "THtAh2HqdJj5"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "if __name__ == \"__main__\":\n",
        "\n",
        "    width = 5\n",
        "    height = 5\n",
        "    start = (0, 0)\n",
        "    goal = (4, 4)\n",
        "    obstacles = [(2, 2), (3, 2)]\n",
        "    env = GridWorld(width, height, start, goal, obstacles)\n",
        "\n",
        "    episodes = 1000\n",
        "    alpha = 0.1\n",
        "    gamma = 0.99\n",
        "    epsilon = 0.1\n",
        "\n",
        "    Q = sarsa(env, episodes, alpha, gamma, epsilon)\n",
        "\n",
        "    print(\"Learned Q-values:\")\n",
        "    print(Q)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JI4dyxIJdQZT",
        "outputId": "29c6efe7-ae8a-4bf0-e131-6146cef78826"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Learned Q-values:\n",
            "[[[-8.04838360e-01 -1.85411867e+00 -7.05829566e-01  1.05373225e+00]\n",
            "  [ 3.98875940e-01  4.77539429e-01 -8.84406329e-01  2.29605159e+00]\n",
            "  [ 1.76010876e+00  2.96340545e+00  4.04240080e-03  3.52732887e+00]\n",
            "  [ 2.56881594e+00  4.76124217e+00  5.85886187e-01  1.75068735e+00]\n",
            "  [-5.54862273e-01  5.51577487e+00  2.83247321e-01 -1.05295398e+00]]\n",
            "\n",
            " [[-3.09305026e+00 -2.69157157e+00 -2.47726923e+00  4.76361397e-01]\n",
            "  [-2.18086110e+00 -2.18944998e+00 -2.13437600e+00  2.22789896e+00]\n",
            "  [-2.66343371e-01 -3.43900000e+00 -1.58967192e+00  4.57561167e+00]\n",
            "  [ 2.93178677e+00  3.26518041e+00  2.49604817e+00  6.13540729e+00]\n",
            "  [ 2.92462953e+00  7.47725020e+00  3.95987419e+00  5.05655601e+00]]\n",
            "\n",
            " [[-2.22666418e+00 -2.13054417e+00 -2.27919936e+00 -2.16434496e+00]\n",
            "  [-1.91308245e+00 -1.85264347e+00 -1.96782075e+00 -1.90000000e+00]\n",
            "  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
            "  [-2.51020743e-01  2.17259045e-01 -3.43900000e+00  6.78556078e+00]\n",
            "  [ 5.35088477e+00  8.76701798e+00  3.80250343e+00  6.55629787e+00]]\n",
            "\n",
            " [[-1.59443914e+00 -1.62531442e+00 -1.59589189e+00 -1.59854212e+00]\n",
            "  [-1.39502463e+00 -5.74977908e-01 -1.36485960e+00 -1.90000000e+00]\n",
            "  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
            "  [ 0.00000000e+00  1.44328813e+00 -2.71000000e+00  7.85827499e+00]\n",
            "  [ 5.73771142e+00  1.00000000e+01  3.96517610e+00  7.99426292e+00]]\n",
            "\n",
            " [[-1.24380433e+00 -1.15693957e+00 -1.20114395e+00 -1.10074877e+00]\n",
            "  [-1.20147838e+00 -8.70036101e-01 -7.32672624e-01  2.08474111e+00]\n",
            "  [-1.00000000e+00 -4.81992878e-01 -3.66408232e-01  5.74361064e+00]\n",
            "  [-1.66194361e-01 -1.90000000e-01  3.05786943e-02  9.28210201e+00]\n",
            "  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "NxyBuLa9dRw2"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}