-
Notifications
You must be signed in to change notification settings - Fork 6.2k
/
Copy pathepsilon_greedy.py
246 lines (216 loc) · 9.21 KB
/
epsilon_greedy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import gymnasium as gym
import numpy as np
import tree # pip install dm_tree
import random
from typing import Union, Optional
from ray.rllib.models.torch.torch_action_dist import TorchMultiActionDistribution
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.utils.annotations import override, OldAPIStack
from ray.rllib.utils.exploration.exploration import Exploration, TensorType
from ray.rllib.utils.framework import try_import_tf, try_import_torch, get_variable
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.schedules import Schedule, PiecewiseSchedule
from ray.rllib.utils.torch_utils import FLOAT_MIN
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@OldAPIStack
class EpsilonGreedy(Exploration):
"""Epsilon-greedy Exploration class that produces exploration actions.
When given a Model's output and a current epsilon value (based on some
Schedule), it produces a random action (if rand(1) < eps) or
uses the model-computed one (if rand(1) >= eps).
"""
def __init__(
self,
action_space: gym.spaces.Space,
*,
framework: str,
initial_epsilon: float = 1.0,
final_epsilon: float = 0.05,
warmup_timesteps: int = 0,
epsilon_timesteps: int = int(1e5),
epsilon_schedule: Optional[Schedule] = None,
**kwargs,
):
"""Create an EpsilonGreedy exploration class.
Args:
action_space: The action space the exploration should occur in.
framework: The framework specifier.
initial_epsilon: The initial epsilon value to use.
final_epsilon: The final epsilon value to use.
warmup_timesteps: The timesteps over which to not change epsilon in the
beginning.
epsilon_timesteps: The timesteps (additional to `warmup_timesteps`)
after which epsilon should always be `final_epsilon`.
E.g.: warmup_timesteps=20k epsilon_timesteps=50k -> After 70k timesteps,
epsilon will reach its final value.
epsilon_schedule: An optional Schedule object
to use (instead of constructing one from the given parameters).
"""
assert framework is not None
super().__init__(action_space=action_space, framework=framework, **kwargs)
self.epsilon_schedule = from_config(
Schedule, epsilon_schedule, framework=framework
) or PiecewiseSchedule(
endpoints=[
(0, initial_epsilon),
(warmup_timesteps, initial_epsilon),
(warmup_timesteps + epsilon_timesteps, final_epsilon),
],
outside_value=final_epsilon,
framework=self.framework,
)
# The current timestep value (tf-var or python int).
self.last_timestep = get_variable(
np.array(0, np.int64),
framework=framework,
tf_name="timestep",
dtype=np.int64,
)
# Build the tf-info-op.
if self.framework == "tf":
self._tf_state_op = self.get_state()
@override(Exploration)
def get_exploration_action(
self,
*,
action_distribution: ActionDistribution,
timestep: Union[int, TensorType],
explore: Optional[Union[bool, TensorType]] = True,
):
if self.framework in ["tf2", "tf"]:
return self._get_tf_exploration_action_op(
action_distribution, explore, timestep
)
else:
return self._get_torch_exploration_action(
action_distribution, explore, timestep
)
def _get_tf_exploration_action_op(
self,
action_distribution: ActionDistribution,
explore: Union[bool, TensorType],
timestep: Union[int, TensorType],
) -> "tf.Tensor":
"""TF method to produce the tf op for an epsilon exploration action.
Args:
action_distribution: The instantiated ActionDistribution object
to work with when creating exploration actions.
Returns:
The tf exploration-action op.
"""
# TODO: Support MultiActionDistr for tf.
q_values = action_distribution.inputs
epsilon = self.epsilon_schedule(
timestep if timestep is not None else self.last_timestep
)
# Get the exploit action as the one with the highest logit value.
exploit_action = tf.argmax(q_values, axis=1)
batch_size = tf.shape(q_values)[0]
# Mask out actions with q-value=-inf so that we don't even consider
# them for exploration.
random_valid_action_logits = tf.where(
tf.equal(q_values, tf.float32.min),
tf.ones_like(q_values) * tf.float32.min,
tf.ones_like(q_values),
)
random_actions = tf.squeeze(
tf.random.categorical(random_valid_action_logits, 1), axis=1
)
chose_random = (
tf.random.uniform(
tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32
)
< epsilon
)
action = tf.cond(
pred=tf.constant(explore, dtype=tf.bool)
if isinstance(explore, bool)
else explore,
true_fn=(lambda: tf.where(chose_random, random_actions, exploit_action)),
false_fn=lambda: exploit_action,
)
if self.framework == "tf2" and not self.policy_config["eager_tracing"]:
self.last_timestep = timestep
return action, tf.zeros_like(action, dtype=tf.float32)
else:
assign_op = tf1.assign(self.last_timestep, tf.cast(timestep, tf.int64))
with tf1.control_dependencies([assign_op]):
return action, tf.zeros_like(action, dtype=tf.float32)
def _get_torch_exploration_action(
self,
action_distribution: ActionDistribution,
explore: bool,
timestep: Union[int, TensorType],
) -> "torch.Tensor":
"""Torch method to produce an epsilon exploration action.
Args:
action_distribution: The instantiated
ActionDistribution object to work with when creating
exploration actions.
Returns:
The exploration-action.
"""
q_values = action_distribution.inputs
self.last_timestep = timestep
exploit_action = action_distribution.deterministic_sample()
batch_size = q_values.size()[0]
action_logp = torch.zeros(batch_size, dtype=torch.float)
# Explore.
if explore:
# Get the current epsilon.
epsilon = self.epsilon_schedule(self.last_timestep)
if isinstance(action_distribution, TorchMultiActionDistribution):
exploit_action = tree.flatten(exploit_action)
for i in range(batch_size):
if random.random() < epsilon:
# TODO: (bcahlit) Mask out actions
random_action = tree.flatten(self.action_space.sample())
for j in range(len(exploit_action)):
exploit_action[j][i] = torch.tensor(random_action[j])
exploit_action = tree.unflatten_as(
action_distribution.action_space_struct, exploit_action
)
return exploit_action, action_logp
else:
# Mask out actions, whose Q-values are -inf, so that we don't
# even consider them for exploration.
random_valid_action_logits = torch.where(
q_values <= FLOAT_MIN,
torch.ones_like(q_values) * 0.0,
torch.ones_like(q_values),
)
# A random action.
random_actions = torch.squeeze(
torch.multinomial(random_valid_action_logits, 1), axis=1
)
# Pick either random or greedy.
action = torch.where(
torch.empty((batch_size,)).uniform_().to(self.device) < epsilon,
random_actions,
exploit_action,
)
return action, action_logp
# Return the deterministic "sample" (argmax) over the logits.
else:
return exploit_action, action_logp
@override(Exploration)
def get_state(self, sess: Optional["tf.Session"] = None):
if sess:
return sess.run(self._tf_state_op)
eps = self.epsilon_schedule(self.last_timestep)
return {
"cur_epsilon": convert_to_numpy(eps) if self.framework != "tf" else eps,
"last_timestep": convert_to_numpy(self.last_timestep)
if self.framework != "tf"
else self.last_timestep,
}
@override(Exploration)
def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None:
if self.framework == "tf":
self.last_timestep.load(state["last_timestep"], session=sess)
elif isinstance(self.last_timestep, int):
self.last_timestep = state["last_timestep"]
else:
self.last_timestep.assign(state["last_timestep"])