【强化学习】港中大强化学习课程Assignment解析 02
课程相关
- 课程首页:https://cuhkrlcourse.github.io/
- 视频链接:https://space.bilibili.com/511221970/channel/seriesdetail?sid=764099【B站】
- 相关资料:https://datawhalechina.github.io/easy-rl/#/【EasyRL】
- Reinforcement Learning: An Introduction:https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf
- Github首页(作业获取):https://github.com/cuhkrlcourse/ierg5350-assignment-2021
- Gitee(我的解析):https://gitee.com/cstern-liao/cuhk_rl_assignment
1 SARSA
SARSA方法的名字来由是State-action-reward-state-action
具体做法是通过下一个状态动作值来更新当前的状态动作值,算法步骤如下:
总结来说就是,不断循环去让agent与环境交互,始终遵循 ε\varepsilonε-greedy 策略去选取动作,每走一步就在线(即一步一步地)更新状态动作值。理论相关的东西大家可以去上面提供的资料里了解,直接上代码啦。
class TabularRLTrainerAbstract:
"""This is the abstract class for tabular RL trainer. We will inherent the specify
algorithm's trainer from this abstract class, so that we can reuse the codes like
getting the dynamic of the environment (self._get_transitions()) or rendering the
learned policy (self.render())."""
def __init__(self, env_name='FrozenLake8x8-v1', model_based=True):
self.env_name = env_name
self.env = gym.make(self.env_name)
self.action_dim = self.env.action_space.n
self.obs_dim = self.env.observation_space.n
self.model_based = model_based
def _get_transitions(self, state, act):
"""Query the environment to get the transition probability,
reward, the next state, and done given a pair of state and action.
We implement this function for you. But you need to know the
return format of this function.
"""
self._check_env_name()
assert self.model_based, "You should not use _get_transitions in " \
"model-free algorithm!"
# call the internal attribute of the environments.
# `transitions` is a list contain all possible next states and the
# probability, reward, and termination indicater corresponding to it
transitions = self.env.env.P[state][act]
# Given a certain state and action pair, it is possible
# to find there exist multiple transitions, since the
# environment is not deterministic.
# You need to know the return format of this function: a list of dicts
ret = []
for prob, next_state, reward, done in transitions:
ret.append({
"prob": prob,
"next_state": next_state,
"reward": reward,
"done": done
})
return ret
def _check_env_name(self):
assert self.env_name.startswith('FrozenLake')
def print_table(self):
"""print beautiful table, only work for FrozenLake8X8-v1 env. We
write this function for you."""
self._check_env_name()
print_table(self.table)
def train(self):
"""Conduct one iteration of learning."""
raise NotImplementedError("You need to override the "
"Trainer.train() function.")
def evaluate(self):
"""Use the function you write to evaluate current policy.
Return the mean episode reward of 1000 episodes when seed=0."""
result = evaluate(self.policy, 1000, env_name=self.env_name)
return result
def render(self):
"""Reuse your evaluate function, render current policy
for one episode when seed=0"""
evaluate(self.policy, 1, render=True, env_name=self.env_name)
#%%
# Solve the TODOs and remove `pass`
class SARSATrainer(TabularRLTrainerAbstract):
def __init__(self,
gamma=1.0,
eps=0.1,
learning_rate=1.0,
max_episode_length=100,
env_name='FrozenLake8x8-v1'
):
super(SARSATrainer, self).__init__(env_name, model_based=False)
# discount factor
self.gamma = gamma
# epsilon-greedy exploration policy parameter
self.eps = eps
# maximum steps in single episode
self.max_episode_length = max_episode_length
# the learning rate
self.learning_rate = learning_rate
# build the Q table
# [TODO] uncomment the next line, pay attention to the shape
self.table = np.zeros((self.obs_