A denoising diffusion probabilistic model for generating numbers based on the MNIST dataset. The underlying machine learning model is a U-Net model, which is a convolutional autoencoder with skip connections. A large part of this post is dedicated to implementing this model architecture in Julia.
Update 28 August 2023: code refactoring and update to Flux 0.13.11 explicit syntax.
This is part of a series. The other articles are:
Full code available at github.com/LiorSinai/DenoisingDiffusion.jl. Examples notebooks at github.com/LiorSinai/DenoisingDiffusion-examples. Google Colab training notebook at DenoisingDiffusion-MNIST.ipynb.
Part 1 detailed the first principles of denoising diffusion probabilistic models for data generation. It applied the technique to generate a spiral. It is now time to extend it to image generation. To limit the amount of training data and the model size, we’ll be focusing on the task of generating hand written numbers instead of wholly unique artworks. We’ll use the famous MNIST database as our data source. This dataset is commonly used in introductory machine learning courses on data classification. However it works just as well for an introduction to data generation. (It is recommended to have experience with MNIST and basic convolutional models because we’ll be making a more complex type of convolutional model here.)
The good news is that we can reuse almost all of the code from part 1. The bad news is that we’ll need a much more complex denoising model to get good results. Even with the given limitations, this task is much more complex than the spiral. Each sample has 28×28 pixels for a total of 784 features per sample. For the spiral, each sample had 2 features: 2 co-ordinates. This is therefore two orders of magnitude harder. We’ll find we that the underlying machine learning model scales accordingly. So instead of 5,400 parameters the final model will have at least 400,000. A large part of this post is solely dedicated to building the model.
As a reminder from the first part, the purpose of the machine learning model is not exactly obvious. It is used to predict the total noise that needs to be removed from the current iteration in order to produce a valid sample on the last iteration. This total noise is then used in analytical equations to calculate the amount of noise that needs to be removed in the current iteration, which is necessarily only some fractional part. The purpose of multiple iterations is to refine the predicted noise and hence refine the final sample. Please review part 1-reverse process for a full explanation.
The U-Net model architecture was first introduced in the 2015 paper U-Net: Convolutional Networks for Biomedical Image Segmentation by Olaf Ronneberger, Philipp Fischer and Thomas Brox. The name U-Net comes from the shape of the model in the schematic that the original authors drew. See above. It is debatable if this is the best choice of name. There are other ways to represent the model and this name obscures the main features. A U-Net model can best be described as a convolutional autoencoder with skip connections (also called residual connections). Here is a linear representation that shows this more clearly, where the skip connections form the U’s:
None the less, the name U-Net has come to mean exactly this in literature - a convolutional autoencoder with skip connections. Therefore I’ll keep using the name.1
Many papers and websites still cite the original 2015 paper. However this model should be seen as a “grandfather” to more recent versions. It is hard to tell from papers what their actual models look like. Looking at the reference code we find several improvements on the original design (e.g. no cropping) and many different kinds of layers which differ from codebase to codebase. Furthermore there is usually flexibility with various parameters of the model.
I’ve seen a U-Net model described as “stacking layers like Lego bricks”. It is really up to you to decide what is best and what layers you want to use. It is an open question what value any of the more complex layers provide. At the very least what can be said is that in machine learning brute force wins over targeted optimisation: for more complex tasks bigger and deeper models are preferable.
But to summarise, we can say that a U-Net model has three primary features:
The second point is what makes it an autoencoder. It is worth reading theories on the latent space of variational autoencoders which justifies the whole model architecture.
This is the model architecture that will be implemented here. It is based on PyTorch models by OpenAI and LucidRains.
Each downsample layer will decrease the image height and width by a factor of 2 but increase the channel dimension by a multiplication factor $m$. The sample is therefore downscaled by an overall factor of $\tfrac{1}{2}\tfrac{1}{2}(m)=\tfrac{m}{4}$ per level. This is reversed with the upsample layers. The model however is based on convolutional layers whose size are independent of the input image height and width and only depend on the channel dimension. (The inference time is a factor of the image size.) Each layer grows with a factor $d^2$ for a channel dimension $d$. Therefore the largest layers are the middle (bottom) layers where the channel dimension is largest. Also because of the concatenation on the channel dimension, the upside layers will tend to be much larger than their downside counterparts.
This figure shows this exponential relationship with the channel dimension:
So despite the symmetrical nature of the schematic, the model is not very symmetric.
The rest of this section will detail the code for the model in a top-down fashion. For the full model see UNet.jl. Please see this source code for the printing functions, which are not detailed here.
The blocks used in the model are described in the next section, Blocks.
There is a slight problem with making the U-Net model in Julia. The reference models are based in PyTorch and the code unfortunately cannot be directly translated because of restrictions in the Julia language. The PyTorch implementations use an array for the skip connections. They append to it on the down side and then pop values out on the up side. The Julia machine libary Flux however does not support these mutating array operations in backward passes.
Here is a minimum working example.
Define the model and a basic forward pass:
Test the forward pass using a very basic model:
This works. However during training we need to apply the backward pass. Here is an example of that:
The last line will result in this error:
There are three solutions:
Option 2 is best when the operation can be isolated to only one layer.2 However here it encompasses most of the forward pass so it essentially is the same as option 3. Option 3 is not a good idea because (1) it requires much more work and (2) the forward pass and backward pass will not be automatically in sync.
So we are left with option 1. There are two ways to implement it. The first is to fix the amount of layers so that we don’t need a mutable array. For example, for the minimum working example with only two layers:
The backward pass code will now work.
For a full working example, see my UNetFixed
model at UNetFixed.jl.
The second is to embrace that Julia is a functional programming language and use a functional approach.
We can use a Flux.SkipConnection
layer to implement the skip connection.
This will lead to a much more flexible model and hence is the construction that will be described in the remainder of this post.
For example:
Again the backward pass code will now work.
For a model with multiple skip connections we can nest layers recursively.
As in part 1 we need to pass two inputs to the model and so we’ll need to reuse the ConditionalChain
from
part 1.
We’ll need a custom ConditionalSkipConnection
layer which can handle multiple inputs too.
The source code for Flux.SkipConnection
can easily be adapted for this:
The model is all held in a struct:
The model will have a complex constructor. The most important parameters are the number of in channels (1 for a grey image and 3 for RGB images), the model dimension $d$ and the number of time steps $T$. The total number of parameters scales with $d^2$. So if we double the model channel we will quadruple the number of parameters.
Additionally, the user will also be able to specify the channel multipliers instead of the default value of 2 per level.
The block layer (purple blocks) will be configurable (either a ConvEmbed or a ResBlock) as well as the number of (purple) blocks per level.
As a simplification only the last block in each level will be connected to a skip connection.
We will have Flux.GroupNorm
layers which requires a number of groups G
; this will be set the same throughout the model.
Lastly we’ll have a parameter for the attention layer.
There are many different variations of the time embedding layer.
This is one used by LuicidRains.
For the SinusoidalPositionEmbedding
see part 1-sinusodial embedding.
Next we’ll make the up and down blocks, where we can have more than one per level. The keys will be very useful when we print the entire model:
Here is the chain constructor.
We have an initial convolution to transform the image so that it has $d$ channels.
The final convolution reverses this.
In between are the down blocks, the skip connection and the up blocks.
The skip connection calls the recursive function _add_unet_level
(to be defined shortly).
Recursion should not be overused for making a model. We’ll only go four levels deep.
Finally, build the struct:
Here is the cat_on_channel_dim
function:
And here is the _add_unet_level
function.
As always for recursive functions, start with the break condition for the recursion else enter in a recursive loop:
For the forward pass we calculate the time embedding and pass it along with the input to the chain.
Printing the full model takes 181 lines. (Please see the source code for the printing functions.) See UNet.txt. Here is a condensed version:
For the final scripts please see DenoisingDiffusion.jl/src/models. Please also see these scripts for the printing functions which are not detailed here.
This post assumes you already have a knowledge of convolutional layers. Otherwise this post on Convolutional Neural Networks is a good source with a nice animation. This repository has nice animations on the effects of altering the stride and padding.
For an image input size $i$ (either the height or width), kernel/filter size $k$, padding $p$ and stride $s$, the output image will have size:
\[o=\left\lfloor{\frac{i + 2p - k}{s} + 1} \right\rfloor \tag{3.1} \label{eq:conv}\]Choosing $p=1$, $s=1$ and $k=3$ ensures that the output size is the same as the input size, $o=i$.
Let the input channel size be $d_i$ and the output channel size be $d_o$. The convolutional layer has $d_o$ kernels of size $k \cdot k \cdot d_i$ each with a bias of size $d_o$. Therefore the layer has $(k^2d_i + 1)d_o$ parameters. For $d_i = m_id$ and $d_o=m_od$ where $m_i$ and $m_o$ are the input and output channel multipliers respectively, the number of parameters is approximately $k^2m_im_od^2$.
The ConvEmbed
block will form the main block layer.
It will perform a convolution on the input and add the time embedding.
Additionally it will apply a GroupNorm
, adjust the time embedding for a layer specific embedding and apply an activation function.
The output is equivalent to activation(norm(conv(x)) .+ embed_layers(emb))
.
This has one convolution with $(3^2d_i + 1)d_o$ parameters, one fully connected layer with $(d_e+1)d_o$ parameters and a group norm with $2d_o$ parameters. For $d_i=m_id$, $d_o=m_od$ and $d_e=4d$, the number of parameters is approximately $(9m_i+4)m_od^2$. For the up layers $m_i$ will be a combination of two multipliers because of the concatenation: $m_i = m_t + m_{t-1}$.
This table summarises the resultant block sizes (all values are a slight underestimate):
down | up | |||||
---|---|---|---|---|---|---|
level | $m_i$ | $m_o$ | size ($d^2$) | $m_i$ | $m_o$ | size ($d^2$) |
1 | 1 | 1 | 13 | 1+1 | 1 | 22 |
2 | 1 | 1 | 13 | 2+1 | 2 | 62 |
3 | 2 | 2 | 44 | 3+2 | 3 | 147 |
4 | 3 | 3 | 93 |
Constructor:
Forward pass. The embedding is reshaped so that each embedding value is applied to the whole cross-section at each channel for each batch:
We can now use this ConvEmbed
block as our block layer in the U-Net Constructor.
Instead of a single ConvEmbed
block we can use a more complex Resnet style block which connects two blocks with a skip connection.
Only the first block will have the time embedding.
A complication of this block is that the input channels can be different to the output channels, so an extra convolution is needed to make the skip connection the same size.
The ConvEmbed
block has approximately $(9m_i+4)m_od^2$ parameters, the convolution layer has approximately $3^2m_o^2d^2$ and
the skip connection has approximately $3^2m_im_od^2$ where $m_i\neq m_o$.
In total there are $(9(m_i+m_o) + 4)m_od^2$ parameters on the down blocks and $(9(2m_i+m_o) + 4)m_od^2$ on the up blocks.
This table summarises the resultant block sizes (all values are a slight underestimate):
down | up | |||||
---|---|---|---|---|---|---|
level | $m_i$ | $m_o$ | size ($d^2$) | $m_i$ | $m_o$ | size ($d^2$) |
1 | 1 | 1 | 22 | 1+1 | 1 | 49 |
2 | 1 | 1 | 22 | 2+1 | 2 | 152 |
3 | 2 | 2 | 80 | 3+2 | 3 | 363 |
4 | 3 | 3 | 174 |
The constructor:
Forward pass:
The original U-Net used a 2×2 MaxPool layer for the down sampling. This samples the maximum value from every 2×2 window so that the output is $o=\left \lfloor \frac{i}{2} \right \rfloor$. Using Flux it is made with:
By default the stride is also 2×2. There are no parameters.
By looking at equation $\ref{eq:conv}$, we can also make a down sample layer by setting $k=4$, $s=2$ and $p=1$. The result is then also $o=\left \lfloor \frac{i}{2} \right \rfloor$. This is the method I’ve chosen to use:
This is a convolution layer with $16m_im_od^2+m_od$ parameters.
If the input image has an odd length then the downsample dimension will be $\frac{i-1}{2}$. Then if we upsample back by 2 the size will be $i-2$ instead of $i$ and the concatenation will fail. So the image has to be evenly divisible by powers of 2 in the forward pass.
Here the strategy is to first apply nearest neighbour upsampling followed by a convolution. In nearest neighbour upsampling we take each value and copy them to 2×2 cells. Hence the image will increase from an initial input dimension $i$ to an output dimension $o=2i$.
There are $9m_om_id^2+m_od$ parameters from the convolution.
Another technique uses a transpose convolution layer. This can be made with ConvTranspose((2,2), 1=>1, stride=(2,2))
.
This however is not recommended because of the “checkerboard” effect that transpose convolutions suffer from.
This is the most complicated block in the model and also the largest. This block is based on the transformer self-attention layer that was first introduced in Google’s 2017 paper Attention is all you need. Ho et. al. do not give a justification for including it in the model.
For a full discussion of self-attention please see my earlier post on transformers.
Unfortunately we can’t reuse all the code exactly. The biggest difference is that the attention is applied across channels (groups of 2D images) whereas for transformers they are applied across the embedding dimension (words). So here we take in a $W \times H \times C \times B$ input and rearrange it into $d_h \times WH \times h \times B$ arrays. For language models we take in a $d \times N \times B$ input and rearrange it into $d_h \times N \times h \times B$ arrays.
This version uses convolutions instead of dense layers for the query, key and value matrices and the output matrix. The query, key and value are all combined into one convolution with $3^2(md)(3md)=27m^2d^2$ parameters (no bias). The output has $9m^2d^2+md$ parameters. So in total this layer has approximately $36m^2d^2$ parameters.
Define the struct:
Constructor:
Forward pass:
Scaled dot attention (same as for the transformer):
Batched multiplication (same as for the transformer):
The following is a summary of all the (approximate) equations for the parameters:
Type | Parameters |
---|---|
Conv | $k^2 m_i m_o d^2 +m_o d$ |
ConvEmbed | $(9 m_i + 4)m_o d^2$ |
ResBlock | $(9(m_i + m_o) + 4)m_o d^2 $ |
ResBlock (up) | $(9(2m_i + m_o) + 4)m_o d^2 $ |
Downsample | $16m_i m_od^2 +m_o d$ |
Upsample | $9m_i m_od^2 +m_o d$ |
Attention | $36m_i^2d^2 + m_i d$ |
With these we can construct a table for the full UNet model:
Key | Type | $m_i$ | $m_o$ | $d^2$ | |
---|---|---|---|---|---|
1 | init | Conv | 0 | 1 | 0 |
2 | down_1 | ResBlock | 1 | 1 | 22 |
3 | downsample_1 | Conv | 1 | 1 | 16 |
4 | down_2 | ResBlock | 1 | 1 | 22 |
5 | downsample_2 | Conv | 1 | 2 | 32 |
6 | down_3 | ResBlock | 2 | 2 | 80 |
7 | down_4 | Conv | 2 | 3 | 54 |
8 | middle_1 | ResBlock | 3 | 3 | 174 |
9 | middle_attention | MultiHeadAttention | 3 | 3 | 324 |
10 | middle_2 | ResBlock | 3 | 3 | 174 |
11 | up_3 | ResBlock | 5 | 3 | 363 |
12 | upsample_3 | Conv | 3 | 2 | 54 |
13 | up_2 | ResBlock | 3 | 2 | 152 |
14 | upsample_2 | Conv | 2 | 1 | 18 |
15 | up_1 | ResBlock | 2 | 1 | 49 |
16 | final | Conv | 1 | 0 | 0 |
Total | 1 | 1 | 1534 |
Setting model_channels
to $d=16$ results in $1534d^2=392,704$ parameters which is 97.3% of the true total of 403,409.
The difference comes from ignoring bias terms and terms in $d$.
We can load data using the MLDatasets.jl package.
The first call to MNIST
will download the data (approximately 11 MB) to data_directory
.
Some quick data exploration:3
Normalise the dataset (function defined in part 1):
We can mostly reuse the code from part 1.
The one difference is that we’ll be using a cosine schedule for the $\beta_t$’s instead of a linear schedule. This was proposed in the 2021 paper Improved Denoising Diffusion Probabilistic Models by Alex Nichol snd Prafulla Dhariwal. The authors found that this schedule more evenly distributes noise over the whole time range for images.
The formula for $\bar{\alpha}_t$ is:
\[\bar{\alpha}_t = \frac{f(t)}{f(0)}, \quad f(t) = \cos\left(\frac{t/T+s}{1+s}\frac{\pi}{2}\right)^2\]They set $s=0.008$. Each $\alpha_t$ can then be calculated as \(\frac{\bar{\alpha}_{t}}{\bar{\alpha}_{t-1}}\) and hence $ \beta_t = 1 - \alpha_t $.
In code:
Forward diffusion:
The result:
For the full training script, see train_images.jl. I also have made a Jupyter Notebook hosted on Google Colab available at DenoisingDiffusion-MNIST.ipynb.4
Firstly split the training dataset into train and validation sets. The test dataset will be used for the evaluating the model after training is finished.
Create the model:
Train the model using the functions from part 1-training:
The training history:
Sample:
Here is the resulting reverse process (left) and first image estimates made using the U-Net model (right):5
We would now like to evaluate how good our image generation model is. In part 1 we had a clear defined target - a spiral defined by mathematical equations - and hence a clear defined measure of error: the shortest distance from each point to the spiral. Here our sample should be recognisable as one of 10 digits without ambiguity. At the same time it should not match any in the original dataset because we want to create new images. This makes it tricky to evaluate.
Here I propose two techniques. The first is the mean Euclidean distance compared to the mean test images. The second is the Fréchet LeNet Distance, inspired by the popular Fréchet Inception Distance (FID). Both require generating a large amount of samples for statistical significance. I have generated 10,000 samples to match the 10,000 test data samples.
For a notebook with the full code please see MNIST-test.ipynb.
We can calculate the mean test images using:
As an aside, the generated sample means tend to be remarkably close to these test set means. This may be because the U-Net model is only used for $\tilde{\mu}_t$ and not $\tilde{\beta}_t$ in the reverse equations.
Define the distance as the minimum of the mean Euclidean distance of $x$ to each mean $\bar{x}_k$:
\[d = \min_{0 \leq k \leq 9} \frac{1}{WH}\sqrt{\sum_{i}^W\sum_{j}^H (x_{ij} - \bar{x}_{k,ij})^2}\]The score is then the average of $d$ over all samples.
In code:
Sample values look like:
dataset | score |
---|---|
test | 0.0080 |
train | 0.0080 |
generated | 0.0084 |
random | 0.0913 |
The counts per label look like:
The test set labels are only 82% accurate, so this method is not good enough for a bias free evaluation.
A smarter way to evaluate the model is the Fréchet Inception Distance (FID). This was introduced in the 2017 paper GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium by Heusel et. al.. There are two important insights. Firstly, image classification can be considered a solved task after the work of the last two decades in machine learning. That is we can use an image classification model to get insights into our generated data. Secondly, we can view the outputs of penultimate layer of the model as a probability distribution and compare the statistical distance of it between different datasets. The intuition behind using the penultimate layer is that it is an abstract feature space containing essential information about the samples that we can for comparison rather than manually specifying features.
In particular, the FID score uses the Inception V3 model as the classification model and the Fréchet Distance as the statistical measure:
\[d(\mathcal{N}(\mu_1, \Sigma_1^2), \mathcal{N}(\mu_2, \Sigma_2^2)) = ||\mu_1 - \mu_2||^2 + \text{trace}\left(\Sigma_1 + \Sigma_2 -2\left( \Sigma_1 \Sigma_2 \right)^{1/2} \right) \label{eq:Frechet} \tag{4.7}\]Intuitively the first term represents the distance between the means and the second term counts for a difference in variances.
The Inception V3 model however is overkill here. It has over 27 million parameters and the penultimate layer has a length of 2048. My proposal is to instead use the smaller LeNet-5 with 44,000 parameters and an output length of 84. It was first proposed in the 1998 paper Gradient-Based Learning Applied to Document Recognition by Yann LeCun, Leon Bottou, Yoshua Bengio and Patrick Haffner.
This is a Julia implementation from the Flux model zoo:
Training this model should be covered in any introductory machine Learning material. You can view my attempt at train_classifier.jl. This model gets a test accuracy of 98.54%. This is good enough for us to consider it a perfect oracle.
After training we can load the model:
To use the generated outputs we’ll have to normalise them between 0 and 1 (function defined in part 1-spiral dataset):
We can apply our classifier to the normalised generated outputs and compare label counts to the test data:
We can immediately see that our model has a bias to producing 0s and 6s and produces too few 2s and 4s.
Now for the FLD score. We can implement equation $\ref{eq:Frechet}$ in Julia as follows:6
Usage:
Sample values:
dataset | score |
---|---|
test | 0.0001 |
train | 0.4706 |
generated | 23.8847 |
random | 337.7282 |
Our generated dataset is indeed significantly better than random. However it is still noticeably different from the original dataset.
The main focus of this post was to build a U-Net model. It is general purpose and can be used for other image generation tasks. For example, I did try to generate Pokemon with it, inspired by This Pokémon Does Not Exist but unfortunately did not get good results. You can see my notebooks at github.com/LiorSinai/DenoisingDiffusion-examples.
Now that we can generate numbers the next task is to tell the model which number we want. That way we can avoid the bias towards certain numbers that was evidenced in the evaluation. This will be the focus of the third and final part on Classifier-free guidance. That same method is used with text embeddings to direct the outcome of AI art generators.
Now that our models are getting large, it is also desirable to improve the generation time. This can be accomplished with a technique introduced in the paper Denoising Diffusion Implicit Models (DDIM) by Jiaming Song, Chenlin Meng and Stefano Ermon. DDIM sampling allows the model to skip timesteps during image generation. This results in much faster image generation with a trade off of a minor loss in quality. I have implemented DDIM sampling in my code. Please review this if you are interested.
One has to admit it has the added advantage of being very short. ↩
For an example of this, see the “backpropagation for mul4d” card in an earlier post on transformers. ↩
The convert2image
code can be rewritten as:
function img_WH_to_gray(img_WH::AbstractArray{T,N}) where {T,N}
@assert N == 2 || N == 3
img_HW = permutedims(img_WH, (2, 1, 3:N...))
img = Images.colorview(Images.Gray, img_HW)
img
end
Google Colab does not natively support Julia so you’ll have to install it every time you run the notebook. Plots.jl does not work on Google Colab. ↩
I’ve used a function to combine multiples images into one:
function combine(imgs::AbstractArray, nrows::Int, ncols::Int, border::Int)
canvas = zeros(Gray, 28 * nrows + (nrows+1) * border, 28 * ncols + (ncols+1) * border)
for i in 1:nrows
for j in 1:ncols
left = 28 * (i-1) + 1 + border * i
right = 28 * i + border * i
top = 28 * (j - 1) + 1 + border * j
bottom = 28 * j + border * j
canvas[left:right, top:bottom] = imgs[:, :, ncols * (i-1) + j]
end
end
canvas
end
The one tricky part of the equation is the matrix square root $A^{1/2}$, defined as a matrix $B=A^{1/2}$ such that $BB=A$. In code it is therefore important to realise that sqrt.(A)
and sqrt(A)
are very different operations. ↩