Rate this Page

Class Optimizer#

Inheritance Relationships#

Derived Types#

Class Documentation#

class Optimizer#

Subclassed by torch::optim::Adagrad, torch::optim::Adam, torch::optim::AdamW, torch::optim::LBFGS, torch::optim::RMSprop, torch::optim::SGD

Public Types

using LossClosure = std::function<Tensor()>#

Public Functions

Optimizer(const Optimizer &optimizer) = delete#
Optimizer(Optimizer &&optimizer) = default#
Optimizer &operator=(const Optimizer &optimizer) = delete#
Optimizer &operator=(Optimizer &&optimizer) = default#
inline explicit Optimizer(const std::vector<OptimizerParamGroup> &param_groups, std::unique_ptr<OptimizerOptions> defaults)#
inline explicit Optimizer(std::vector<Tensor> parameters, std::unique_ptr<OptimizerOptions> defaults)#

Constructs the Optimizer from a vector of parameters.

void add_param_group(const OptimizerParamGroup &param_group)#

Adds the given param_group to the optimizer’s param_group list.

virtual ~Optimizer() = default#
virtual Tensor step(LossClosure closure = nullptr) = 0#

A loss function closure, which is expected to return the loss value.

void add_parameters(const std::vector<Tensor> &parameters)#

Adds the given vector of parameters to the optimizer’s parameter list.

void zero_grad(bool set_to_none = true)#

Zeros out the gradients of all parameters.

const std::vector<Tensor> &parameters() const noexcept#

Provides a const reference to the parameters in the first param_group this optimizer holds.

std::vector<Tensor> &parameters() noexcept#

Provides a reference to the parameters in the first param_group this optimizer holds.

size_t size() const noexcept#

Returns the number of parameters referenced by the optimizer.

OptimizerOptions &defaults() noexcept#
const OptimizerOptions &defaults() const noexcept#
std::vector<OptimizerParamGroup> &param_groups() noexcept#

Provides a reference to the param_groups this optimizer holds.

const std::vector<OptimizerParamGroup> &param_groups() const noexcept#

Provides a const reference to the param_groups this optimizer holds.

ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> &state() noexcept#

Provides a reference to the state this optimizer holds.

const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> &state() const noexcept#

Provides a const reference to the state this optimizer holds.

virtual void save(serialize::OutputArchive &archive) const#

Serializes the optimizer state into the given archive.

virtual void load(serialize::InputArchive &archive)#

Deserializes the optimizer state from the given archive.

Protected Attributes

std::vector<OptimizerParamGroup> param_groups_#
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_#
std::unique_ptr<OptimizerOptions> defaults_#