Classifier-free guidance for denoising diffusion probabilistic model in Julia.
Update 28 August 2023: code refactoring and update to Flux 0.13.11 explicit syntax.
This is part of a series. The other articles are:
Full code available at github.com/LiorSinai/DenoisingDiffusion.jl. Examples notebooks at github.com/LiorSinai/DenoisingDiffusion-examples. Google Colab training notebook at DenoisingDiffusion-MNIST.ipynb.
A major disadvantage of the models developed so far is that the output is random. The user has no way to control it. For the pattern in part 1, we always have to create a spiral. For the number generation in part 2, we cannot specify which number we want. However it is easy enough to extend the existing code to include guidance where the user can specify the output. This post details the conditioning and classifier-free guidance techniques to accomplish this. It applies them to the 2D pattern generation and number generation problems from the previous posts.
These same techniques are very powerful when combined with other models. For example, we can use a model to generate text embeddings and pass these to the diffusion model, creating a text-to-image workflow. This is the underlying idea behind all the popular AI art generators like Stable Diffusion, DALLE-2 and Imagen.
We already have time embeddings and we can combine other sorts of embeddings with it. These embeddings can be added, multiplied or concatenated to the time embeddings. So instead of estimating the noise $\epsilon_\theta(x_t | t)$ we’ll estimate $\epsilon_\theta(x_t | t, y)$ where $y$ is the label of the class. We can also directly pass in the conditioning parameters (text embeddings) $c$ and estimate $\epsilon_\theta(x_t | t, c)$.
It is still useful for this guided model to generate random samples. This gives it backwards compatibility with our previous code and as we’ll see shortly, gives us a reference point for increasing the conditioning signal. We can do this by passing a special label/embedding denoting the empty set $\emptyset$. Mathematically we can write $\epsilon_\theta(x_t | t) = \epsilon_\theta(x_t | t, \emptyset)$. During training we sometimes randomly change the label/embedding to the empty set label/embedding so that model learns to associate it with random choice.
In practice when conditioning on the label, the first column of the embedding matrix (label 0 in Python or label 1 in Julia) is reserved for the empty set and the rest are used for the classes. Or when conditioning on the embedding vectors directly, a vector of all zeros can be used for the empty set.
Conditioning on its own works well enough for the toy problems in this blog series. However for text-to-image models we need a stronger signal, and so need a way of amplifying this conditioning. This is known as guidance.
An early proposal for guided diffusion was classifier guidance in the paper Diffusion Models Beat GANs on Image Synthesis by Prafulla Dhariwal and Alex Nichol. This technique used a classifier model to estimate a gradient through a loss function which was then added to the estimate of the sample:
\[\mu_t = \tilde{\mu}_t(x_t | y) + s\Sigma_\theta(x_t | y)\nabla_{x_t} \log p_\phi (y | x_t) \tag{2.1}\]Where $ \mu_t $ is the mean sample per time step, $y$ is the target class, $s$ is the guidance scale, $\Sigma_\theta$ is the model variance and $\nabla_{x_t} \log p_\phi$ is the gradient of the log-probability of the output of the classifier.1 This is an incredibly complicated solution. Creating a classifier model is easy (see part 2-Frechet LeNet Distance) but creating a classifier model that can work throughout all the noisy time steps of diffusion is very difficult. This requires a second U-Net model that needs to be trained in parallel.
Thankfully a much simpler and more effective solution was proposed in the 2022 paper Classifier-Free Diffusion Guidance by Jonathan Ho and Tim Salimans. They proposed estimating the noise twice per time step: once with conditioning parameters (text embeddings) and once without. The predicted noise $\epsilon_\theta$ is then given as a weighted sum of the two:
\[\begin{align} \epsilon_\theta(x_t, c) &= (1 + s)\epsilon_\theta(x_t | c) - s\epsilon_\theta(x_t) \\ &= \epsilon_\theta(x_t | c) + s(\epsilon_\theta(x_t | c) - \epsilon_\theta(x_t)) \tag{2.2} \end{align}\]Where $t$ is the time step, $x_t$ is the sample, $c$ are the conditioning parameters and $s$ is the guidance scale. From the second line we can see that the gradient is estimated as the difference between the noise with conditioning $\epsilon_\theta(x_t | c)$ and the noise without conditioning $\epsilon_\theta(x_t)$. It is then added to the original noise term, essentially telling it very crudely “go in this direction”. This is an incredibly simple idea but it is effective.
Nichol et. al. quickly adapted this technique, realising it was superior to their complicated classifier guidance. They made the minor improvement of putting the unconditioned noise $\epsilon_\theta(x_t)$ as the first term:
\[\epsilon_\theta(x_t, c) = \epsilon_\theta(x_t) + s(\epsilon_\theta(x_t | c) - \epsilon_\theta(x_t)) \tag{2.3} \label{eq:classifier-free-guidance}\]This form makes slightly more intuitive sense, as we are taking the random noise and guiding it with the gradient, rather than guiding the conditioned (and already guided) noise with the gradient. Then for the special case of $s=1$, $\epsilon_\theta(x_t, c) = \epsilon_\theta(x_t | c)$. That is just plain conditioning. This is significant because it means that for $s=1$ we can use conditioning without estimating the noise twice for classifier-free guidance.
We can now implement this in Julia. We can reuse the same functions and use multiple dispatch to differentiate between them: the guided versions will have three inputs ($x_t$, $t$, $c$) while the originals will have two ($x_t$, $t$).
For the full code see guidance.jl.
This is the guided version of the reverse process from part 1.
For text embeddings coming from a language model the new argument to p_sample
would be a matrix.
However here the embeddings are calculated in the model too, so p_sample
will expect a vector of integer labels instead.
These will then be passed to an embedding layer.
(For modifications which can take in external embeddings, see the feature/external-embeddings branch on the repository.)
A label of 1 will be considered random choice and other integers will correspond to the classes.
As discussed above, for the special case of the guidance_scale=1
inputs will be passed directly to the model.
Otherwise the more computationally intensive classifier-free guidance will be invoked.
We need a new method for the denoise
function to handle three inputs:
Next is the implementation of classifier-free guidance with equation $\ref{eq:classifier-free-guidance}$. This code calculates the conditioned noise and the unconditioned noise at the same time.
All the other functions remain the same as in part 1.
This is the guided version of the sampling loops from part 1.
Sampling loops with an extra input:
Sampling loop returning all diffusion time steps and $x_0$ estimates:
This is the guided version of the training process from part 1.
As before we’ll have two methods for p_losses
.
The first takes in the four inputs (x_start
, timesteps
, noise
and the new labels
) and calculates the losses:
The second will generate the timesteps
and noise
based on x_start
(as before):
The training loop needs to be updated so that it can randomly set labels to that of the empty set $\emptyset$.2
This train!
function from part 1 is extended as follows:
For more functionality like calculating the validation loss after each epoch, see the train.jl script in the repository.
That’s it. Our code can now do guided diffusion. Let’s test it out on both the 2D patterns and the number generation.
In part 1 we only worked with a spiral dataset. However we can work other patterns too. Here is a Julia implementation of the Scikit-learn moon dataset:
Similarly for the Scikit-learn s-curve:
Then we can make and combine samples of all three:
Plotting all at once:
Make the labels. Remember that 1 is reserved for random choice.
We can use the same model as part 1-sinusodial embeddings with an additional Flux.Embedding
layer at each level.
Otherwise no other changes required.
Training is the same as part 1-training. The only difference is that the loss
function must now take in three arguments:
See train_generated_2d_cond.jl for the full script.
We can sample at a guidance scale of 1 to avoid classifier-free guidance:
The result:
The random choice label of 1 will result in a combination of all the patterns. This is because the 2D points are sampled from one of the patterns independently of their neighbours. It is therefore highly unlikely that they will all be randomly chosen from one pattern.
The MNIST dataset comes with labels. We will have to shift them over by 2 because 1 is reserved for random choice and so 0 needs to correspond to 2.
The code:
We need a new method for split_validation
to split the labels too:
We would like to use the same model from part 2-constructor.
However we need to add two extra elements to the struct: a new embedding layer and a function to combine embeddings.
The combine_embeddings
can be one of +
(as used above for the spiral), *
or vcat
.
It should not have parameters.
The new struct is:
The constructor is almost the same, so I’ll skip most of it:
The forward pass has additional steps for calculating the class embedding and combining it with the time embeddings:
Automatically set labels to random choice if none are supplied:
For printing functions please see UNetConditioned.jl.
We can make a model with:
The whole model will have 420,497 parameters. This is only an increase of 4.3% on the unconditioned UNet
.
Training is the same as part 2-training. For the full training script, see train_images_cond.jl. I also have made a Jupyter Notebook hosted on Google Colab available at DenoisingDiffusion-MNIST.ipynb.3
We can now finally make the video I showed at the very top of part 1:
Another interesting question to ask is what is the effect of the guidance scale on the output? Here is an image showing that:
It seems that lower guidance scale values don’t affect the output much but higher values interfere too much.
We can also go through the same exercise of calculating the Frechet LeNet Distance (FLD) from part 2-Fretchet LeNet Distance. Except this time we can ask the model to generated 1000 samples of each label. This gives us very uniform label counts:
The classifier agrees very well with our desired labels. When the guidance scale is 1 the average recall, precision and F1 score across all labels is 89%. When it is 2 the averages increase to 99%.
The FLD score is also lower. It dropped from values in the 20s without conditioning to between 4 and 8 with conditioning.
This three part series sought to explain the magic behind AI art generators. In particular it focused on denoising diffusion probabilistic models, the main image generation model used in text-to-image pipelines. I hope you understand them much better now and are able to experiment with your own models.
There is still ongoing controversy over AI art generators - the way training data is collected, the quality of the outputs, the potential for forgeries and debates about the value of machine generated art. As an engineer I am always tempted to promote the technology but the adverse effects of it cannot be ignored. I hope that society is able to navigate the complex questions posed by these art generators to the benefit of all involved.
One thing that is for sure is that the technology is here to stay. If you’ve followed this series properly, you’ll have a model running on your own machine training on your own data. That is not something can be taken away easily.
For an implementation of the classifier guidance equation, see OpenAI’s classifier_sample.py from the guided-diffusion repository. To be honest, I don’t understand this equation or the code. Why is there is a $\log$ in the gradient? Why are they computing the sum of the probabilities? ↩
The original code put the randomly_set_unconditioned
logic in the p_losses
function. Other than that this was not the best logical place for it to be, it created the problem that this code fell under the Flux.withgradient
scope and so it needed to be differentiated. However Zygote
can not differentiate mutating operations like .=
. So instead of using:
is_not_class_cond = rand(batch_size) .< prob_uncond
labels[is_not_class_cond] .= 1
I used:
is_class_cond = rand(batch_size) .>= prob_uncond
is_not_class_cond = .~is_class_cond
labels = (labels .* is_class_cond) + is_not_class_cond
This hack worked. However it is better to not do unnecessary differentiation. ↩
Google Colab does not natively support Julia so you’ll have to install it every time you run the notebook. Plots.jl does not work on Google Colab. ↩