-
Notifications
You must be signed in to change notification settings - Fork 6.3k
/
Copy pathpong_appo.py
87 lines (72 loc) · 2.55 KB
/
pong_appo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gymnasium as gym
from ray.rllib.algorithms.appo import APPOConfig
from ray.rllib.connectors.env_to_module.frame_stacking import FrameStackingEnvToModule
from ray.rllib.connectors.learner.frame_stacking import FrameStackingLearner
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack
from ray.rllib.utils.test_utils import add_rllib_example_script_args
from ray.tune.registry import register_env
parser = add_rllib_example_script_args(
default_reward=20.0,
default_timesteps=10000000,
)
parser.set_defaults(
enable_new_api_stack=True,
env="ale_py:ALE/Pong-v5",
)
args = parser.parse_args()
def _make_env_to_module_connector(env):
return FrameStackingEnvToModule(num_frames=4)
def _make_learner_connector(input_observation_space, input_action_space):
return FrameStackingLearner(num_frames=4)
def _env_creator(cfg):
return wrap_atari_for_new_api_stack(
gym.make(args.env, **cfg, **{"render_mode": "rgb_array"}),
dim=64,
framestack=None,
)
register_env("env", _env_creator)
config = (
APPOConfig()
.environment(
"env",
env_config={
# Make analogous to old v4 + NoFrameskip.
"frameskip": 1,
"full_action_space": False,
"repeat_action_probability": 0.0,
},
clip_rewards=True,
)
.env_runners(
env_to_module_connector=_make_env_to_module_connector,
num_envs_per_env_runner=2,
)
.learners(
num_aggregator_actors_per_learner=2,
)
.training(
learner_connector=_make_learner_connector,
train_batch_size_per_learner=500,
target_network_update_freq=2,
lr=0.0005 * ((args.num_learners or 1) ** 0.5),
vf_loss_coeff=1.0,
entropy_coeff=[[0, 0.01], [3000000, 0.0]], # <- crucial parameter to finetune
# Only update connector states and model weights every n training_step calls.
broadcast_interval=5,
# learner_queue_size=1,
circular_buffer_num_batches=4,
circular_buffer_iterations_per_batch=2,
)
.rl_module(
model_config=DefaultModelConfig(
vf_share_layers=True,
conv_filters=[(16, 4, 2), (32, 4, 2), (64, 4, 2), (128, 4, 2)],
conv_activation="relu",
head_fcnet_hiddens=[256],
)
)
)
if __name__ == "__main__":
from ray.rllib.utils.test_utils import run_rllib_example_script_experiment
run_rllib_example_script_experiment(config, args)