-
Notifications
You must be signed in to change notification settings - Fork 6.2k
/
Copy pathtest_registry.py
37 lines (27 loc) · 1.08 KB
/
test_registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import unittest
from ray.rllib.algorithms.registry import (
POLICIES,
get_policy_class,
get_policy_class_name,
ALGORITHMS_CLASS_TO_NAME,
ALGORITHMS,
)
class TestPolicies(unittest.TestCase):
def test_load_policies(self):
for name in POLICIES.keys():
self.assertIsNotNone(get_policy_class(name))
def test_get_eager_traced_class_name(self):
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF2Policy
traced = PPOTF2Policy.with_tracing()
self.assertEqual(get_policy_class_name(traced), "PPOTF2Policy")
def test_registered_algorithm_names(self):
"""All RLlib registered algorithms should have their name listed in the
registry dictionary."""
for class_name in ALGORITHMS_CLASS_TO_NAME.keys():
registered_name = ALGORITHMS_CLASS_TO_NAME[class_name]
algo_class, _ = ALGORITHMS[registered_name]()
self.assertEqual(class_name.upper(), algo_class.__name__.upper())
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))