ray.rllib.algorithms.algorithm.Algorithm#

class ray.rllib.algorithms.algorithm.Algorithm(config: Optional[ray.rllib.algorithms.algorithm_config.AlgorithmConfig] = None, env=None, logger_creator: Optional[Callable[[], ray.tune.logger.logger.Logger]] = None, **kwargs)[source]#

Bases: ray.tune.trainable.trainable.Trainable

An RLlib algorithm responsible for optimizing one or more Policies.

Algorithms contain a WorkerSet under self.workers. A WorkerSet is normally composed of a single local worker (self.workers.local_worker()), used to compute and apply learning updates, and optionally one or more remote workers used to generate environment samples in parallel. WorkerSet is fault tolerant and elastic. It tracks health states for all the managed remote worker actors. As a result, Algorithm should never access the underlying actor handles directly. Instead, always access them via all the foreach APIs with assigned IDs of the underlying workers.

Each worker (remotes or local) contains a PolicyMap, which itself may contain either one policy for single-agent training or one or more policies for multi-agent training. Policies are synchronized automatically from time to time using ray.remote calls. The exact synchronization logic depends on the specific algorithm used, but this usually happens from local worker to all remote workers and after each training update.

You can write your own Algorithm classes by sub-classing from Algorithm or any of its built-in sub-classes. This allows you to override the training_step method to implement your own algorithm logic. You can find the different built-in algorithms’ training_step() methods in their respective main .py files, e.g. rllib.algorithms.dqn.dqn.py or rllib.algorithms.impala.impala.py.

The most important API methods a Algorithm exposes are train(), evaluate(), save() and restore().

Methods

__init__([config, env, logger_creator])

Initializes an Algorithm instance.

add_policy(policy_id[, policy_cls, policy, ...])

Adds a new policy to this Algorithm.

compute_actions(observations[, state, ...])

Computes an action for the specified policy on the local Worker.

compute_single_action([observation, state, ...])

Computes an action for the specified policy on the local worker.

delete_checkpoint(checkpoint_path)

Deletes local copy of checkpoint.

evaluate([duration_fn])

Evaluates current policy under evaluation_config settings.

export_model(export_formats[, export_dir])

Exports model based on export_formats.

export_policy_checkpoint(export_dir[, ...])

Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.

export_policy_model(export_dir[, policy_id, ...])

Exports policy model with given policy_id to a local directory.

from_checkpoint(checkpoint[, policy_ids, ...])

Creates a new algorithm instance from a given checkpoint.

from_state(state)

Recovers an Algorithm from a state object.

get_config()

Returns configuration passed in by Tune.

get_default_policy_class(config)

Returns a default Policy class to use, given a config.

get_policy([policy_id])

Return policy for the specified id, or None.

get_weights([policies])

Return a dictionary of policy ids to weights.

import_model(import_file)

Imports a model from import_file.

import_policy_model_from_h5(import_file[, ...])

Imports a policy's model with given policy_id from a local h5 file.

merge_trainer_configs(config1, config2[, ...])

Merges a complete Algorithm config dict with a partial override dict.

remove_policy([policy_id, ...])

Removes a new policy from this Algorithm.

reset(new_config[, logger_creator, ...])

Resets trial for use with new config.

reset_config(new_config)

Resets configuration without restarting the trial.

restore(checkpoint_path[, ...])

Restores training state from a given model checkpoint.

restore_from_object(obj)

Restores training state from a checkpoint object.

restore_workers(workers)

Try to restore failed workers if necessary.

save([checkpoint_dir, prevent_upload])

Saves the current model state to a checkpoint.

save_checkpoint(checkpoint_dir)

Exports AIR Checkpoint to a local directory and returns its directory path.

save_to_object()

Saves the current model state to a Python object.

set_weights(weights)

Set policy weights by policy id.

step()

Implements the main Trainer.train() logic.

stop()

Releases all resources used by this trainable.

train()

Runs one logical iteration of training.

train_buffered(buffer_time_s[, ...])

Runs multiple iterations of training.

training_step()

Default single iteration logic of an algorithm.

validate_env(env, env_context)

Env validator function for this Algorithm class.

Attributes

iteration

Current training iteration.

logdir

Directory of the results and checkpoints for this Trainable.

training_iteration

Current training iteration (same as self.iteration).

trial_id

Trial ID for the corresponding trial of this Trainable.

trial_name

Trial name for the corresponding trial of this Trainable.

trial_resources

Resources currently assigned to the trial of this Trainable.

uses_cloud_checkpointing