Multivariate Chain-Rule
In the multivariate case, where x ∈ Rn , the basic differentiation rules that we
know from school (e.g., sum rule, product rule, chain rule) still apply. However, Product rule:
we need to pay attention because now we have to deal with matrices where (f g)0 =
multiplication is no longer commutative, i.e., the order matters. f 0 g + f g 0 , Sum
rule: (f +g)0 =
∂ ∂f ∂g
Product Rule: f (x)g(x) = g(x) + f (x) (1) f 0 + g 0 , Chain
∂x ∂x ∂x rule: (g ◦ f )0 =
∂ ∂f ∂g
Sum Rule: f (x) + g(x) = + (2) g 0 (f )f 0
∂x ∂x ∂x
∂ ∂ ∂g ∂f
Chain Rule: (g ◦ f )(x) = g(f (x)) = (3)
∂x ∂x ∂f ∂x
Let us have a closer look at the chain rule. The chain rule formula (3) resembles
to some degree the rules for matrix multiplication where “neighboring” dimen-
sions have to match for matrix multiplication to be defined. If we go from left to
right, the chain rule exhibits similar properties: ∂f shows up in the “denomina-
tor” of the first factor and in the “numerator” of the second factor. If we multiply
the factors together, multiplication is defined (the dimensions of ∂f match, and
∂f “cancels”, such that ∂g/∂x remains.1
Consider a function f : R2 → R of two variables x1 , x2 . Furthermore, x1 (t)
and x2 (t) are themselves functions of t. To compute the gradient of f with
respect to t, we need to apply the chain rule (3) for multivariate functions as
" #
df h i ∂x1 (t) ∂f ∂x1 ∂f ∂x2
∂f ∂f ∂t
= ∂x1 ∂x2 ∂x2 (t) = + (4)
dt ∂t
∂x1 ∂t ∂x 2 ∂t
where d denotes the gradient and ∂ partial derivatives.
Example:
Consider f (x1 , x2 ) = x21 + 2x2 , where x1 = sin t and x2 = cos t, then
df ∂f ∂x1 ∂f ∂x2
= + (5)
dt ∂x1 ∂t ∂x2 ∂t
∂ sin t ∂ cos t
= 2 sin t +2 (6)
∂t ∂t
= 2 sin t cos t − 2 sin t = 2 sin t(cos t − 1) (7)
is the corresponding derivative of f with respect to t.
If f (x1 , x2 ) is a function of x1 and x2 , where x1 (s, t) and x2 (s, t) are themselves
functions of two variables s and t, the chain rule yields
df ∂f ∂x1 ∂f ∂x2
= + , (8)
ds ∂x1 ∂s ∂x2 ∂s
1 This is only an intuition, but not mathematically correct since the partial derivative is not a
fraction.
1
df ∂f ∂x1 ∂f ∂x2
= + , (9)
dt ∂x1 ∂t ∂x2 ∂t
which can be expressed as the matrix multiplication
df ∂f ∂x h i ∂x1 ∂x1
∂f ∂f ∂s ∂t
= = ∂x ∂x2 . (10)
d(s, t) ∂x ∂(s, t) | 1 {z ∂x2 } ∂x ∂s
2
∂f
| {z ∂t }
= ∂x ∂x
= ∂(s,t)
This compact way of writing the chain rule as a matrix multiplication makes
only sense if the gradient is defined as a row vector. Otherwise, we will need
to start transposing gradients for the matrix dimensions to match. This may
still be straightforward as long as the gradient is a vector or a matrix; however,
when the gradient becomes a tensor (we will discuss this in the following), the
transpose is no longer a triviality.
Example: (Gradient of a Linear Model)
Let us consider the linear model
y = Φθ , (11)
where θ ∈ RD is a parameter vector, Φ ∈ RN ×D are input features and y ∈ RN
are the corresponding observations. We define the following functions:
L(e) := kek2 , (12)
e(θ) := y − Φθ . (13)
We seek ∂L
∂θ , and we will use the chain rule for this purpose.
Before we start any calculation, we determine the dimensionality of the gra-
dient as
∂L
∈ R1×D . (14)
∂θ
The chain rule allows us to compute the gradient as
∂L ∂L ∂e
= . (15)
∂θ ∂e ∂θ
We know that kek2 = e> e and determine
∂L
= 2e> ∈ R1×N . (16)
∂e
Furthermore, we obtain
∂e
= −Φ ∈ RN ×D , (17)
∂θ
such that our desired derivative is
∂L (13)
= −2e> Φ = − 2(y > − θ > Φ> ) |{z}
Φ ∈ R1×D . (18)
∂θ | {z }
1×N N ×D
2
Remark. We would have obtained the same result without using the chain rule
by immediately looking at the function
L2 (θ) := ky − Φθk2 = (y − Φθ)> (y − Φθ) . (19)
This approach is still practical for simple functions like L2 but becomes imprac-
tical if consider deep function compositions.