A series on automatic differentiation in Julia. Part 5 shows how the MicroGrad.jl code can be used for a machine learning framework like Flux.jl. The working example is a multi-layer perceptron trained on the moons dataset.
The previous four sections have developed a minimal automatic differentiation package.
The aim of this part is to demonstrate how it can be used as the backbone for a machine learning framework like Flux.jl.
In this post we will create a multi-layer perceptron also known as a fully connected neural network.
This is an extremely popular and powerful machine learning model.
New code will be needed for the forward pass and for some extra rrules.
Otherwise, the rest is handled by code from the previous parts.
2 Moons dataset
The moons dataset is a toy dataset for testing and visualising classification algorithms.
While clearly distinct, the curved nature of the two classes requires a non-linear algorithm to discern them.
This was the dataset chosen by Karpathy to demonstrate his micrograd package, and so it will be used here too.
This dataset can be reconstructed in Julia as follows, based on the Scikit-Learn function:
Creating the moons and labels:
3 Layers
3.1 ReLU
The Rectified Linear Unit (ReLU) is a common activation function in machine learning. It is defined as follows:
This is the code from Flux.jl to create this fully connected layer (source):
Also add a method to paramaters:
Create and test:
3.3 Reverse broadcast
Inspect the IR @code_ir layer(X):
From part 1 and part 4 we have rrules for getproperty (getfield), matrix multiplication (*) and for the activation (relu). We still need rrules for broadcasted and materialize.
Creating rules for broadcasting in general is complex1, so instead create a specific rule for the broadcast invoked here:
Testing:
The definition for Base.Broadcast.materialize is:
Hence we need rrules for copy and instantiate (source):
Now the pullback for the Dense layer works:
3.4 Chain
Here is the Flux code to create a generic chain (source):
The output of the machine learning model will be a probability $p_j$ for a sample $j$ being in a certain class. This will be compared to a probability for a known label $y_j$, which is either 1 if that sample is in the class or 0 if it is not.
An obvious value to maximise is their product:
\[y_j p_j
\tag{4.1}\]
with range $[0, 1]$.
However most machine learning optimisation algorithms aim to minimise a loss.
So instead $p_j$ is scaled as $-\log(p_j)$, so that the loss ranges from $[0, \infty)$ with the goal to minimise it at 0.
This is called the cross entropy loss:
The outputs of the neural network are not probabilities but instead a vector of logits containing $N$ real values for $N$ classes.
By convention these logits are scaled to a probability distribution using the softmax function:
where $z_{ij}$ is the output of the logsoftmax function. Assuming that $y_{ij}$ is 1 for exactly one value of $i$ and 0 otherwise, this can be simplified to:
In Julia this can be implemented as follows (source):
According to the multivariable chain rule, the derivative with respect to one logit $x_{ij}$ in the vector for sample $j$ is (gradients come from the main case $k=i$ case as well as the sum in the softmax for $k\neq i$):
In Julia this can be implemented as follows (source):
Testing:
5 Train and Evaluate
5.1 Train
Create the moons data and labels:
Convert the labels to a one hot presentation:
Create the model:
Test the loss function:
Use the exact same gradient_descent! function from part 4:
5.2 Evaluate
Plot the history:
Calculate accuracy:
Plot decision boundary:
Plot points over the boundary:
The result:
6 Conclusion
That was a long and difficult journey.
I hope you understand how automatic differentiation with Zygote.jl works now!
The Zygote.jl code for broadcast has this gem of a comment:
There's a saying that debugging code is about twice as hard as
writing it in the first place. So if you're as clever as you can
be when writing code, how will you ever debug it?
AD faces a similar dilemma: if you write code that's as clever as
the compiler can handle, how will you ever differentiate it?
Differentiating makes clever code that bit more complex and the
compiler gives up, usually resulting in 100x worse performance.
Base's broadcasting is very cleverly written, and this makes
differentiating it... somewhat tricky.