A series on automatic differentiation in Julia. Part 2 uses metaprogramming to generate a modified (primal) forward pass and to reverse differentiate it into a backward pass. This post uses an expression based approach which can be brittle. Part 3 develops a more robust approach for the same code using IRTools.jl.
This is part of a series. The other articles are:
All source code can be found at MicroGrad.jl. The code here is inspired by the example at IRTools.jl.
Part 1 introduced the rrule
for implementing chain rules.
The challenge now is to automate it.
This will be done through metaprogramming and generated functions.
For example, from part 1 there are rrule
s for +
, *
and /
.
The goal is then to automatically differentiate the following:
like so:
where pullback
is a @generated
function that inspects the lowered code for f
:
This is an advanced use of the Julia programming language. You should be comfortable with the language before reading this post. At the very least, the Julia documentation page on metaprogramming is required for this post and will be considered assumed knowledge, especially the sections on “Expressions and evaluation”, “Code Generation” and “Generated Functions”.
The Zygote.jl automatic differentiation (AD) package is a realisation of the paper Don’t Unroll Adjoint: Differentiating SSA-Form Programs (2019) by Michael J Innes.
The paper works with Wengert lists, also known as tapes, and a generalisation of it called Static Single Assignment (SSA) form.
The aim here is to develop a minimal AD package, so this series only focuses on the sections on Wengert lists.
A consequence is that the code will not be to handle any non-linear logic in Julia, for example any control flow like if
, while
or for
blocks.
The paper uses the same example as the introduction:
\[f(a, b) = \frac{a}{a + b^2} \tag{2.1} \label{eq:f}\]This can be broken down into smaller steps where each intermediate variable is saved. This is known as a Wengert list, or tape, or (backpropagation) graph:
\[\begin{align} y_1 &= b \times b \\ y_2 &= a + y_1 \\ y_3 &= a / y_2 \end{align} \tag{2.2} \label{eq:f_wengert}\]To differentiate this, all function calls are wrapped with a differentiation function $\mathcal{J}$ which returns both the output $y$ and a pullback function $\mathcal{B}$. This is called the primal form:
\[\begin{align} y_1, \mathcal{B}_1 &\leftarrow \mathcal{J}(\times, b, b) \\ y_2, \mathcal{B}_2 &\leftarrow \mathcal{J}(+, a, y_1) \\ y_3, \mathcal{B}_3 &\leftarrow \mathcal{J}(/, a, y_2) \end{align} \tag{2.3} \label{eq:primal}\]The pullback function $\mathcal{B}$ takes as input the gradient of a scalar $l$ (typically a loss function) to a function $y(x)$ and returns the gradient with regards to the variable $x$. This partial gradient $\frac{\partial l}{\partial x}$ is written as $\bar{x}$.
\[\begin{align} \bar{x} &= \frac{\partial l}{\partial x} = \frac{\partial l}{\partial y} \frac{\partial y}{\partial x} \end{align} \tag{2.4} \label{eq:bar_x}\]so we can write in this mathematical notation as:
\[\begin{align} \bar{x} &\leftarrow \mathcal{B}(\bar{y}) = \bar{y} \frac{\partial y}{\partial x}\\ \text{or} \quad \bar{x} &\leftarrow \mathcal{B}(\bar{y}) = J^{\dagger}\bar{y} \end{align} \tag{2.5} \label{eq:pullback}\]where $\bar{y}=\frac{\partial l}{\partial y}$ and $J=\frac{\partial y}{\partial x}$ is the Jacobian (gradient) for arrays.
The various partial gradients are calculated by reversing the list. Each pullback function $\mathcal{B}_i$ takes as input the previous gradient $\bar{y}_i$. The input is an existing gradient $\Delta$. At the start this is usually set to 1:
\[\begin{align} \text{s̄elf}_3, \bar{a}_{3,1}, \bar{y}_2 &\leftarrow \mathcal{B}_3(\Delta) \\ \text{s̄elf}_2, \bar{a}_{2,1}, \bar{y}_1 &\leftarrow \mathcal{B}_2(\bar{y}_2) \\ \text{s̄elf}_1, \bar{b}_{1,1}, \bar{b}_{1,2} &\leftarrow \mathcal{B}_1(\bar{y}_1) \end{align} \tag{2.6} \label{eq:reverse}\]The final step is to accumulate the gradients for variables which are used multiple times:
\[\begin{align} \bar{a} &\leftarrow \bar{a}_{3,1} + \bar{a}_{2,1} \\ \bar{b} &\leftarrow \bar{b}_{1,1} + \bar{b}_{1,2} \\ \end{align} \tag{2.7} \label{eq:accumulate}\]This end result is equivalent to rolling everything up into one function using the multivariable chain rule:
\[\begin{align} \bar{a} &= \frac{\partial l}{\partial a} = \mathcal{B}_{3,a}(\Delta) + \mathcal{B}_{2,a}(\bar{y}_2) \\ &= \frac{\partial l}{\partial y_3} \frac{\partial y_3}{\partial a} + \frac{\partial l}{\partial y_2} \frac{\partial y_2}{\partial a} \\ &= \Delta \cdot \frac{\partial }{\partial a} \left( \frac{a}{y_2}\right) + \left(\frac{\partial l}{\partial y_3}\frac{\partial y_3}{\partial y_2} \right)\frac{\partial}{\partial a}(a + y_1) \\ &= \Delta \frac{1}{y_2} + \left(\Delta \frac{-a}{y_2^2} \right) (1+0) \\ &= \Delta \frac{b^2}{(a+b^2)^2} \\ \bar{b} &= \frac{\partial l}{\partial b} = 2 \mathcal{B}_{1,b}(\bar{y}_1) \\ &= 2\frac{\partial l}{\partial y_1} \frac{\partial y_1}{\partial b} \\ &= 2 \left(\frac{\partial l}{\partial y_3}\frac{\partial y_3}{\partial y_2}\frac{\partial y_2}{\partial y_1} \right) \frac{\partial y_1}{\partial b} \\ &= 2 \left(\Delta \cdot \frac{\partial}{\partial y_2}\left(\frac{a}{y_2}\right) \cdot \frac{\partial}{\partial y_1}(a + y_1) \right)\frac{\partial}{\partial b'}(b'\times b) \\ &= 2\left(\Delta \left(-\frac{a}{y_2^2}\right)(0+1)\right)b \\ &= -\frac{2ab\Delta}{(a+b^2)^2} \end{align} \tag{2.8} \label{eq:rollup}\]The goal is to generate code which automatically implements the equations of section 2.
To start, define a pullback
function (source):
This will be turned into a generated function.
Julia changed the behaviour of generated functions in version 1.10.
Before 1.10, they always had access to the world age counter.
This is a single number that is incremented every time a method is defined, and helps optimise compilations.
However from version 1.10 generated functions Base.get_world_counter()
will only return typemax(UInt)
.
This is to prevent reflection - code inspection - in generated functions.1
However the code here relies on reflection.
Thankfully, there is a hack that Zygote.jl uses to access the world age in pullback
.
Because of this, the definition of pullback
is different based on the version, but both will forward to a common internal _generate_pullback
function.
The first goal of _generate_pullback
will be to forward the function and its arguments to a matching rrule
if it exists.
For now it will throw an error if it cannot find one.
In part 1 the most generic method of rrule
was defined for an Any
first argument, so if the compiler dispatches to this method it means no specific rrule
was found.2
The meta
function uses the internal reflection function Base._methods_by_ftype
to get all the methods for a specific type. (This same function is used by methods
.)
The most specific method is assumed to be the last one (source):
Let’s test all this code from bottom to top for a function with an rrule
and one without: +
and f(a,b)=a/(a+b*b)
.
As a reminder, generated functions only have access to a variables types, so to test the _generate_pullback
and all functions under it, we can only work with the types.
Firstly, for +
acting on floats:
Now for f
, also acting on floats:
The more interesting task is to inspect f
and apply the equations of section 2 to fully differentiate with respect to all input parameters.
The first step is to create a Wengert list for f
.
This is trivial because Julia already does this as part of the compilation process.
As the first step of lowering code, the compiler will create an Abstract Syntax Tree (AST) which in the absence of control flow is the same as a Wengert list.
This AST can be retrieved by calling Base.uncompressed_ast
on the method we have found above:
The returned object is a CodeInfo struct and it corresponds exactly to $\ref{eq:f_wengert}$.
Using this knowledge, we can now create a new function _generate_pullback_via_decomposition
which will be called if no rrule
exists.
It uses the CodeInfo
block to create the primal (equation $\ref{eq:primal}$) (source).
The goal here is to create an expression for equation $\ref{eq:primal}$. This is what it will look like:
Note that this expression cannot be executed because it still has slot numbers which correspond to input arguments (_X
), and SSA values which correspond to intermediate values (e.g. %X
).
This will be fixed in the Sanitise section.
The first step for the primal function is to define three arrays to store information (source):
The tape
array stores the new expressions which will be part of the final expression.
The calls
array stores the subset of expressions that require a pullback.
This will be used to generate the reverse code (equation $\ref{eq:reverse}$) in the next section.
Lastly, pullbacks
stores all the pullbacks.
Next, iterate over each line in the CodeInfo
instance.
Each output variable will be called y$i
.
Then the line’s expression type is inspected.
This minimal code cannot handle control flow or the creation of new objects, so errors will be explicitly thrown if those cases are encountered.
(Please refer to the Lowered form section in the Julia documentation.)
If the expression is of type Expr
and it makes a call, and it is not in a specialised ignore list (to be defined shortly), then the new expression can be created and the three arrays updated.
Otherwise, leave as is.
After working through all the lines, a final expression is added which returns a tuple with the final output of the function and a Pullback
struct which stores all the pullbacks.
Everything is then grouped into a single :block
expression:
This code requires definitions for the Pullback
struct as well as the following functions: ignored
, xcall
and returnvalue
.
There are no closures in lowered Julia code, so instead Zygote.jl stores the pullbacks in a generic struct:
In the next section this struct will be turned into a callable struct.
That is, for back=Pullback{S}(data)
, we will create a generated function that dispatches on itself: (j::Pullback)(Δ)
so that we can call back(Δ)
. This back
has all the information to generate the reverse pass independently of the forward pass: the method can be retrieved using meta(S)
and the relevant data and input parameters from back.data
.
Here is the ignored functions list (source):
xcall
and returnvalue
are convenience functions from IRTools:
Running this code:
gives the expression at the start.
To evaluate the expression we need to remove all slot values and SSA values.
For the slot values (_X
), the first parameter in T
will always be the function f
, and the remainder are from args
.
Therefore the first slot needs to be replaced with the symbol :f
, and the remainder with Base.getindex(args, idx)
where idx
is offset by 1.
Here are two recursive functions to accomplish this:
The SSA values (%id
) need to be replaced by the y$id
symbol:
Running this code on pr
:
Results in:
We can now complete _generate_pullback
to also call the decomposition code:
Testing (you should redefine the @generated pullback
function first):
The goal is to now turn Pullback
into a callable struct so that we can call back(1.0)
to evaluate equations $\ref{eq:reverse}$ and $\ref{eq:accumulate}$.
With typeof(back)
and back.data
we have all the information to do this independent from the forward pass.
The result will be:
As with the forward pass, an internal function _generate_callable_pullback
will do most of the work.
It uses the meta
function defined above to get the CodeInfo
struct based on the input types:
The reverse_differentiate
function is a simplified version of Zygote.adjoint and Zygote.reverse_stacks!.
To start, a dictionary is created to store the gradients.
It maps variable names (symbols) to an array of gradients.
It is not accessed directly (e.g. grads[x]
) but rather through the closure functions grad
and grad!
which automatically handle the arrays.
The first gradient stored is Δ
associated with the final return value of the forward pass.
(_var_name
and xaccum
will be defined shortly.)
The tape
for the expression block is started by retrieving the data
field in the struct.
Next the code retrieves all the calls with pullbacks from the primal and loops over them, calling the pullbacks one by one.
For each call it also loops over the input arguments and unpacks them one by one.
Each variable’s gradient is added to grads
and may be used later in the loop.
The _var_name
function ensures that the keys of grads
can be connected back to the original functions.
Finally, the last call retrieves all the necessary gradients for the input arguments and returns a single quote
block.
This code required the following functions: xaccum
, _var_name
and arguments
. They are as follows:
The xaccum
function calls an internal accumulate function if it acts on multiple inputs.
At its simplest, accum
is the same as sum
.
However it also handles nothing
inputs, Tuples
s and NameTuple
s (source).
Examples:
Finally, dispatch on the Pullback
struct to turn it into a callable struct:
Testing:
The results should match equation $\ref{eq:rollup}$:
This code works well enough for this simple case. It also works for the trigonometry example from part 1:
However it will fail for the polynomial model:
The error is raised three levels down:
This can be fixed by explicitly writing a pullback for map
.
However rather than fixing it here, I first want to rewrite the code using IRTools.
The code written here is brittle and difficult to debug.
Instead of writing expressions, it would be better to directly create a CodeInfo
struct which always contains valid code.
Julia does not allow us to do that, but working with an IR
object which can be readily converted is the next best thing.
This is will be the goal of part 3.
Presumably the reason the Julia team tried to prevent reflection in generated functions is that it interferes with the compliers ability to properly predict, trigger and/or optimise compilations. ↩
Zygote.jl has more complex rules which also consider other fallbacks, key word arguments and a possible opt out through a no_rrule
. ↩