A series on automatic differentiation in Julia. Part 1 provides an overview and defines explicit chain rules.
This is part of a series. The other articles are:
All source code can be found at MicroGrad.jl.
A major convenience of modern machine learning frameworks is automatic differentiation (AD). Training a machine learning model typically consist of two steps, a forward pass and a backwards pass. The forward pass takes an input sample and calculates the result. Examples include a label in a classifier model or a word or image in a generative model. In the backward pass, the result is compared to a ground truth sample and the error is backpropagated throughout the model, from the final layers through to the start. Backpropagation is driven by gradients which are calculated with the differentiation rules of Calculus.
With modern machine learning frameworks, such as PyTorch or Flux.jl, only the forward pass needs to be defined and they will automatically generate the backward pass. This (1) makes them easier to use and (2) enforces consistency between the forward pass and backward pass.
Andrej Kaparthy made an excellent video where he built a minimal automatic differentiation module called Micrograd in Python. This is the first video in his Zero to Hero series. He later uses it to train a multi-layer perceptron model. I highly recommend it for anyone who wants to understand backpropagation.
The aim of this series is to create a minimal automatic differentiation package in Julia. It is based on Zygote.jl and works very differently to the Python AD packages. The latter are based on objects with their own custom implementations of mathematical operations that calculate both the forward and backward passes. All operations are only done with these objects.1 Zygote.jl is instead based on the principle that Julia is a functional programming language. It utilises Julia’s multiple dispatch feature and its comprehensive metaprogramming abilities to generate new code for the backward pass. Barring some limitations, it can be used to differentiate all existing functions as well as any custom code.
Zygote’s approach is complex and pushes the boundaries of Julia’s metaprogramming. It can sometimes be buggy. However its promise is true automatic differentiation of any forward pass code without further work on the coder’s part.
For the final code, see my MicroGrad.jl repository. It is very versatile but has several limitations, including less code coverage than Zygote.jl and it is unable to handle control flow or keyword arguments.
There are almost no comprehensive tutorials on AD in Julia and so this series aims to cover that gap. A good understanding of Julia and of Calculus is required.
The Julia automatic differentiation ecosystem is centered around three packages: Flux.jl, ChainRules.jl and Zygote.jl.
gradient
, withgradient
and pullback
. The pullback
function is a light wrapper around _pullback
which does most of the heavy lifting._pullback
is to dispatch a function, its arguments and its keyword arguments to a ChainRule.rrule
. If it cannot, it will inspect the code, decompose it into smaller steps, and follow the rules of differentiation to dispatch each of those to _pullback
to recursively find an rrule
. If this recursive process does not find a valid rule it will raise an error.frule
and rrule
. This series deals only with backpropagation, so it will only concentrate on rrule
.Also important is IRTools.jl, an extended metaprogramming package for working with an intermediate representation (IR) between raw Julia code and lowered code. MicroGrad.jl in particular is based on the example code at IRTools.jl with alignment with Zyogte.jl functions and names.
As an example, consider the function $f(x) = \sin(\cos(x))$. Using the chain rule of Calculus, it is differentiated as:
\[\begin{align} \frac{df}{dx} &= \frac{df}{dh}\frac{dh}{dx} \quad ; h(x)=cos(x)\\ &= \frac{d}{dh}\sin(h)\frac{d}{dx}\cos(x) \\ &= \cos(h)(-\sin(x)) \\ &= -\cos(\cos(x))\sin(x) \end{align}\]Zygote.withgradient
, exposed as Flux.withgradient
, can be used to calculate this:
More commonly we differentiate with respect to the model, not the data:
This is more useful for a model with parameters. For example a dense, fully connected layer:
The aim of the rest of the series is to recreate this functionality.
This first part will focus solely on ChainRules.jl and recreating the rrule
function.
Part 2 will focus on recreating the Zygote._pullback
function.
Part 3 repeats part 2 in a more robust manner.
Part 4 extends part 3’s solution to handle maps, anonymous functions and structs.
Finally part 5 shows how this AD code can be used by a machine learning framework.
ChainRules.jl’s rrule
returns the output of the forward pass $y(x)$ and a function $\mathcal{B}$ which calculates the backward pass.
$\mathcal{B}$ takes as input $\Delta = \frac{\partial l}{\partial y}$, the gradient of some scalar $l$ with regards to the output variable $y$, and returns a tuple of $\left(\frac{\partial l}{\partial \text{self}}, \frac{\partial l}{\partial x_1}, …, \frac{\partial l}{\partial x_n}\right)$, the gradient of $l$ with regards to each of the input variables $x_i$.
(The extra gradient $\frac{\partial l}{\partial \text{self}}$ is needed for internal fields and closures.
See the Dense
layer example above.)
According to the chain rule of Calculus, each gradient is calculated as:
As a starting point $\frac{\partial l}{\partial y}=1$ is used to evaluate only $\frac{\partial y}{\partial x}$.
If $x$ and $y$ are vectors, then the gradient $J=\frac{\partial y}{\partial x}$ is a Jacobian:
\[J = \begin{bmatrix} \frac{\partial y_1}{\partial x_1} & \dots & \frac{\partial y_1}{\partial x_n} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_m}{\partial x_1} & \dots & \frac{\partial y_m}{\partial x_n} \end{bmatrix}\]To maintain the correct order, we need to use the conjugate transpose (adjoint) of the Jacobian. So each gradient is calculated as:
\[\mathcal{B_i}(\Delta) = J_i^{\dagger} \Delta\]Note the Jacobian does not need to be explicitly calculated; only the product needs to be.
This is can be useful when coding the rrule
for matrix functions.
See the section on the chain rule for matrix multiplication later.
To start, define a default fallback for rrule
that returns nothing
for any function with any number of arguments (source):
An rrule
can now be defined for any function.
For it to be really useful rrule
must cover a large set of functions.
Thankfully ChainRules.jl provides us with that.
However in this post I’ll only work through a limited set of examples.
The derivatives of adding two variables is:
\[\frac{\partial}{\partial x}(x+y) = 1 + 0; \frac{\partial}{\partial y}(x+y) = 0 + 1\]$$ \begin{align} \Delta f_x &= (x+\Delta x+ y) - (x+y) \\ \therefore \lim_{\Delta x \to 0}\frac{\Delta f_x}{\Delta x} &=\frac{\partial f}{\partial x}= 1 \\ \therefore \lim_{\Delta y \to 0}\frac{\Delta f_y}{\Delta y} &=\frac{\partial f}{\partial y}= 1 \end{align} $$
There are no internal fields so $\frac{\partial l}{\partial \text{self}}$ is nothing
.
$\mathcal{B}$ can be returned as an anonymous function, but giving it the name add_back
helps with debugging (source).
Usage:
Subtraction is almost identical:
With multiplication, the incoming gradient is multiplied by the other variable:
\[\frac{\partial}{\partial x}(xy) = y; \frac{\partial}{\partial y}(xy) = x\]$$ \begin{align} \Delta f_x &= (x+\Delta x)y - xy \\ \therefore \lim_{\Delta x \to 0}\frac{\Delta f_x}{\Delta x} &=\frac{\partial f}{\partial x}= y \\ \therefore \lim_{\Delta y \to 0}\frac{\Delta f_y}{\Delta y} &=\frac{\partial f}{\partial y}= x \end{align} $$
In code (source):
Note that Julia will create a closure around the incoming x
and y
variables for times_back
.
A closure is when the function stores the values of variables from its parents scope (it closes over the variables).
In other words, x
and y
will become constants in the times_back
scope.
In this way, the times_back
function will always “remember” what values it was called with:
Example:
Every call to rrule
with *
will return a different back
instance based on the input arguments.
Division is slightly different in that the derivatives look different for $x$ and $y$:
\[\frac{\partial}{\partial x}\frac{x}{y} = \frac{1}{y}; \frac{\partial}{\partial y}\frac{x}{y}= -\frac{x}{y^2}\]$$ \begin{align} \Delta f_x &= \frac{x+\Delta x}{y} - \frac{x}{y} \\ \therefore \lim_{\Delta x \to 0}\frac{\Delta f_x}{\Delta x} &=\frac{\partial f}{\partial x}= \frac{1}{y} \\ \Delta f_y &= \frac{x}{y+\Delta y} - \frac{x}{y} \\ &= \frac{xy}{y(y+\Delta y)} - \frac{x(y+\Delta y)}{y(y+\Delta y)} \\ &= -\frac{x \Delta y}{y(y+\Delta y)} \\ \therefore \lim_{\Delta y \to 0}\frac{\Delta f_y}{\Delta y} &=\frac{\partial f}{\partial y} = -\frac{x}{y^2} \end{align} $$
Here we can calculate an internal variable Ω
to close over, and use it for the $\frac{\partial}{\partial y}$ derivative (source):
Example:
The derivatives of $\sin$ and $\cos$ are:
\[\begin{align} \frac{\partial}{\partial x} \sin(x) &= \cos(x) \\ \frac{\partial}{\partial x} \cos(x) &= -\sin(x) \end{align}\]Because both use $\sin$ and $\cos$, we can use sincos
to calculate both simultaneously and more efficiently than calculating each on its own. This shows the advantage of calculating the forward pass and backward pass at the same time (source):
Let’s now revisit the example from earlier, $f(x) = \sin(\cos(x))$. We have the forward pass:
\[\begin{align} y_1 &= \cos(x) \\ y_2 &= \sin(y_1)\\ \end{align}\]And the backwards pass:
\[\begin{align} \frac{\partial y_2}{\partial y_1} &= (1.0) \frac{\partial}{\partial y_1} \sin(y_1) \\ &= \cos(y_1) \\ \frac{\partial y_2}{\partial x} &= \frac{\partial y_2}{\partial y_1} \frac{\partial}{\partial x} \cos(x) \\ &= -\Delta_2 \sin(x) \end{align}\]In code:
The next section will showcase an example of polynomial curve fitting.
This requires an rrule
for the evalpoly
function.
For a general polynomial:
\[y = a_0 + a_1x + a_2x^2 + ... + a_n x^n\]The derivatives are:
\[\begin{align} \frac{\partial y}{\partial x} &= 0 + a_1 + 2a_2x^1 + ... + n a_n x^{n-1} \\ \frac{\partial y}{\partial a_i} &= 0 + ... + x^{i} + ... + 0 \end{align}\]For the most efficient implementation, the powers of $x$ can be calculated for both the forward and backwards pass at the same time. For simplicity, I’m not going to do that (source):
Usage:
For some scaler loss function $l$, we can calculate a derivative $\Delta=\frac{\partial l}{\partial Y}$ against some matrix $Y$. Then for $Y=AB$, the partial derivatives are:
\[\begin{align} \frac{\partial l}{\partial A} &= \frac{\partial Y}{\partial A} \frac{\partial L}{\partial Y} \\ &= \Delta B^T \\ \frac{\partial l}{\partial B} &= \frac{\partial Y}{\partial B} \frac{\partial L}{\partial Y} \\ &= A^T \Delta \end{align}\]Note that the Jacobians $\frac{\partial Y}{\partial A}$ and $\frac{\partial Y}{\partial B}$ are not explicitly calculated here; only the product is. (These Jacobians would have many zeros because each output element depends only on a small subset of the input elements.)
In code (source):
Test:
The mean square error (MSE) is a common loss function in machine learning. It will be used shortly for polynomial curve fitting. It is:
\[MSE(\hat{y}, y) = \frac{1}{n}\sum^n_{i=1} (\hat{y}_i - y_i)^2\]with derivatives:
\[\begin{align} \frac{\partial MSE}{\partial \hat{y}_i} &= \frac{1}{n}(0 + ... + 2(\hat{y}_i - y_i) + ... + 0) \\ &= \frac{2(\hat{y}_i - y_i)}{n} \\ \frac{\partial MSE}{\partial y_i} &= \frac{1}{n}(0 + ... - 2(\hat{y}_i - y_i) + ... + 0) \\ &= -\frac{2(\hat{y}_i - y_i)}{n} \end{align}\]In code it is:
Flux.jl does not define an rrule
for its mse
because it can be decomposed into functions which already have an rrule
(-
, broadcast
, abs2
and mean
).
However since we don’t have rrule
s for these parts and have not yet automated decomposition, it is simplest to create an rrule
for the entire function:
The mse
can also be applied per individual data point and summed up separately.
This form is not common but will be useful for explanatory purposes in the polynomial curve fitting section:
Gradient descent is a great algorithm to illustrate the usefulness of the code developed so far. The toy example of fitting a polynomial to data will be used. This is a useful example because (1) we can start with a target curve and so have ground truth values to compare and (2) this problem can be solved analytically without gradients.
Here is code to create the above data:
Analytical least squares fitting of polynomials ⇩
For a polynomial of order $p$, if there are exactly $n=p+1$ training samples (including for the constant $a_0$) than there exactly $n$ equations for $n$ unknowns ($a_0$,...,$a_p$) and this can be solved as an ordinary linear system: $$ \begin{align} &a_0 + a_1 x_1 + a_2x_1^2 + ... + a_p x_1^p = y_1 \\ &\vdots \\ &a_0 + a_1 x_n + a_2x_n^2 + ... + a_p x_n^p = y_n \\ &\Rightarrow \begin{bmatrix} 1 & x_1 & x_1^2 & \cdots & x_1^p \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & x_n & x_n^2 & \cdots & x_n^p \end{bmatrix} \begin{bmatrix} a_0 \\ \vdots \\ a_n \end{bmatrix} = \begin{bmatrix} y_1 \\ \vdots \\ y_n \end{bmatrix} \\ &\Rightarrow XA=Y \\ &\Rightarrow A = X^{-1}Y \end{align} $$ Where $X^{-1}$ usually exists because $X$ is a square matrix.
However usually $n > p + 1$ and thus $X^{-1}$ will not exist. In that case the pseudoinverse $X^+$, also called the Moore-Penrose inverse, can be used instead: $$ \begin{align} X^{+} &= (X^T X)^{-1} X^T \\ \Rightarrow A &= X^{+}Y \end{align} $$ It can be proven that this solution for $A$ minimises the least squared error.
Here is this solution in code:
Here is a simple version of gradient descent:
Gradient descent
while (criteria is not met) do:
$\quad$ $\Delta = 0$
$\quad$ for sample, label in train_set do:
$\quad\quad$ $\Delta \leftarrow \Delta + \frac{\partial}{\partial\theta_j}L$($m_{\theta_j}$(sample), label)
$\quad$ $\theta_{j+1}$ $\leftarrow \theta_j - \alpha \Delta$
where $m_\theta$ is the model with parameters $\theta$ and $L$ is the loss function.
This is a Julia implementation for specifically applying the algorithm to polynomials.
The stopping condition is a maximum number of iterations, so the while
loop has been replaced with a for
loop.
The code also saves the loss so that the training progress can be analysed.
Calling the code:
Plotting the history:
Comparing losses on the train set:
Method | Loss | Coefficients |
---|---|---|
Target | 416.62 | (15.0, -2.1, 13.9, 13.9, 1.5) |
Analytical | 391.64 | (15.34, -3.24, 13.84, 1.46) |
Gradient Descent | 498.50 | (1.37, 0.54, 14.51, 1.26) |
And finally, comparing the curves:
It is possible to replace the inner loop over the training data with map
.
This is code is slightly more complex than the previous version. The behaviour and performance is practically identical. However, it is one step closer to being more generic.
In machine learning, models usually execute on multiple inputs at once. We could make a polynomial model that does that:
The goal then is to get gradients for the model’s weights directly:
In the next sections we will write code that will inspect the model function call, recognise that it calls map
, and call a pullback
for map.2
This in turn will call the pullback
for evalpoly
, which will pass the arguments to the rrule
defined above.
The next two sections will develop the pullback
function.
It will inspect and decompose code with the goal of passing arguments to rrule
and accumulating gradients via the chain rule.
Part 2 will introduce metaprogamming Julia and generate expressions for the backpropagation code. However the code is unstable and prone to errors - it is recursive metaprogramming - so part 3 will introduce more robust code making use of the IRTools.jl package. This code really pushes Julia’s metaprogramming to its limits.
It is possible to jump straight to part 3 if desired.
For example, Micrograd defines a Value
class that has a custom definition for __add__
. This custom definition then calculates the forward pass and prepares the backwards pass. The same is true of the Tensor
objects in PyTorch. ↩
It is a design choice to use pullback
and not rrule
for map. Both rrule
and pullback
have the same outputs. However rrule
is intended for stand alone gradients, whereas pullback
will potentially involve recursive calls to itself. ↩