ray.rllib.core.learner.learner.Learner.additional_update#
- Learner.additional_update(*, module_ids_to_update: Sequence[str] = None, timestep: int, **kwargs) Mapping[str, Any] [source]#
Apply additional non-gradient based updates to this Algorithm.
For example, this could be used to do a polyak averaging update of a target network in off policy algorithms like SAC or DQN.
Example:
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import ( PPOTorchLearner ) from ray.rllib.algorithms.ppo.ppo_learner import ( LEARNER_RESULTS_CURR_KL_COEFF_KEY ) from ray.rllib.algorithms.ppo.ppo_learner import PPOLearnerHyperparameters import gymnasium as gym env = gym.make("CartPole-v1") hps = PPOLearnerHyperparameters( use_kl_loss=True, kl_coeff=0.2, kl_target=0.01, use_critic=True, clip_param=0.3, vf_clip_param=10.0, entropy_coeff=0.01, entropy_coeff_schedule = [ [0, 0.01], [20000000, 0.0], ], vf_loss_coeff=0.5, ) # Create a single agent RL module spec. module_spec = SingleAgentRLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, model_config_dict = {"hidden": [128, 128]}, catalog_class = PPOCatalog, ) class CustomPPOLearner(PPOTorchLearner): def additional_update_for_module( self, *, module_id, hps, timestep, sampled_kl_values ): results = super().additional_update_for_module( module_id=module_id, hps=hps, timestep=timestep, sampled_kl_values=sampled_kl_values, ) # Try something else than the PPO paper here. sampled_kl = sampled_kl_values[module_id] curr_var = self.curr_kl_coeffs_per_module[module_id] if sampled_kl > 1.2 * self.hps.kl_target: curr_var.data *= 1.2 elif sampled_kl < 0.8 * self.hps.kl_target: curr_var.data *= 0.4 results.update({LEARNER_RESULTS_CURR_KL_COEFF_KEY: curr_var.item()}) learner = CustomPPOLearner( module_spec=module_spec, learner_hyperparameters=hps, ) # Note: the learner should be built before it can be used. learner.build() # Inside a training loop, we can now call the additional update as we like: for i in range(100): # sample = ... # learner.update(sample) if i % 10 == 0: learner.additional_update( timestep=i, sampled_kl_values={"default_policy": 0.5} )
- Parameters:
module_ids_to_update – The ids of the modules to update. If None, all modules will be updated.
timestep – The current timestep.
**kwargs – Keyword arguments to use for the additional update.
- Returns:
A dictionary of results from the update