Class Optimizer#
Defined in File optimizer.h
Page Contents
Inheritance Relationships#
Derived Types#
public torch::optim::Adagrad
(Class Adagrad)public torch::optim::Adam
(Class Adam)public torch::optim::AdamW
(Class AdamW)public torch::optim::LBFGS
(Class LBFGS)public torch::optim::RMSprop
(Class RMSprop)public torch::optim::SGD
(Class SGD)
Class Documentation#
-
class Optimizer#
Subclassed by torch::optim::Adagrad, torch::optim::Adam, torch::optim::AdamW, torch::optim::LBFGS, torch::optim::RMSprop, torch::optim::SGD
Public Types
-
using LossClosure = std::function<Tensor()>#
Public Functions
-
inline explicit Optimizer(const std::vector<OptimizerParamGroup> ¶m_groups, std::unique_ptr<OptimizerOptions> defaults)#
-
inline explicit Optimizer(std::vector<Tensor> parameters, std::unique_ptr<OptimizerOptions> defaults)#
Constructs the
Optimizer
from a vector of parameters.
-
void add_param_group(const OptimizerParamGroup ¶m_group)#
Adds the given param_group to the optimizer’s param_group list.
-
virtual ~Optimizer() = default#
-
virtual Tensor step(LossClosure closure = nullptr) = 0#
A loss function closure, which is expected to return the loss value.
-
void add_parameters(const std::vector<Tensor> ¶meters)#
Adds the given vector of parameters to the optimizer’s parameter list.
-
void zero_grad(bool set_to_none = true)#
Zeros out the gradients of all parameters.
-
const std::vector<Tensor> ¶meters() const noexcept#
Provides a const reference to the parameters in the first param_group this optimizer holds.
-
std::vector<Tensor> ¶meters() noexcept#
Provides a reference to the parameters in the first param_group this optimizer holds.
-
size_t size() const noexcept#
Returns the number of parameters referenced by the optimizer.
-
OptimizerOptions &defaults() noexcept#
-
const OptimizerOptions &defaults() const noexcept#
-
std::vector<OptimizerParamGroup> ¶m_groups() noexcept#
Provides a reference to the param_groups this optimizer holds.
-
const std::vector<OptimizerParamGroup> ¶m_groups() const noexcept#
Provides a const reference to the param_groups this optimizer holds.
-
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> &state() noexcept#
Provides a reference to the state this optimizer holds.
-
const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> &state() const noexcept#
Provides a const reference to the state this optimizer holds.
Protected Attributes
-
std::vector<OptimizerParamGroup> param_groups_#
-
ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>> state_#
-
std::unique_ptr<OptimizerOptions> defaults_#
-
using LossClosure = std::function<Tensor()>#