Today I will try to explain how the forward and backward mode in automatic differentiation work. I will only cover the principle, not actual algorithms and the optimizations they apply. While the so called forward mode is quite intuitive, it is not so easy to wrap your head around the backward mode. I will try to go through all steps and not leave out anything seemingly trivial.
We consider the computation of a function with independent variables
and dependent variables
. The ultimate goal is to compute the Jacobian
We view the function as a composite of elementary operations
for where we set
for
(i.e. we reserve these indices for the start values of the computation) and
for
(i.e. these are the final results of the computation). The notation should suggest that
depends on prior results
with
in some index set
. Note that if
this refers to a direct dependency of
on
, i.e. if
depends on
, but
does not enter the calculation of
directly then
.
As an example consider the function
for which we would have ,
,
,
. The direct dependencies are
,
and
, but not
, because
does not enter the expression for
directly.
We can view the computation chain as a directed graph with vertices and edges
if
. There are no circles allowed in this graph (it is a acyclic graph) and it consists of
vertices.
We write for the length of the longest path from
to
and call that number the distance from
to
. Note that this is not the usual definition of distance normally being the length of the shortest path.
If is not reachable from
we set
. If
is reachable from
the distance is finite, since the graph is acyclic.
We can compute a partial derivative using the chain rule
This suggest a forward propagation scheme: We start at the initial nodes . For all nodes
with maximum distance
from all of these nodes we compute
where we can choose for
freely at this stage. This assigns the dot product of the gradient of
w.r.t.
and
to the node
.
If we choose for one specific
and zero otherwise, we get the partial derivative of
by
, but we can compute any other directional derivatives using other vectors
. (Remember that the directional derivative is the gradient times the direction w.r.t. which the derivative shall be computed.)
Next we consider nodes with maximum distance from all nodes
. For such a node
where we can assume that the were computed in the previous step, because their maximum distance to all initial nodes
muss be less than
, hence
.
Also note that if , which may be the case,
if
and zero otherwise, so
trivially. Or seemingly trivial.
The same argument can be iterated for nodes with maximum distance until we reach the final nodes
. This way we can work forward through the computational graph and compute the directional derivative we seek.
In the backward mode we do very similar things, but in a dual way: We start at the final nodes and compute for all nodes with maximum distance
from all of these nodes
Note that we compute a weighted sum in the dependent variables now. By setting a specific to
and the rest to zero again we can compute the partial derivatives of a single final variable. Again using the chain rule we can compute
for all nodes with maximum distance of
from all the final nodes.
Note that the chain rule formally requires to include all variables on which
depends. Howvever if
does not depend on
the whole term will effectively be zero, so we can drop these summands from the beginning. Also we may include indices
on which
does not depend in the first place, which is not harmful for the same reason.
As above we can assume all to be computed in the previous step, so that we can iterate backwards to the inital nodes to get all partial derivatives of the weighted sum of the final nodes w.r.t. the initial nodes.
