-
Notifications
You must be signed in to change notification settings - Fork 6.2k
/
Copy pathmock.py
154 lines (127 loc) · 4.33 KB
/
mock.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import pickle
import time
import numpy as np
from ray.tune import result as tune_result
from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig
from ray.rllib.utils.annotations import override
class _MockTrainer(Algorithm):
"""Mock Algorithm for use in tests."""
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfig:
return (
AlgorithmConfig()
.framework("tf")
.update_from_dict(
{
"mock_error": False,
"persistent_error": False,
"test_variable": 1,
"user_checkpoint_freq": 0,
"sleep": 0,
}
)
)
@classmethod
def default_resource_request(cls, config: AlgorithmConfig):
return None
@override(Algorithm)
def setup(self, config):
self.callbacks = self.config.callbacks_class()
# Add needed properties.
self.info = None
self.restored = False
@override(Algorithm)
def step(self):
if (
self.config.mock_error
and self.iteration == 1
and (self.config.persistent_error or not self.restored)
):
raise Exception("mock error")
if self.config.sleep:
time.sleep(self.config.sleep)
result = dict(
episode_reward_mean=10, episode_len_mean=10, timesteps_this_iter=10, info={}
)
if self.config.user_checkpoint_freq > 0 and self.iteration > 0:
if self.iteration % self.config.user_checkpoint_freq == 0:
result.update({tune_result.SHOULD_CHECKPOINT: True})
return result
@override(Algorithm)
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
with open(path, "wb") as f:
pickle.dump(self.info, f)
@override(Algorithm)
def load_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
with open(path, "rb") as f:
info = pickle.load(f)
self.info = info
self.restored = True
@staticmethod
@override(Algorithm)
def _get_env_id_and_creator(env_specifier, config):
# No env to register.
return None, None
def set_info(self, info):
self.info = info
return info
def get_info(self, sess=None):
return self.info
class _SigmoidFakeData(_MockTrainer):
"""Algorithm that returns sigmoid learning curves.
This can be helpful for evaluating early stopping algorithms."""
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfig:
return AlgorithmConfig().update_from_dict(
{
"width": 100,
"height": 100,
"offset": 0,
"iter_time": 10,
"iter_timesteps": 1,
}
)
def step(self):
i = max(0, self.iteration - self.config.offset)
v = np.tanh(float(i) / self.config.width)
v *= self.config.height
return dict(
episode_reward_mean=v,
episode_len_mean=v,
timesteps_this_iter=self.config.iter_timesteps,
time_this_iter_s=self.config.iter_time,
info={},
)
class _ParameterTuningTrainer(_MockTrainer):
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfig:
return AlgorithmConfig().update_from_dict(
{
"reward_amt": 10,
"dummy_param": 10,
"dummy_param2": 15,
"iter_time": 10,
"iter_timesteps": 1,
}
)
def step(self):
return dict(
episode_reward_mean=self.config.reward_amt * self.iteration,
episode_len_mean=self.config.reward_amt,
timesteps_this_iter=self.config.iter_timesteps,
time_this_iter_s=self.config.iter_time,
info={},
)
def _algorithm_import_failed(trace):
"""Returns dummy Algorithm class for if PyTorch etc. is not installed."""
class _AlgorithmImportFailed(Algorithm):
_name = "AlgorithmImportFailed"
def setup(self, config):
raise ImportError(trace)
return _AlgorithmImportFailed