A series on automatic differentiation in Julia. Part 4 extends part 3 to handle maps, getfield and anonymous functions. It creates a generic gradient descent and uses this to fit a polynomial.
By end of part 3 we had code that could automatically differentiate many functions as long as we had rrules and there was no control flow.
However, the code failed for the polynomial model:
Calling @code_ir model(x), we can see that code is lowered as follows:
And further that model(1.0) is lowered to:
We could have also defined the map using an anonymous function:
In which case it would have been lowered to:
The calls to Core.typeof and Core.apply_type are in the list of ignored functions.
However we need to handle map, getproperty and %new.
These sort of functions do not have formal mathematical derivatives and so they do not have rrules in ChainRules.jl.
Instead, Zygote.jl handles these functions with their own custom pullbacks.
Zygote also replaces some low level functions like new, getproperty and getindex entirely with custom code.
2 Extending pullback
2.1 map
The pullback for map is fairly complex. What will be presented here is a simplified version.
It might also help to look at the less generic code in the example in part 1.
Consider the following code:
The pullback for map should return 3 values: $\text{s̄elf}$ for map, $\bar{f}$ for the function f and $\bar{x}$ for each value in x.
The code will start by getting pullbacks for each value in x:
This list is in a “zipped” format: there are $n$ entries of $(y_i, \mathcal{B}_i)$ for an array length $n$.
This will be unzipped into two lists each of length $n$: $(y_1,…,y_n), (\mathcal{B}_1,…,\mathcal{B}_n)$:
This is done with an unzip function which generalises first to any index i (source):
The result:
As a final step, all the gradients for the function are accumulated into one value:
Putting all this code in a single function (source):
Testing:
And also:
2.2 Instrument
Zygote.jl modifies some of the source code before creating the primal and reverse passes.
Here is a simplified version of this instrument function which only replaces new and getfield (source):
Modify the existing _generate_pullback_via_decomposition and _generate_callable_pullback functions to call it:
Now we need to define literal_getfield and __new__ and their pullbacks.
2.3 getfield
Calls to getproperty default to getfield, where a field is is declared in a struct’s declaration.
The getfield function is substituted with literal_getfield (source):
The pullback will return a NamedTuple for each field, where the gradient is Δ for the relevant field and nothing for the others (source):
For example:
And for the polynomial model:
For the first time we have a value $\text{s̄elf}$, which is the named tuple for the fields.
2.4 new
The code now works with:
It returns $\text{s̄elf}$ and $\bar{x}$:
However with an anonymous function:
nothing is returned for $\text{s̄elf}$:
If we inspect the primal(ir), we see that it’s because no pullbacks and hence no gradients are recorded against variable %1 (self):
The solution is to swap %new with a call to a custom function __new__ with a pullback.
This function is as follows (source):
Now if we try the following (after redefining @generated function pullback and function (methodinstance::Pullback)) we should get the same results:
3 Gradient Descent revisited
3.1 Generic Gradient Descent
Now that we have an automatic differentiation engine, it is possible to create a much more generic gradient descent function than in part 1:
Note that pullback(m->f(m), model) is directly equivalent to pullback(model) do f(m) end.
The update_params! function is defined as follows:
The parameters function is defined per model.
(Flux uses the generic Functors.jl library to accomplish something similar.)
3.2 Polynomial curve fitting revisited
Let’s create the exact same data set from part 1:
The Polynomial model is defined in the introduction. We also need a custom method for parameters:
Define the model:
Some sanity checks:
Train the model:
This works just as well as before.
4 Conclusion
We now have a fully working AD package.
It has some limitations, such as it cannot handle control flow or keyword arguments.
However it can already work on a wide variety of code.
All that might be needed is new rrule definitions.
The next and final part of this series is a demonstration of exactly that.