-
Notifications
You must be signed in to change notification settings - Fork 6.2k
/
Copy pathvpg_custom_algorithm.py
117 lines (96 loc) · 5.03 KB
/
vpg_custom_algorithm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Example of how to write a custom Algorithm.
This is an end-to-end example for how to implement a custom Algorithm, including
a matching AlgorithmConfig class and Learner class. There is no particular RLModule API
needed for this algorithm, which means that any TorchRLModule returning actions
or action distribution parameters suffices.
The RK algorithm implemented here is "vanilla policy gradient" (VPG) in its simplest
form, without a value function baseline.
See the actual VPG algorithm class here:
https://github.com/ray-project/ray/blob/master/rllib/examples/algorithms/classes/vpg.py
The Learner class the algorithm uses by default (if the user doesn't specify a custom
Learner):
https://github.com/ray-project/ray/blob/master/rllib/examples/learners/classes/vpg_torch_learner.py # noqa
And the RLModule class the algorithm uses by default (if the user doesn't specify a
custom RLModule):
https://github.com/ray-project/ray/blob/master/rllib/examples/rl_modules/classes/vpg_torch_rlm.py # noqa
This example shows:
- how to subclass the AlgorithmConfig base class to implement a custom algorithm's.
config class.
- how to subclass the Algorithm base class to implement a custom Algorithm,
including its `training_step` method.
- how to subclass the TorchLearner base class to implement a custom Learner with
loss function, overriding `compute_loss_for_module` and
`after_gradient_based_update`.
- how to define a default RLModule used by the algorithm in case the user
doesn't bring their own custom RLModule. The VPG algorithm doesn't require any
specific RLModule APIs, so any RLModule returning actions or action distribution
inputs suffices.
We compute a plain policy gradient loss without value function baseline.
The experiment shows that even with such a simple setup, our custom algorithm is still
able to successfully learn CartPole-v1.
How to run this script
----------------------
`python [script file name].py --enable-new-api-stack`
For debugging, use the following additional command line options
`--no-tune --num-env-runners=0`
which should allow you to set breakpoints anywhere in the RLlib code and
have the execution stop there for inspection and debugging.
For logging to your WandB account, use:
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
--wandb-run-name=[optional: WandB run name (within the defined project)]`
Results to expect
-----------------
With some fine-tuning of the learning rate, the batch size, and maybe the
number of env runners and number of envs per env runner, you should see decent
learning behavior on the CartPole-v1 environment:
+-----------------------------+------------+--------+------------------+
| Trial name | status | iter | total time (s) |
| | | | |
|-----------------------------+------------+--------+------------------+
| VPG_CartPole-v1_2973e_00000 | TERMINATED | 451 | 59.5184 |
+-----------------------------+------------+--------+------------------+
+-----------------------+------------------------+------------------------+
| episode_return_mean | num_env_steps_sample | ...env_steps_sampled |
| | d_lifetime | _lifetime_throughput |
|-----------------------+------------------------+------------------------|
| 250.52 | 415787 | 7428.98 |
+-----------------------+------------------------+------------------------+
"""
from ray.rllib.examples.algorithms.classes.vpg import VPGConfig
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
run_rllib_example_script_experiment,
)
parser = add_rllib_example_script_args(
default_reward=250.0,
default_iters=1000,
default_timesteps=750000,
)
parser.set_defaults(enable_new_api_stack=True)
if __name__ == "__main__":
args = parser.parse_args()
base_config = (
VPGConfig()
.environment("CartPole-v1")
.training(
# The only VPG-specific setting. How many episodes per train batch?
num_episodes_per_train_batch=10,
# Set other config parameters.
lr=0.0005,
# Note that you don't have to set any specific Learner class, because
# our custom Algorithm already defines the default Learner class to use
# through its `get_default_learner_class` method, which returns
# `VPGTorchLearner`.
# learner_class=VPGTorchLearner,
)
# Increase the number of EnvRunners (default is 1 for VPG)
# or the number of envs per EnvRunner.
.env_runners(num_env_runners=2, num_envs_per_env_runner=1)
# Plug in your own RLModule class. VPG doesn't require any specific
# RLModule APIs, so any RLModule returning `actions` or `action_dist_inputs`
# from the forward methods works ok.
# .rl_module(
# rl_module_spec=RLModuleSpec(module_class=...),
# )
)
run_rllib_example_script_experiment(base_config, args)