The chain rule of derivatives is, in my opinion, the most important formula in differential calculus. In this post I want to explain how the chain rule works for single-variable and multivariate functions, with some interesting examples along the way.

Preliminaries: composition of functions and differentiability

We denote a function f that maps from the domain X to the codomain Y as f:X \rightarrow Y. With this f and given g:Y \rightarrow Z, we can define g \circ f:X \rightarrow Z as the composition of g and f. It's defined for \forall x \in X as:

\[(g \circ f)(x)=g(f(x))\]

In calculus we are usually concerned with the real number domain of some dimensionality. In the single-variable case, we can think of f and g as two regular real-valued functions: f:\mathbb{R} \rightarrow \mathbb{R} and g:\mathbb{R} \rightarrow \mathbb{R}.

As an example, say f(x)=x+1 and g(x)=x^2. Then:

\[(g \circ f)(x)=g(f(x))=g(x+1)=(x+1)^2\]

We can compose the functions the other way around as well:

\[(f \circ g)(x)=f(g(x))=f(x^2)=x^2+1\]

Obviously, we shouldn't expect composition to be commutative. It is, however, associative. h \circ (g \circ f) and (h \circ g) \circ f are equivalent, and both end up being h(g(f(x))) for \forall x \in X.

To better handle compositions in one's head it sometimes helps to denote the independent variable of the outer function (g in our case) by a different letter (such as g(a)). For simple cases it doesn't matter, but I'll be using this technique occasionally throughout the article. The important thing to remember here is that the name of the independent variable is completely arbitrary, and we should always be able to replace it by another name throughout the formula without any semantic change.

The other preliminary I want to mention is differentiability. The function f is differentiable at some point x_0 if the following limit exists:

\[\lim_{h \to 0}\frac{f(x_0+h)-f(x_0)}{h}\]

This limit is then the derivative of f at the point x_0, or {f}'(x_0). Another way to express this is \frac{d}{dx}f(x_0). Note that x_0 can be any arbitrary point on the real line. I sometimes say something like "f is differentiable at g(x_0)". Here too, g(x_0) is just a real value that happens to be the value of the function g at x_0.

The single-variable chain rule

The chain rule for single-variable functions states: if g is differentiable at x_0 and f is differentiable at g(x_0), then f \circ g is differentiable at x_0 and its derivative is:

\[(f \circ g)'(x_0)={f}'(g(x_0)){g}'(x_0)\]

The proof of the chain rule is a bit tricky - I left it for the appendix. However, we can get a better feel for it using some intuition and a couple of examples.

First, the intuituion. By definition:

\[{g}'(x_0)=\lim_{h \to 0}\frac{g(x_0+h)-g(x_0)}{h}\]

Multiplying both sides by h we get [1]:

\[{g}'(x_0)h=\lim_{h \to 0}g(x_0+h)-g(x_0)\]

Therefore we can say that when x_0 changes by some very small amount, g(x_0) changes by {g}'(x_0) times that small amount.

Similarly {f}'(a_0) is the amount of change in the value of f for some very small change from a_0. However, since in our case we compose f \circ g, we can say that a_0=g(x_0), evaluating f(g(x_0)). Suppose we shift x_0 by a small amount h. This causes g(x_0) to shift by {g}'(x_0)h. So the input a_0 of f shifted by {g}'(x_0)h - this is still a small amount! Therefore, the total change in the value of f should be {f}'(g(x_0)){g}'(x_0)h [2].

Now, a couple of simple examples. Let's take the function f(x)=(x+1)^2. The idea is to think of this function as a composition of simpler functions. In this case, one option is: g(x)=x+1 and then w(g(x))=g(x)^2, so the original f is now the composition w \circ g.

The derivative of this composition is {w}'(g(x)){g}'(x), or 2(x+1) since {g}'(x)=1. Note that w is differentiable at any point, so this derivative always exists.

Another example will use a longer chain of composition. Let's differentiate f(x)=sin[(x+1)^2]. This is a composition of three functions:

\[\begin{align*} g(x)&=x+1\\ w(x)&=x^2\\ v(x)&=sin(x) \end{align*}\]

Function composition is associative, so f can be expressed as either v \circ (w \circ g) or (v \circ w) \circ g. Since we already know what the derivative of w \circ g is, let's use the former:

\[\begin{align*} \frac{df(x)}{dx}=\frac{d v(w(g(x)))}{dx}&={v}'(w(g(x))){w(g(x))}'(x)\\                                         &=cos(w(g(x)))2(x+1)\\                                         &=2cos[(x+1)^2](x+1) \end{align*}\]

The chain rule as a computational procedure

As the last example demonstrates, the chain rule can be applied multiple times in a single derivation. This makes the chain rule a powerful tool for computing derivatives of very complex functions, which can be broken up into compositions of simpler functions. I like to draw a parallel between this process and programming; a function in a programming language can be seen as a computational procedure - we have a set of input parameters and we produce outputs. On the way, several transformations happen that can be expressed mathematically. These transformations are composed, so their derivatives can be computed naturally with the chain rule.

This may be somewhat abstract, so let's use another example. We'll compute the derivative of the Sigmoid function - a very important function in machine learning:

\[S(x)=\frac{1}{1+e^{-x}}\]

To make the equivalence between functions and computational procedures clearer, let's think how we'd compute S in Python:

def sigmoid(x):
    return 1 / (1 + math.exp(-x))

This doesn't look much different, but that's just because Python is a high level language with arbitrarily nested expressions. Its VM (or the CPU in general) would execute this computation step by step. Let's break it up to be clearer, assuming we can only apply a single operation at every step:

def sigmoid(x):
    f = -x
    g = math.exp(f)
    w = 1 + g
    v = 1 / w
    return v

I hope you're starting to see the resemblance to our chain rule examples at this point. Sacrificing some rigor in the notation for the sake of expressiveness, we can write:

\[S'=v'(w)w'(g)g'(f)f'(x)\]

This is the chain rule applied to v \circ (w \circ (g \circ f)). Solving this is easy because every single derivative in the chain above is trivial:

\[\begin{align*} S'&=v'(w)w'(g)g'(f)(-1)\\   &=v'(w)w'(g)e^{-x}(-1)\\   &=v'(w)(1)e^{-x}(-1)\\   &=\frac{-1}{(1+e^{-x})^2}e^{-x}(-1)\\   &=\frac{e^{-x}}{(1+e^{-x})^2} \end{align*}\]

Now you may be thinking:

  1. Every function computable by a program can be broken down to trivial steps like our sigmoid above.
  2. Using the chain rule, we can easily find the derivative of such a sequence of steps... therefore:
  3. We can easily find the derivative of any function computable by a program!!

An you'll be right. This is precisely the basis for the technique known as automatic differentiation, which is widely used in scienctific computing. The most notable use of automatic differentiation in recent times is the backpropagation algorithm - an essential backbone of modern machine learning. I personally find automatic differentiation fascinating, and will write a more dedicated article about it in the future.

Multivariate chain rule - general formulation

So far this article has been looking at functions with a single input and output: f:\mathbb{R} \to \mathbb{R}. In the most general case of multi-variate calculus, we're dealing with functions that map from n dimensions to m dimensions: f:\mathbb{R}^{n} \to \mathbb{R}^{m}. Because every one of the m outputs of f can be considered a separate function dependent on n variables, it's very natural to deal with such math using vectors and matrices.

First let's define some notation. We'll consider the outputs of f to be numbered from 1 to m as f_1,f_2 \dots f_m. For each such f_i we can compute its partial derivative by any of the n inputs as:

\[D_j f_i(a)=\frac{\partial f_i}{\partial a_j}(a)\]

Where j goes from 1 to n and a is a vector with n components. If f is differentiable at a [3] then the derivative of f at a is the Jacobian matrix:

\[Df(a)=\begin{bmatrix} D_1 f_1(a) & \cdots & D_n f_1(a) \\ \vdots &  & \vdots \\ D_1 f_m(a) & \cdots & D_n f_m(a) \\ \end{bmatrix}\]

The multivariate chain rule states: given g:\mathbb{R}^n \to \mathbb{R}^m and f:\mathbb{R}^m \to \mathbb{R}^p and a point a \in \mathbb{R}^n, if g is differentiable at a and f is differentiable at g(a) then the composition f \circ g is differentiable at a and its derivative is:

\[D(f \circ g)(a)=Df(g(a)) \cdot Dg(a)\]

Which is the matrix multiplication of Df(g(a)) and Dg(a) [4]. Intuitively, the multivariate chain rule mirrors the single-variable one (and as we'll soon see, the latter is just a special case of the former) with derivatives replaced by derivative matrices. From linear algebra, we represent linear transformations by matrices, and the composition of two linear transformations is the product of their matrices. Therefore, since derivative matrices - like derivatives in one dimension - are a linear approximation to the function, the chain rule makes sense. This is a really nice connection between linear algebra and calculus, though a full proof of the multivariate rule is very technical and outside the scope of this article.

Multivariate chain rule - examples

Since the chain rule deals with compositions of functions, it's natural to present examples from the world of parametric curves and surfaces. For example, suppose we define f(x,y,z) as a scalar function \mathbb{R}^3 \to \mathbb{R} giving the temperature at some point in 3D. Now imagine that we're moving through this 3D space on a curve defined by a function g:\mathbb{R} \to \mathbb{R}^3 which takes time t and gives the coordinates x(t),y(t),z(t) at that time. We want to compute how the temperature changes as a function of time t - how do we do that? Recall that the temprerature is not a direct function of time, but rather is a function of location, while location is a function of time. Therefore, we'll want to compose f \circ g. Here's a concrete example:

\[g(t)=\begin{pmatrix} t\\ t^2\\ t^3 \end{pmatrix}\]

And:

\[f\begin{pmatrix} x \\ y \\ z \end{pmatrix}=x^2+xyz+5y\]

If we reformulate x, y and z as functions of t:

\[f(x(t),y(t),z(t))=x(t)^2+x(t)y(t)z(t)+5y(t)\]

Composing f \circ g, we get:

\[(f \circ g)(t)=f(g(t))=f(t,t^2,t^3)=t^2+t^6+5t^2=6t^2+t^6\]

Since this is a simple function, we can find its derivative directly:

\[(f \circ g)'(t)=12t+6t^5\]

Now let's repeat this exercise using the multivariate chain rule. To compute D(f \circ g)(t) we need Df(g(t)) and Dg(t). Let's start with Dg(t). g(t) maps \mathbb{R} \to \mathbb{R}^3, so its Jacobian is a 3-by-1 matrix, or column vector:

\[Dg(t)=\begin{bmatrix} 1 \\ 2t \\ 3t^2 \end{bmatrix}\]

To compute Df(g(t)) let's first find Df(x,y,z). Since f(x,y,z) maps \mathbb{R}^3 \to \mathbb{R}, its Jacobian is a 1-by-3 matrix, or row vector:

\[Df(x,y,z)=\begin{bmatrix} 2x+yz & xz+5 & xy \end{bmatrix}\]

To apply the chain rule, we need Df(g(t)):

\[Df(g(t))=\begin{bmatrix} 2t+t^5 & t^4+5 & t^3 \end{bmatrix}\]

Finally, multiplying Df(g(t)) by Dg(t), we get:

\[\begin{align*} D(f \circ g)(t)=Df(g(t)) \cdot Dg(t)&=\begin{bmatrix} 2t+t^5 & t^4+5 & t^3 \end{bmatrix} \cdot \begin{bmatrix} 1 \\ 2t \\ 3t^2 \end{bmatrix}\\ &=2t+t^5+2t^6+10t+3t^5\\ &=12t+6t^5 \end{align*}\]

Another interesting way to interpret this result for the case where f:\mathbb{R}^3 \to \mathbb{R} and g:\mathbb{R} \to \mathbb{R}^3 is to recall that the directional derivative of f in the direction of some vector \vec{v} is:

\[D_{\vec{v}}f=(\nabla f) \cdot \vec{v}\]

In our case (\nabla f) is the Jacobian of f (because of its dimensionality). So if we take \vec{v} to be the vector Dg(t), and evaluate the gradient at g(t) we get [5]:

\[D_{\vec{Dg(t)}}f(t)=(\nabla f(g(t))) \cdot Dg(t)\]

This gives us some additional intuition for the temperature change question. The change in temperature as a function of time is the directional derivative of f in the direction of the change in location (Dg(t)).

For additional examples of applying the chain rule, see my post about softmax.

Tricks with the multivariate chain rule - derivative of products

Earlier in the article we've seen how the chain rule helps find derivatives of complicated functions by decomposing them into simpler functions. The multivariate chain rule allows even more of that, as the following example demonstrates. Suppose h(x)=f(x)g(x). Then, the well-known product rule of derivatives states that:

\[h'(x)=f'(x)g(x)+f(x)g'(x)\]

Proving this from first principles (the definition of the derivative as a limit) isn't hard, but I want to show how it stems very easily from the multivariate chain rule.

Let's begin by re-formulating h(x) as a composition of two functions. The first takes a vector \vec{s} in \mathbb{R}^2 and maps it to \mathbb{R} by computing the product of its two components:

\[p(\vec{s})=s_1 s_2\]

The second is a vector-valued function that maps a number x \in \mathbb{R} to \mathbb{R}^2 :

\[s(x)=\begin{pmatrix} f(x)\\ g(x) \end{pmatrix}\]

We can compose p \circ s, producing a function that takes a scalar an returns a scalar: (p \circ s) : \mathbb{R} \to \mathbb{R}. We get:

\[h(x)=(p \circ s)(x) = f(x)g(x)\]

Since we're composing two multivariate functions, we can apply the multivariate chain rule here:

\[\begin{align*} D(p \circ s) &= Dp(s(x)) \cdot Ds(x)\\              &=\begin{bmatrix}                 \frac{\partial p}{\partial s_1}(x) & \frac{\partial p}{\partial s_2}(x)             \end{bmatrix}\cdot                 \begin{bmatrix}                 {s_1}'(x)\\                 {s_2}'(x)               \end{bmatrix}\\              &=\begin{bmatrix}                 s_2(x) & s_1(x)                 \end{bmatrix}                 \cdot                 \begin{bmatrix}                 {s_1}'(x)\\                 {s_2}'(x)               \end{bmatrix}\\               &={s_1}'(x)s_2(x)+{s_2}'(x)s_1(x) \end{align*}\]

Since s_1(x)=f(x) and s_2(x)=g(x), this is exactly the product rule.

Connecting the single-variable and multivariate chain rules

Given function f(x) : \mathbb{R} \to \mathbb{R}, its Jacobian matrix has a single entry:

\[Df(a)=\begin{bmatrix}D_{x}f(a)\end{bmatrix}=       \begin{bmatrix}\frac{df}{dx}(a)\end{bmatrix}\]

Therefore, given two functions mapping \mathbb{R} \to \mathbb{R}, the derivative of their composition using the multivariate chain rule is:

\[D(f \circ g)(a)=Df(g(a))\cdot Dg(a)=f'(g(a))g'(a)\]

Which is precisely the single-variable chain rule. This results from matrix multiplication between two 1x1 matrices, which ends up being just the product of their single entries.

Appendix: proving the single-variable chain rule

It turns out that many online resources (including Khan Academy) provide a flawed proof for the chain rule. It's flawed due to a careless division by a quantity that may be zero. This flaw can be corrected by making the proof somewhat more complicated; I won't take that road here - for details see Spivak's Calculus. Instead, I'll present a simpler proof inspired by the one I found at Casey Douglas's site.

We want to prove that:

\[(f \circ g)'(x)={f}'(g(x)){g}'(x)\]

Note that previously we defined derivatives at some concrete point x_0. Here for the sake of brevity I'll just use x as an arbitrary point, assuming the derivative exists.

Let's start with the definition of g'(x):

\[{g}'(x)=\lim_{h \to 0}\frac{g(x+h)-g(x)}{h}\]

We can reorder it as follows:

\[\lim_{h \to 0}\left [ \frac{g(x+h)-g(x)}{h} - g'(x) \right ] = 0\]

Let's give the part in the brackets the name \Delta g(x).

Similarly, if the function f is differentiable at the point a=g(x), we have:

\[f'(a)=\lim_{k \to 0}\frac{f(a+k)-f(a)}{k}\]

We reorder:

\[\lim_{k \to 0}\left [ \frac{f(a+k)-f(a)}{k} - f'(a) \right ] = 0\]

And call the part in the brackets \Delta f(a). The choice of the variable used to go to zero: k instead of h is arbitrary and is useful to simplify the discussion that follows.

Let's reorder the definition of \Delta g(x) a bit:

\[g(x+h)=g(x)+[g'(x)+\Delta g(x)]h\]

We can apply f to both sides:

\[\begin{equation} f(g(x+h))=f(g(x)+[g'(x)+\Delta g(x)]h) \tag{1} \end{equation}\]

By reordering the definition of \Delta f(a) we get:

\[\begin{equation} f(a+k)=f(a)+[f'(a)+\Delta f(a)]k \tag{2} \end{equation}\]

Now taking the right-hand side of (1), we can look at it as f(a+k) since a=g(x) and we can define k=[g'(x)+\Delta g(x)]h. We still have k going to zero when h goes to zero. Assigning these a and k into (2) we get:

\[f(a+k)=f(g(x))+[f'(g(x))+\Delta f(g(x))][g'(x)+\Delta g(x)]h\]

So, starting from (1) again, we have:

\[\begin{align*} f(g(x+h))&=f(a+k) \\          &=f(g(x))+[f'(g(x))+\Delta f(g(x))][g'(x)+\Delta g(x)]h \end{align*}\]

Subtracting f(g(x)) from both sides and dividing by h (which is legal, since h is not zero, it's just very small) we get:

\[\frac{f(g(x+h))-f(g(x))}{h}=[f'(g(x))+\Delta f(g(x))][g'(x)+\Delta g(x)]\]

Apply a limit to both sides:

\[\lim_{h \to 0} \frac{f(g(x+h))-f(g(x))}{h}= \lim_{h \to 0} [f'(g(x))+\Delta f(g(x))][g'(x)+\Delta g(x)]\]

Now recall that both \Delta f(g(x)) and \Delta g(x) go to 0 when h goes to 0. Taking this into account, we get:

\[\lim_{h \to 0} \frac{f(g(x+h))-f(g(x))}{h}= f'(g(x))g'(x)\]

Q.E.D.


[1]Here, as in the rest of the post, I'm being careless with the usage of \lim, sometimes leaving its existence to be implicit. In general, wherever h appears in a formula we know there's a \lim_{h \to 0} there, whether explicitly or not.
[2]An alternative way to think about it is: suppose the functions f and g are linear: f(x)=ax+b and g(x)=cx+d. Then the chain rule is trivially true. But now recall what the derivative is. The derivative at some point x_0 is the best linear approximation for the function at that point. Therefore the chain rule is true for any pair of differentiable functions - even when the functions are not linear, we approximate their rate of change in an infinitisemal area around x_0 with a linear function.
[3]The condition for f being differentiable at a is stronger than simply saying that all partial derivatives exist at a, but I won't spend more time on this subtlety here.
[4]As an exercise, verify that the matrix dimensions of Df and Dg make this multiplication valid.
[5]It shouldn't be surprising we get here, since the definition of the directional derivative as the gradient was derived using the multivariate chain rule.