A series on automatic differentiation in Julia. Part 3 uses metaprogramming based on IRTools.jl to generate a modified (primal) forward pass and to reverse differentiate it into a backward pass. This is a more robust approach than the expression based approach in Part 2.
This is part of a series. The other articles are:
All source code can be found at MicroGrad.jl. The code here is based on the example at IRTools.jl.
Part 1 introduced the rrule
for implementing chain rules
and Part 2 defined a @generated pullback
function for inspecting and decomposing complex code.
The goal here is to replicate the results of Part 2 except in a more robust manner using the IRTools.jl package.
For example, from part 1 there are rrule
s for +
, *
and /
.
The goal is then to automatically differentiate the following:
like so:
where pullback
is a @generated
function that inspects the Intermediate Representation (IR) code for f
:
This is an advanced use of the Julia programming language. You should be comfortable with the language before reading this post. At the very least, the Julia documentation page on metaprogramming is required for this post and will be considered assumed knowledge, especially the sections on “Expressions and evaluation”, “Code Generation” and “Generated Functions”. I also suggest going through the IRTools.jl documentation first.
This post can be read independently to Part 2 and will repeat parts of it. However it is advised to read Part 2 first because it is easier to understand than this post.
The Zygote.jl automatic differentiation (AD) package is a realisation of the paper Don’t Unroll Adjoint: Differentiating SSA-Form Programs (2019) by Michael J Innes.
The paper works with Wengert lists, also known as tapes, and a generalisation of it called Static Single Assignment (SSA) form.
The aim here is to develop a minimal AD package, so this series only focuses on the sections on Wengert lists.
A consequence is that the code will not be to handle any non-linear logic in Julia, for example any control flow like if
, while
or for
blocks.
The paper uses the same example as the introduction:
\[f(a, b) = \frac{a}{a + b^2} \tag{2.1} \label{eq:f}\]This can be broken down into smaller steps where each intermediate variable is saved. This is known as a Wengert list, or tape, or (backpropagation) graph:
\[\begin{align} y_1 &= b \times b \\ y_2 &= a + y_1 \\ y_3 &= a / y_2 \end{align} \tag{2.2} \label{eq:f_wengert}\]To differentiate this, all function calls are wrapped with a differentiation function $\mathcal{J}$ which returns both the output $y$ and a pullback function $\mathcal{B}$. This is called the primal form:
\[\begin{align} y_1, \mathcal{B}_1 &\leftarrow \mathcal{J}(\times, b, b) \\ y_2, \mathcal{B}_2 &\leftarrow \mathcal{J}(+, a, y_1) \\ y_3, \mathcal{B}_3 &\leftarrow \mathcal{J}(/, a, y_2) \end{align} \tag{2.3} \label{eq:primal}\]The pullback function $\mathcal{B}$ takes as input the gradient of a scalar $l$ (typically a loss function) to a function $y(x)$ and returns the gradient with regards to the variable $x$. This partial gradient $\frac{\partial l}{\partial x}$ is written as $\bar{x}$.
\[\begin{align} \bar{x} &= \frac{\partial l}{\partial x} = \frac{\partial l}{\partial y} \frac{\partial y}{\partial x} \end{align} \tag{2.4} \label{eq:bar_x}\]so we can write in this mathematical notation as:
\[\begin{align} \bar{x} &\leftarrow \mathcal{B}(\bar{y}) = \bar{y} \frac{\partial y}{\partial x}\\ \text{or} \quad \bar{x} &\leftarrow \mathcal{B}(\bar{y}) = J^{\dagger}\bar{y} \end{align} \tag{2.5} \label{eq:pullback}\]where $\bar{y}=\frac{\partial l}{\partial y}$ and $J=\frac{\partial y}{\partial x}$ is the Jacobian (gradient) for arrays.
The various partial gradients are calculated by reversing the list. Each pullback function $\mathcal{B}_i$ takes as input the previous gradient $\bar{y}_i$. The input is an existing gradient $\Delta$. At the start this is usually set to 1:
\[\begin{align} \text{s̄elf}_3, \bar{a}_{3,1}, \bar{y}_2 &\leftarrow \mathcal{B}_3(\Delta) \\ \text{s̄elf}_2, \bar{a}_{2,1}, \bar{y}_1 &\leftarrow \mathcal{B}_2(\bar{y}_2) \\ \text{s̄elf}_1, \bar{b}_{1,1}, \bar{b}_{1,2} &\leftarrow \mathcal{B}_1(\bar{y}_1) \end{align} \tag{2.6} \label{eq:reverse}\]The final step is to accumulate the gradients for variables which are used multiple times:
\[\begin{align} \bar{a} &\leftarrow \bar{a}_{3,1} + \bar{a}_{2,1} \\ \bar{b} &\leftarrow \bar{b}_{1,1} + \bar{b}_{1,2} \\ \end{align} \tag{2.7} \label{eq:accumulate}\]This end result is equivalent to rolling everything up into one function using the multivariable chain rule:
\[\begin{align} \bar{a} &= \frac{\partial l}{\partial a} = \mathcal{B}_{3,a}(\Delta) + \mathcal{B}_{2,a}(\bar{y}_2) \\ &= \frac{\partial l}{\partial y_3} \frac{\partial y_3}{\partial a} + \frac{\partial l}{\partial y_2} \frac{\partial y_2}{\partial a} \\ &= \Delta \cdot \frac{\partial }{\partial a} \left( \frac{a}{y_2}\right) + \left(\frac{\partial l}{\partial y_3}\frac{\partial y_3}{\partial y_2} \right)\frac{\partial}{\partial a}(a + y_1) \\ &= \Delta \frac{1}{y_2} + \left(\Delta \frac{-a}{y_2^2} \right) (1+0) \\ &= \Delta \frac{b^2}{(a+b^2)^2} \\ \bar{b} &= \frac{\partial l}{\partial b} = 2 \mathcal{B}_{1,b}(\bar{y}_1) \\ &= 2\frac{\partial l}{\partial y_1} \frac{\partial y_1}{\partial b} \\ &= 2 \left(\frac{\partial l}{\partial y_3}\frac{\partial y_3}{\partial y_2}\frac{\partial y_2}{\partial y_1} \right) \frac{\partial y_1}{\partial b} \\ &= 2 \left(\Delta \cdot \frac{\partial}{\partial y_2}\left(\frac{a}{y_2}\right) \cdot \frac{\partial}{\partial y_1}(a + y_1) \right)\frac{\partial}{\partial b'}(b'\times b) \\ &= 2\left(\Delta \left(-\frac{a}{y_2^2}\right)(0+1)\right)b \\ &= -\frac{2ab\Delta}{(a+b^2)^2} \end{align} \tag{2.8} \label{eq:rollup}\]The goal is to generate code which automatically implements the equations of section 2.
To start, define a pullback
function (source):
This will be turned into a generated function.
Julia changed the behaviour of generated functions in version 1.10.
Before 1.10, they always had access to the world age counter.
This is a single number that is incremented every time a method is defined, and helps optimise compilations.
However from version 1.10 generated functions Base.get_world_counter()
will only return typemax(UInt)
.
This is to prevent reflection - code inspection - in generated functions.1
However the code here relies on reflection.
Thankfully, there is a hack that Zygote.jl uses to access the world age in pullback
.
Because of this, the definition of pullback
is different based on the version, but both will forward to a common internal _generate_pullback
function.
The first goal of _generate_pullback
will be to forward the function and its arguments to a matching rrule
if it exists.
For now it will throw an error if it cannot find one.
In part 1 the most generic method of rrule
was defined for an Any
first argument, so if the compiler dispatches to this method it means no specific rrule
was found.2
Let’s test all this code from bottom to top for a function with an rrule
and one without: +
and f(a,b)=a/(a+b*b)
.
As a reminder, generated functions only have access to a variables types, so to test the _generate_pullback
and all functions under it, we can only work with the types.
Firstly, for +
acting on floats (redefine @generated pullback
if necessary):
Now for f
, also acting on floats:
The more interesting task is to inspect f
and apply the equations of section 2 to fully differentiate with respect to all input parameters.
The first step is to create a Wengert list for f
in Intermediate Representation (IR) form.
Julia already does this as part of the compilation process.
IRTools.jl mimics this internal IR form with its own custom IR struct.
It can be generated as follows:
The returned object corresponds exactly to $\ref{eq:f_wengert}$.
Using this knowledge, we can now create a new function _generate_pullback_via_decomposition
which will be called if no rrule
exists.
It uses the IR to create the primal (equation $\ref{eq:primal}$) (source).
The goal here is to create an IR for equation $\ref{eq:primal}$. This is what it will look like:
Although harder to read, this code represents the same code as the expressions in part 2.
The primal function first wraps the existing IR with Pipe
to make inserts more efficient.
It defines two arrays to store information (source):
The calls
array stores the subset of variables that require a pullback.
Because the IR is a dictionary - ir[Variable(i)]
returns statement i
- this creates a direct link to the statement called.
These will be used to generate the reverse code (equation $\ref{eq:reverse}$) in the next section.
Next, iterate over each statement in the IR.
For each statement if it is an expression :call
and not part of a special ignored list, replace it with three calls: the first is to pullback
and then two calls to getindex
to get the output variable v
and back function J
from the tuple t
:
After working through all the statements, a final statement is added which returns a tuple with the output of the function and a Pullback
struct which stores all the pullbacks.
In the last step the pipe is converted back into an IR.
This code requires a definition for the Pullback
struct as well as the ignored
function.
There are no closures in lowered Julia code, so instead Zygote.jl stores the pullbacks in a generic struct:
In the next section this struct will be turned into a callable struct.
That is, for back=Pullback{S}(data)
, we will create a generated function that dispatches on itself: (j::Pullback)(Δ)
so that we can call back(Δ)
. This back
has all the information to generate the reverse pass independently of the forward pass: the method can be retrieved using meta(S)
and the relevant data and input parameters from back.data
.
Here is the ignored functions list (source):
Running this code:
gives the IR at the start.
To evaluate the IR it needs to be converted into a CodeInfo
struct.
Zygote.jl uses IRTools.Inner.update!
to modify the existing struct in meta_T.code
.
To me, it makes more sense to construct a new code info block directly from the IR using a slightly modified version of IRTools.Inner.build_codeinfo
:
This can now be used in _generate_pullback
:
Testing (you should redefine the @generated pullback
function first):
The goal is to now turn Pullback
into a callable struct so that we can call back(1.0)
to evaluate equations $\ref{eq:reverse}$ and $\ref{eq:accumulate}$.
With typeof(back)
and back.data
we have all the information to do this independent from the forward pass.
The result will be:
Although harder to read, this code represents the same code as the expressions in part 2.
As with the forward pass, an internal function _generate_callable_pullback
will do most of the work:
The reverse_differentiate
function is a simplified version of Zygote.adjoint and Zygote.reverse_stacks!.
To start, a dictionary is created to store the gradients.
It maps variable names (symbols) to an array of gradients.
It is not accessed directly (e.g. grads[x]
) but rather through the closure functions grad
and grad!
which automatically handle the arrays.
The first gradient stored is %2=Δ
associated with the final return value of the forward pass.
(xaccum
will be defined shortly.)
The first statement retrieves the data
field in the struct.
Next the code retrieves all the calls with pullbacks from the primal and loops over them, calling the pullbacks one by one.
For each call it also loops over the input arguments and unpacks them one by one.
Each variable’s gradient is added to grads
and may be used later in the loop.
Finally, the last call retrieves all the necessary gradients for the input arguments and returns the IR:
This code calls a xaccum
function. It is as follows:
The xaccum
function calls an internal accumulate function if it acts on multiple inputs.
At its simplest, accum
is the same as sum
.
However it also handles nothing
inputs, Tuples
s and NameTuple
s (source).
Examples:
Finally, dispatch on the Pullback
struct to turn it into a callable struct:
Testing:
The results should match equation $\ref{eq:rollup}$:
This code works well enough for this simple case. It also works for the trigonometry example from part 1:
However it will fail for the polynomial model:
The error is raised five levels down:
This can be fixed by explicitly defining a pullback for map
.
These and other extensions will be the goal of part 4.
Presumably the reason the Julia team tried to prevent reflection in generated functions is that it interferes with the compliers ability to properly predict, trigger and/or optimise compilations. ↩
Zygote.jl has more complex rules which also consider other fallbacks, key word arguments and a possible opt out through a no_rrule
. ↩