-
Notifications
You must be signed in to change notification settings - Fork 6.2k
/
Copy pathsac.py
587 lines (540 loc) · 27 KB
/
sac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
import logging
from typing import Any, Dict, Optional, Tuple, Type, Union
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.dqn.dqn import DQN
from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.core.learner import Learner
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.policy.policy import Policy
from ray.rllib.utils import deep_update
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer
from ray.rllib.utils.typing import LearningRateOrSchedule, RLModuleSpecType
tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
logger = logging.getLogger(__name__)
class SACConfig(AlgorithmConfig):
"""Defines a configuration class from which an SAC Algorithm can be built.
.. testcode::
config = (
SACConfig()
.environment("Pendulum-v1")
.env_runners(num_env_runners=1)
.training(
gamma=0.9,
actor_lr=0.001,
critic_lr=0.002,
train_batch_size_per_learner=32,
)
)
# Build the SAC algo object from the config and run 1 training iteration.
algo = config.build()
algo.train()
"""
def __init__(self, algo_class=None):
self.exploration_config = {
# The Exploration class to use. In the simplest case, this is the name
# (str) of any class present in the `rllib.utils.exploration` package.
# You can also provide the python class directly or the full location
# of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
# EpsilonGreedy").
"type": "StochasticSampling",
# Add constructor kwargs here (if any).
}
super().__init__(algo_class=algo_class or SAC)
# fmt: off
# __sphinx_doc_begin__
# SAC-specific config settings.
# `.training()`
self.twin_q = True
self.q_model_config = {
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"custom_model": None, # Use this to define custom Q-model(s).
"custom_model_config": {},
}
self.policy_model_config = {
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"custom_model": None, # Use this to define a custom policy model.
"custom_model_config": {},
}
self.clip_actions = False
self.tau = 5e-3
self.initial_alpha = 1.0
self.target_entropy = "auto"
self.n_step = 1
# Replay buffer configuration.
self.replay_buffer_config = {
"type": "PrioritizedEpisodeReplayBuffer",
# Size of the replay buffer. Note that if async_updates is set,
# then each worker will have a replay buffer of this size.
"capacity": int(1e6),
"alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"beta": 0.4,
}
self.store_buffer_in_checkpoints = False
self.training_intensity = None
self.optimization = {
"actor_learning_rate": 3e-4,
"critic_learning_rate": 3e-4,
"entropy_learning_rate": 3e-4,
}
self.actor_lr = 3e-5
self.critic_lr = 3e-4
self.alpha_lr = 3e-4
# Set `lr` parameter to `None` and ensure it is not used.
self.lr = None
self.grad_clip = None
self.target_network_update_freq = 0
# .env_runners()
# Set to `self.n_step`, if 'auto'.
self.rollout_fragment_length = "auto"
# .training()
self.train_batch_size_per_learner = 256
self.train_batch_size = 256 # @OldAPIstack
# Number of timesteps to collect from rollout workers before we start
# sampling from replay buffers for learning. Whether we count this in agent
# steps or environment steps depends on config.multi_agent(count_steps_by=..).
self.num_steps_sampled_before_learning_starts = 1500
# .reporting()
self.min_time_s_per_iteration = 1
self.min_sample_timesteps_per_iteration = 100
# __sphinx_doc_end__
# fmt: on
self._deterministic_loss = False
self._use_beta_distribution = False
self.use_state_preprocessor = DEPRECATED_VALUE
self.worker_side_prioritization = DEPRECATED_VALUE
@override(AlgorithmConfig)
def training(
self,
*,
twin_q: Optional[bool] = NotProvided,
q_model_config: Optional[Dict[str, Any]] = NotProvided,
policy_model_config: Optional[Dict[str, Any]] = NotProvided,
tau: Optional[float] = NotProvided,
initial_alpha: Optional[float] = NotProvided,
target_entropy: Optional[Union[str, float]] = NotProvided,
n_step: Optional[Union[int, Tuple[int, int]]] = NotProvided,
store_buffer_in_checkpoints: Optional[bool] = NotProvided,
replay_buffer_config: Optional[Dict[str, Any]] = NotProvided,
training_intensity: Optional[float] = NotProvided,
clip_actions: Optional[bool] = NotProvided,
grad_clip: Optional[float] = NotProvided,
optimization_config: Optional[Dict[str, Any]] = NotProvided,
actor_lr: Optional[LearningRateOrSchedule] = NotProvided,
critic_lr: Optional[LearningRateOrSchedule] = NotProvided,
alpha_lr: Optional[LearningRateOrSchedule] = NotProvided,
target_network_update_freq: Optional[int] = NotProvided,
_deterministic_loss: Optional[bool] = NotProvided,
_use_beta_distribution: Optional[bool] = NotProvided,
num_steps_sampled_before_learning_starts: Optional[int] = NotProvided,
**kwargs,
) -> "SACConfig":
"""Sets the training related configuration.
Args:
twin_q: Use two Q-networks (instead of one) for action-value estimation.
Note: Each Q-network will have its own target network.
q_model_config: Model configs for the Q network(s). These will override
MODEL_DEFAULTS. This is treated just as the top-level `model` dict in
setting up the Q-network(s) (2 if twin_q=True).
That means, you can do for different observation spaces:
`obs=Box(1D)` -> `Tuple(Box(1D) + Action)` -> `concat` -> `post_fcnet`
obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action
-> post_fcnet
obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action)
-> vision-net -> concat w/ Box(1D) and action -> post_fcnet
You can also have SAC use your custom_model as Q-model(s), by simply
specifying the `custom_model` sub-key in below dict (just like you would
do in the top-level `model` dict.
policy_model_config: Model options for the policy function (see
`q_model_config` above for details). The difference to `q_model_config`
above is that no action concat'ing is performed before the post_fcnet
stack.
tau: Update the target by \tau * policy + (1-\tau) * target_policy.
initial_alpha: Initial value to use for the entropy weight alpha.
target_entropy: Target entropy lower bound. If "auto", will be set
to `-|A|` (e.g. -2.0 for Discrete(2), -3.0 for Box(shape=(3,))).
This is the inverse of reward scale, and will be optimized
automatically.
n_step: N-step target updates. If >1, sars' tuples in trajectories will be
postprocessed to become sa[discounted sum of R][s t+n] tuples. An
integer will be interpreted as a fixed n-step value. If a tuple of 2
ints is provided here, the n-step value will be drawn for each sample(!)
in the train batch from a uniform distribution over the closed interval
defined by `[n_step[0], n_step[1]]`.
store_buffer_in_checkpoints: Set this to True, if you want the contents of
your buffer(s) to be stored in any saved checkpoints as well.
Warnings will be created if:
- This is True AND restoring from a checkpoint that contains no buffer
data.
- This is False AND restoring from a checkpoint that does contain
buffer data.
replay_buffer_config: Replay buffer config.
Examples:
{
"_enable_replay_buffer_api": True,
"type": "MultiAgentReplayBuffer",
"capacity": 50000,
"replay_batch_size": 32,
"replay_sequence_length": 1,
}
- OR -
{
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 50000,
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
"replay_sequence_length": 1,
}
- Where -
prioritized_replay_alpha: Alpha parameter controls the degree of
prioritization in the buffer. In other words, when a buffer sample has
a higher temporal-difference error, with how much more probability
should it drawn to use to update the parametrized Q-network. 0.0
corresponds to uniform probability. Setting much above 1.0 may quickly
result as the sampling distribution could become heavily “pointy” with
low entropy.
prioritized_replay_beta: Beta parameter controls the degree of
importance sampling which suppresses the influence of gradient updates
from samples that have higher probability of being sampled via alpha
parameter and the temporal-difference error.
prioritized_replay_eps: Epsilon parameter sets the baseline probability
for sampling so that when the temporal-difference error of a sample is
zero, there is still a chance of drawing the sample.
training_intensity: The intensity with which to update the model (vs
collecting samples from the env).
If None, uses "natural" values of:
`train_batch_size` / (`rollout_fragment_length` x `num_env_runners` x
`num_envs_per_env_runner`).
If not None, will make sure that the ratio between timesteps inserted
into and sampled from th buffer matches the given values.
Example:
training_intensity=1000.0
train_batch_size=250
rollout_fragment_length=1
num_env_runners=1 (or 0)
num_envs_per_env_runner=1
-> natural value = 250 / 1 = 250.0
-> will make sure that replay+train op will be executed 4x asoften as
rollout+insert op (4 * 250 = 1000).
See: rllib/algorithms/dqn/dqn.py::calculate_rr_weights for further
details.
clip_actions: Whether to clip actions. If actions are already normalized,
this should be set to False.
grad_clip: If not None, clip gradients during optimization at this value.
optimization_config: Config dict for optimization. Set the supported keys
`actor_learning_rate`, `critic_learning_rate`, and
`entropy_learning_rate` in here.
actor_lr: The learning rate (float) or learning rate schedule for the
policy in the format of
[[timestep, lr-value], [timestep, lr-value], ...] In case of a
schedule, intermediary timesteps will be assigned to linearly
interpolated learning rate values. A schedule config's first entry
must start with timestep 0, i.e.: [[0, initial_value], [...]].
Note: It is common practice (two-timescale approach) to use a smaller
learning rate for the policy than for the critic to ensure that the
critic gives adequate values for improving the policy.
Note: If you require a) more than one optimizer (per RLModule),
b) optimizer types that are not Adam, c) a learning rate schedule that
is not a linearly interpolated, piecewise schedule as described above,
or d) specifying c'tor arguments of the optimizer that are not the
learning rate (e.g. Adam's epsilon), then you must override your
Learner's `configure_optimizer_for_module()` method and handle
lr-scheduling yourself.
The default value is 3e-5, one decimal less than the respective
learning rate of the critic (see `critic_lr`).
critic_lr: The learning rate (float) or learning rate schedule for the
critic in the format of
[[timestep, lr-value], [timestep, lr-value], ...] In case of a
schedule, intermediary timesteps will be assigned to linearly
interpolated learning rate values. A schedule config's first entry
must start with timestep 0, i.e.: [[0, initial_value], [...]].
Note: It is common practice (two-timescale approach) to use a smaller
learning rate for the policy than for the critic to ensure that the
critic gives adequate values for improving the policy.
Note: If you require a) more than one optimizer (per RLModule),
b) optimizer types that are not Adam, c) a learning rate schedule that
is not a linearly interpolated, piecewise schedule as described above,
or d) specifying c'tor arguments of the optimizer that are not the
learning rate (e.g. Adam's epsilon), then you must override your
Learner's `configure_optimizer_for_module()` method and handle
lr-scheduling yourself.
The default value is 3e-4, one decimal higher than the respective
learning rate of the actor (policy) (see `actor_lr`).
alpha_lr: The learning rate (float) or learning rate schedule for the
hyperparameter alpha in the format of
[[timestep, lr-value], [timestep, lr-value], ...] In case of a
schedule, intermediary timesteps will be assigned to linearly
interpolated learning rate values. A schedule config's first entry
must start with timestep 0, i.e.: [[0, initial_value], [...]].
Note: If you require a) more than one optimizer (per RLModule),
b) optimizer types that are not Adam, c) a learning rate schedule that
is not a linearly interpolated, piecewise schedule as described above,
or d) specifying c'tor arguments of the optimizer that are not the
learning rate (e.g. Adam's epsilon), then you must override your
Learner's `configure_optimizer_for_module()` method and handle
lr-scheduling yourself.
The default value is 3e-4, identical to the critic learning rate (`lr`).
target_network_update_freq: Update the target network every
`target_network_update_freq` steps.
_deterministic_loss: Whether the loss should be calculated deterministically
(w/o the stochastic action sampling step). True only useful for
continuous actions and for debugging.
_use_beta_distribution: Use a Beta-distribution instead of a
`SquashedGaussian` for bounded, continuous action spaces (not
recommended; for debugging only).
Returns:
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if twin_q is not NotProvided:
self.twin_q = twin_q
if q_model_config is not NotProvided:
self.q_model_config.update(q_model_config)
if policy_model_config is not NotProvided:
self.policy_model_config.update(policy_model_config)
if tau is not NotProvided:
self.tau = tau
if initial_alpha is not NotProvided:
self.initial_alpha = initial_alpha
if target_entropy is not NotProvided:
self.target_entropy = target_entropy
if n_step is not NotProvided:
self.n_step = n_step
if store_buffer_in_checkpoints is not NotProvided:
self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
if replay_buffer_config is not NotProvided:
# Override entire `replay_buffer_config` if `type` key changes.
# Update, if `type` key remains the same or is not specified.
new_replay_buffer_config = deep_update(
{"replay_buffer_config": self.replay_buffer_config},
{"replay_buffer_config": replay_buffer_config},
False,
["replay_buffer_config"],
["replay_buffer_config"],
)
self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"]
if training_intensity is not NotProvided:
self.training_intensity = training_intensity
if clip_actions is not NotProvided:
self.clip_actions = clip_actions
if grad_clip is not NotProvided:
self.grad_clip = grad_clip
if optimization_config is not NotProvided:
self.optimization = optimization_config
if actor_lr is not NotProvided:
self.actor_lr = actor_lr
if critic_lr is not NotProvided:
self.critic_lr = critic_lr
if alpha_lr is not NotProvided:
self.alpha_lr = alpha_lr
if target_network_update_freq is not NotProvided:
self.target_network_update_freq = target_network_update_freq
if _deterministic_loss is not NotProvided:
self._deterministic_loss = _deterministic_loss
if _use_beta_distribution is not NotProvided:
self._use_beta_distribution = _use_beta_distribution
if num_steps_sampled_before_learning_starts is not NotProvided:
self.num_steps_sampled_before_learning_starts = (
num_steps_sampled_before_learning_starts
)
return self
@override(AlgorithmConfig)
def validate(self) -> None:
# Call super's validation method.
super().validate()
# Check rollout_fragment_length to be compatible with n_step.
if isinstance(self.n_step, tuple):
min_rollout_fragment_length = self.n_step[1]
else:
min_rollout_fragment_length = self.n_step
if (
not self.in_evaluation
and self.rollout_fragment_length != "auto"
and self.rollout_fragment_length
< min_rollout_fragment_length # (self.n_step or 1)
):
raise ValueError(
f"Your `rollout_fragment_length` ({self.rollout_fragment_length}) is "
f"smaller than needed for `n_step` ({self.n_step})! If `n_step` is "
f"an integer try setting `rollout_fragment_length={self.n_step}`. If "
"`n_step` is a tuple, try setting "
f"`rollout_fragment_length={self.n_step[1]}`."
)
if self.use_state_preprocessor != DEPRECATED_VALUE:
deprecation_warning(
old="config['use_state_preprocessor']",
error=False,
)
self.use_state_preprocessor = DEPRECATED_VALUE
if self.grad_clip is not None and self.grad_clip <= 0.0:
raise ValueError("`grad_clip` value must be > 0.0!")
if self.framework in ["tf", "tf2"] and tfp is None:
logger.warning(
"You need `tensorflow_probability` in order to run SAC! "
"Install it via `pip install tensorflow_probability`. Your "
f"tf.__version__={tf.__version__ if tf else None}."
"Trying to import tfp results in the following error:"
)
try_import_tfp(error=True)
# Validate that we use the corresponding `EpisodeReplayBuffer` when using
# episodes.
if (
self.enable_env_runner_and_connector_v2
and self.replay_buffer_config["type"]
not in [
"EpisodeReplayBuffer",
"PrioritizedEpisodeReplayBuffer",
"MultiAgentEpisodeReplayBuffer",
"MultiAgentPrioritizedEpisodeReplayBuffer",
]
and not (
# TODO (simon): Set up an indicator `is_offline_new_stack` that
# includes all these variable checks.
self.input_
and (
isinstance(self.input_, str)
or (
isinstance(self.input_, list)
and isinstance(self.input_[0], str)
)
)
and self.input_ != "sampler"
and self.enable_rl_module_and_learner
)
):
raise ValueError(
"When using the new `EnvRunner API` the replay buffer must be of type "
"`EpisodeReplayBuffer`."
)
elif not self.enable_env_runner_and_connector_v2 and (
(
isinstance(self.replay_buffer_config["type"], str)
and "Episode" in self.replay_buffer_config["type"]
)
or (
isinstance(self.replay_buffer_config["type"], type)
and issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer)
)
):
raise ValueError(
"When using the old API stack the replay buffer must not be of type "
"`EpisodeReplayBuffer`! We suggest you use the following config to run "
"SAC on the old API stack: `config.training(replay_buffer_config={"
"'type': 'MultiAgentPrioritizedReplayBuffer', "
"'prioritized_replay_alpha': [alpha], "
"'prioritized_replay_beta': [beta], "
"'prioritized_replay_eps': [eps], "
"})`."
)
if self.enable_rl_module_and_learner:
if self.lr is not None:
raise ValueError(
"Basic learning rate parameter `lr` is not `None`. For SAC "
"use the specific learning rate parameters `actor_lr`, `critic_lr` "
"and `alpha_lr`, for the actor, critic, and the hyperparameter "
"`alpha`, respectively and set `config.lr` to None."
)
# Warn about new API stack on by default.
logger.warning(
"You are running SAC on the new API stack! This is the new default "
"behavior for this algorithm. If you don't want to use the new API "
"stack, set `config.api_stack(enable_rl_module_and_learner=False, "
"enable_env_runner_and_connector_v2=False)`. For a detailed "
"migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa
)
@override(AlgorithmConfig)
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
if self.rollout_fragment_length == "auto":
return (
self.n_step[1]
if isinstance(self.n_step, (tuple, list))
else self.n_step
)
else:
return self.rollout_fragment_length
@override(AlgorithmConfig)
def get_default_rl_module_spec(self) -> RLModuleSpecType:
if self.framework_str == "torch":
from ray.rllib.algorithms.sac.torch.default_sac_torch_rl_module import (
DefaultSACTorchRLModule,
)
return RLModuleSpec(module_class=DefaultSACTorchRLModule)
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. Use `torch`."
)
@override(AlgorithmConfig)
def get_default_learner_class(self) -> Union[Type["Learner"], str]:
if self.framework_str == "torch":
from ray.rllib.algorithms.sac.torch.sac_torch_learner import SACTorchLearner
return SACTorchLearner
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. Use `torch`."
)
@override(AlgorithmConfig)
def build_learner_connector(
self,
input_observation_space,
input_action_space,
device=None,
):
pipeline = super().build_learner_connector(
input_observation_space=input_observation_space,
input_action_space=input_action_space,
device=device,
)
# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
# after the corresponding "add-OBS-..." default piece).
pipeline.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)
return pipeline
@property
def _model_config_auto_includes(self):
return super()._model_config_auto_includes | {"twin_q": self.twin_q}
class SAC(DQN):
"""Soft Actor Critic (SAC) Algorithm class.
This file defines the distributed Algorithm class for the soft actor critic
algorithm.
See `sac_[tf|torch]_policy.py` for the definition of the policy loss.
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#sac
"""
def __init__(self, *args, **kwargs):
self._allow_unknown_subkeys += ["policy_model_config", "q_model_config"]
super().__init__(*args, **kwargs)
@classmethod
@override(DQN)
def get_default_config(cls) -> AlgorithmConfig:
return SACConfig()
@classmethod
@override(DQN)
def get_default_policy_class(
cls, config: AlgorithmConfig
) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
from ray.rllib.algorithms.sac.sac_torch_policy import SACTorchPolicy
return SACTorchPolicy
else:
return SACTFPolicy