-
Notifications
You must be signed in to change notification settings - Fork 6.2k
/
Copy pathray_serve_with_rllib.py
190 lines (154 loc) · 6.54 KB
/
ray_serve_with_rllib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""Example on how to run RLlib in combination with Ray Serve.
This example trains an agent with PPO on the CartPole environment, then creates
an RLModule checkpoint and returns its location. After that, it sends the checkpoint
to the Serve deployment for serving the trained RLModule (policy).
This example:
- shows how to set up a Ray Serve deployment for serving an already trained
RLModule (policy network).
- shows how to request new actions from the Ray Serve deployment while actually
running through episodes in an environment (on which the RLModule that's served
was trained).
How to run this script
----------------------
`python [script file name].py --enable-new-api-stack --stop-reward=200.0`
Use the `--stop-iters`, `--stop-reward`, and/or `--stop-timesteps` options to
determine how long to train the policy for. Use the `--serve-episodes` option to
set the number of episodes to serve (after training) and the `--no-render` option
to NOT render the environment during the serving phase.
For debugging, use the following additional command line options
`--no-tune --num-env-runners=0`
which should allow you to set breakpoints anywhere in the RLlib code and
have the execution stop there for inspection and debugging.
For logging to your WandB account, use:
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
--wandb-run-name=[optional: WandB run name (within the defined project)]`
You can visualize experiment results in ~/ray_results using TensorBoard.
Results to expect
-----------------
You should see something similar to the following on the command line when using the
options: `--stop-reward=250.0`, `--num-episodes-served=2`, and `--port=12345`:
[First, the RLModule is trained through PPO]
+-----------------------------+------------+-----------------+--------+
| Trial name | status | loc | iter |
| | | | |
|-----------------------------+------------+-----------------+--------+
| PPO_CartPole-v1_84778_00000 | TERMINATED | 127.0.0.1:40411 | 1 |
+-----------------------------+------------+-----------------+--------+
+------------------+---------------------+------------------------+
| total time (s) | episode_return_mean | num_env_steps_sample |
| | | d_lifetime |
|------------------+---------------------|------------------------|
| 2.87052 | 253.2 | 12000 |
+------------------+---------------------+------------------------+
[The RLModule is deployed through Ray Serve on port 12345]
Started Ray Serve with PID: 40458
[A few episodes are played through using the policy service (w/ greedy, non-exploratory
actions)]
Episode R=500.0
Episode R=500.0
"""
import atexit
import os
import requests
import subprocess
import time
import gymnasium as gym
from pathlib import Path
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core import (
COMPONENT_LEARNER_GROUP,
COMPONENT_LEARNER,
COMPONENT_RL_MODULE,
DEFAULT_MODULE_ID,
)
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
)
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
run_rllib_example_script_experiment,
)
parser = add_rllib_example_script_args()
parser.set_defaults(
enable_new_api_stack=True,
checkpoint_freq=1,
checkpoint_at_and=True,
)
parser.add_argument("--num-episodes-served", type=int, default=2)
parser.add_argument("--no-render", action="store_true")
parser.add_argument("--port", type=int, default=12345)
def kill_proc(proc):
try:
proc.terminate() # Send SIGTERM
proc.wait(timeout=5) # Wait for process to terminate
except subprocess.TimeoutExpired:
proc.kill() # Send SIGKILL
proc.wait() # Ensure process is dead
if __name__ == "__main__":
args = parser.parse_args()
# Config for the served RLlib RLModule/Algorithm.
base_config = PPOConfig().environment("CartPole-v1")
results = run_rllib_example_script_experiment(base_config, args)
algo_checkpoint = results.get_best_result(
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}"
).checkpoint.path
# We only need the RLModule component from the algorithm checkpoint. It's located
# under "[algo checkpoint dir]/learner_group/learner/rl_module/[default policy ID]
rl_module_checkpoint = (
Path(algo_checkpoint)
/ COMPONENT_LEARNER_GROUP
/ COMPONENT_LEARNER
/ COMPONENT_RL_MODULE
/ DEFAULT_MODULE_ID
)
path_of_this_file = Path(__file__).parent
os.chdir(path_of_this_file)
# Start the serve app with the trained checkpoint.
serve_proc = subprocess.Popen(
[
"serve",
"run",
"classes.cartpole_deployment:rl_module",
f"rl_module_checkpoint={rl_module_checkpoint}",
f"port={args.port}",
"route_prefix=/rllib-rlmodule",
]
)
# Register our `kill_proc` function to be called on exit to stop Ray Serve again.
atexit.register(kill_proc, serve_proc)
# Wait a while to make sure the app is ready to serve.
time.sleep(20)
print(f"Started Ray Serve with PID: {serve_proc.pid}")
try:
# Create the environment that we would like to receive
# served actions for.
env = gym.make("CartPole-v1", render_mode="human")
obs, _ = env.reset()
num_episodes = 0
episode_return = 0.0
while num_episodes < args.num_episodes_served:
# Render env if necessary.
if not args.no_render:
env.render()
# print(f"-> Requesting action for obs={obs} ...", end="")
# Send a request to serve.
resp = requests.get(
f"http://localhost:{args.port}/rllib-rlmodule",
json={"observation": obs.tolist()},
)
response = resp.json()
# print(f" received: action={response['action']}")
# Apply the action in the env.
action = response["action"]
obs, reward, terminated, truncated, _ = env.step(action)
episode_return += reward
# If episode done -> reset to get initial observation of new episode.
if terminated or truncated:
print(f"Episode R={episode_return}")
obs, _ = env.reset()
num_episodes += 1
episode_return = 0.0
finally:
# Make sure to kill the process on script termination
kill_proc(serve_proc)