<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://liorsinai.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://liorsinai.github.io/" rel="alternate" type="text/html" /><updated>2026-01-10T14:46:54+00:00</updated><id>https://liorsinai.github.io/feed.xml</id><title type="html">Lior Sinai</title><subtitle>Musings on mathematics, coding and physics.</subtitle><author><name>Lior Sinai</name></author><entry><title type="html">DeepSeek’s Multi-Head Latent Attention</title><link href="https://liorsinai.github.io/machine-learning/2025/02/22/mla.html" rel="alternate" type="text/html" title="DeepSeek’s Multi-Head Latent Attention" /><published>2025-02-22T00:00:00+00:00</published><updated>2026-01-10T00:00:00+00:00</updated><id>https://liorsinai.github.io/machine-learning/2025/02/22/mla</id><content type="html" xml:base="https://liorsinai.github.io/machine-learning/2025/02/22/mla.html"><![CDATA[<p><em>A deep dive into DeepSeek’s Multi-Head Latent Attention, including the mathematics and implementation details. The layer is recreated in Julia using Flux.jl.</em></p>

<p>See also previous posts on transformers:</p>
<ul>
  <li><a href="/machine-learning/2022/05/18/transformers">Transformers from first principles in Julia</a>.</li>
  <li><a href="/machine-learning/2024/03/23/transformers-gpt">Generative transformer from first principles in Julia</a>.</li>
</ul>

<p>All code available at <a href="https://github.com/LiorSinai/TransformersLite.jl/tree/feature/mla">github.com/LiorSinai/TransformersLite.jl/tree/feature/mla</a>.</p>

<h3 id="table-of-contents">Table of Contents</h3>

<nav id="toc"></nav>
<script src="/assets/makeTableOfContents.js"></script>

<h2 id="introduction">1 Introduction</h2>

<p>In January 2025, <a href="https://www.deepseek.com/">DeepSeek</a> unveiled their new DeepSeek-V3 and DeepSeek R1 models.
It took the <a href="https://edition.cnn.com/2025/01/27/tech/deepseek-ai-explainer/index.html">world</a> <a href="https://www.technologyreview.com/2025/01/31/1110740/how-deepseek-ripped-up-the-ai-playbook-and-why-everyones-going-to-follow-it/">by</a> <a href="https://hn.algolia.com/?q=deepseek">storm</a>.
Users were impressed with its abilities on top of their claims that is was up to 50× more efficient to train and run than their competitors.</p>

<p>They also released multiple papers 
(<a href="https://github.com/deepseek-ai/DeepSeek-VL2">DeepSeek-V2</a>, <a href="https://github.com/deepseek-ai/DeepSeek-V3">DeepSeek-V3</a>, <a href="https://github.com/deepseek-ai/DeepSeek-R1">DeepSeek-R1</a>)
with an impressive array of new techniques across the whole machine learning pipeline, from high level theory to intricate implementation details.
Most of it built on existing ideas in innovative ways. They include:</p>
<ul>
  <li>Theory
    <ul>
      <li>Multi-Head Latent Attention (MLA): compress vectors during attention, which reduces the cache size during inference.</li>
      <li>DeepSeekMoE: segmented and isolated mixture of experts.</li>
      <li>Multi-token prediction.</li>
      <li>Reinforcement learning with Group Relative Policy Optimization but without supervised data.</li>
      <li>Improved chain-of-thought reasoning.</li>
    </ul>
  </li>
  <li>Implementation
    <ul>
      <li>DualPipe: accelerate training by overlapping forward and backward computation communication phases.</li>
      <li>Exponential moving average on the CPU.</li>
      <li>Mixed precision floating point numbers during training. FP8 and FP32 are used.</li>
      <li>Low-precision storage and communication.</li>
    </ul>
  </li>
</ul>

<p>The aim of this post is to explore the first of these ideas in depth, namely Multi-Head Latent Attention (MLA). It is actually a combination of 3 ideas:</p>
<ol>
  <li>Attention and multi-head attention from <a href="https://arxiv.org/abs/1706.03762">Attention is all you need (2017)</a>.</li>
  <li>KV Caching.</li>
  <li>Low-Rank Adaption matrices (LoRA) from <a href="https://arxiv.org/abs/2106.09685">LoRA: Low-Rank Adaptation of Large Language Models (2021)</a>.</li>
</ol>

<p>The idea behind MLA is simple: it compresses the input matrix into a single matrix for caching during inference, as opposed to standard KV caching which caches two intermediate matrices.
This is a novel innovation by DeepSeek.
However, MLA has side effects which DeepSeek barely address. 
The compression also results in a performance boost but most likely comes with a qualitative penalty. That these are not discussed is a notable omission in their <a href="https://github.com/deepseek-ai/DeepSeek-VL2">paper</a>.</p>

<p>DeepSeek also adds two unwieldly enhancements to MLA. While they only speak good of these, they complicate the mathematics and require specialised optimised code to see performance gains:</p>
<ol start="1">
  <li>Weight absorption.</li>
  <li>Decoupled rotary position embeddings (RoPE), a modification of <a href="https://arxiv.org/abs/2104.09864">RoFormer: Enhanced Transformer with Rotary Position Embedding (2021)</a>.</li>
</ol>

<p>The original source code is written in Python with the PyTorch machine learning framework.
However, my favourite language is Julia and continuing with my previous posts on transformers, all the code here is written in Julia using the Flux.jl machine learning framework.</p>

<div class="message-container info-message">
	<div class="message-icon fa fa-fw fa-2x fa-exclamation-circle">
	</div>
	<div class="content-container">
		<div class="message-body">
		Julia uses column major format whereas Python uses row major format. In Julia sample vectors are columns while in Python they are rows.
		Equations between the two formats will look backwards to each other.
		They need to be transposed and definitions also need to be transposed. 
		E.g. $K^TQ \rightarrow (K_c^TQ_c)^T=Q_c^TK_c= Q_r K_r^T$
		</div>
	</div>
</div>

<p>The mathematics therefore follows Julia’s columnar major format and not Python’s row major format.
In this format each column represents a single sample. The second dimension is sequence length and all remaining dimensions represent batches.</p>

<h2 id="kv-caching">2 KV Caching</h2>
<h3 id="kv-caching-theory">2.1 Theory</h3>

<p>There is an inefficiency in using transformers for generation.
In classification <a href="/machine-learning/2022/05/18/transformers#use-case-amazon-reviews">use cases</a>, the model only needs to calculate attention over the input sentence once (per layer) before making its prediction.
But in text generation, it needs to recalculate it over the entire sentence.
One analogy I read is that this is like having to reread the entire sentence so far in order to produce the next word and then rereading again to produce the next word, and so on.</p>

<p>In mathematical terms, this is computing the attention between the first token and itself, between the first two tokens, between the first 3 tokens and so on until it’s between all the tokens, and repeating this whole process each time for every new token.
This was the case for the <a href="/machine-learning/2024/03/23/transformers-gpt">generator</a> I made based on Andrej Karpathy’s <a href="https://karpathy.ai/zero-to-hero.html">Zero to Hero course</a>.</p>

<p>It would be better to have a “memory” of what has been generated so far.
This is where the KV cache comes in, for Key-Value cache.
We can store the previous keys and values and use them to calculate the new attention value.
This will give the exact same value as recalculating the entire attention.</p>

<p>Here is a mathematical proof.
For a visual interpretation, please see João Lages’ <a href="https://medium.com/@joaolages/kv-caching-explained-276520203249">excellent article</a>.</p>

<p>The attention equation is</p>

\[A = V\text{softmax}\left(\frac{1}{\sqrt{d_h}}K^T Q\right)
\label{eq:attention}
\tag{2.1.1}\]

<p>where</p>

\[\begin{align}
    K &amp;= W^K X \\
    Q &amp;= W^Q X \\
    V &amp;= W^V X
\end{align}
\label{eq:KQV}
\tag{2.1.2}\]

<p>where $X \in \mathbb{R}^{d\times n \times B}$ and $W^K, W^Q, W^V \in \mathbb{R}^{d_hH \times d}$.</p>

<p>Focusing on one head in one batch with dimension $d_h$, the input matrix $X$ is split into the first $n-1$ columns and the $n$th column.
The first multiplication can then be written as (dimensional analysis on right):<sup id="fnref:forgive" role="doc-noteref"><a href="#fn:forgive" class="footnote" rel="footnote">1</a></sup></p>

\[\begin{align}
S &amp;= K^T Q &amp; &amp; \\
  &amp;=
\begin{bmatrix}
K_{1:n-1} &amp; K_n 
\end{bmatrix}^T
\begin{bmatrix}
Q_{1:n-1} &amp; Q_n
\end{bmatrix}
&amp;; &amp;\begin{bmatrix}d_h \times (n-1) &amp; d_h \times 1 \end{bmatrix} ^T \begin{bmatrix} d_h \times (n-1) &amp; d_h \times 1\end{bmatrix}\\
 &amp;=
\begin{bmatrix}
K_{1:n-1}^T \\ K_n^T 
\end{bmatrix}
\begin{bmatrix}
Q_{1:n-1} &amp; Q_n
\end{bmatrix} &amp;; &amp;\begin{bmatrix}(n-1) \times d_h \\ 1 \times d_h \end{bmatrix} \begin{bmatrix} d_h \times (n-1) &amp; d_h \times 1\end{bmatrix}\\
 &amp;= 
\begin{bmatrix}
K_{1:n-1}^T Q_{1:n-1} &amp; K_{1:n-1}^T Q_n \\
K_n^T Q_{1:n-1} &amp; K_n^T Q_n
\end{bmatrix} &amp;; &amp;\begin{bmatrix}(n-1)\times(n-1) &amp; (n-1)\times 1 \\ 1 \times (n-1) &amp; 1\times 1 \end{bmatrix}
\end{align}
\label{eq:Kcache}
\tag{2.1.3}\]

<p>Looking at the final line, the first $(n-1)$ columns of the query can be safely dropped without affecting the $n$th column.
(We would do this anyway in <a href="/machine-learning/2024/03/23/transformers-gpt#generation">generation</a>.)
The $n$th column only depends on $K_{1:n-1}$, $K_n$ and $Q_n$.
Of these, $K_{1:n-1}$ will come from the cache and the other two will be calculated from $X_n$.</p>

<div class="message-container info-message">
	<div class="message-icon fa fa-fw fa-2x fa-exclamation-circle">
	</div>
	<div class="content-container">
		<div class="message-body">
    It is important to note that dropping the first $(n-1)$ columns is also valid because almost all other layers in a transformer are independent of position.
    For example for the dense layers $Y=WX$, permuting the columns of $X$ will result in a corresponding permutation of the columns of $Y$.
    E.g. if $X$ has two columns and $[Y_1 Y_2] = W[X_1 X_2]$ then $[Y_2 Y_1] = W[X_2 X_1]$
    The exception is the position embedding layers which will require a new parameter to be passed through the whole transformer to indicate the position.
		</div>
	</div>
</div>

<p>Similarly for the next multiplication we have ($Z=\text{softmax}(S)$):</p>

\[\begin{align}
A &amp;= V Z \\
  &amp;=
\begin{bmatrix}
V_{1:n-1} &amp; V_n 
\end{bmatrix}
\begin{bmatrix}
Z_{1:n-1,1:n-1} &amp; Z_{1:n-1,n} \\
Z_{n,1:n-1} &amp; Z_{n,n} 
\end{bmatrix} \\
 &amp;= 
 \begin{bmatrix}
    V_{1:n-1} Z_{1:n-1,1:n-1} +  V_{n} Z_{n,1:n-1}  &amp;
    V_{1:n-1} Z_{1:n-1,n} + V_n Z_{n, n}
    \end{bmatrix}
\end{align}
\label{eq:Vcache}
\tag{2.1.4}\]

<p>However, we said we are dropping the first $(n-1)$ columns.
Without them only the $n$th column is calculated.
Hence we have:</p>

\[\begin{align}
A_n &amp;= 
\begin{bmatrix}
V_{1:n-1} &amp; V_n 
\end{bmatrix}
\begin{bmatrix}
Z_{1:n-1,n} \\
Z_{n,n} 
\end{bmatrix} \\
  &amp;= V_{1:n-1} Z_{1:n-1,n} + V_n Z_{n, n}
\end{align}
\tag{2.1.5}\]

<p>which depends on $V_{1:n-1}$, which will come from the cache, and $V_n$, which will be calculated from $X_n$.</p>

<p>There will be two caches each with size $d_h H \times N \times B$ for $H$ heads, a maximum sequence length of $N$ and a maximum batch size of $B$.
The total cache size can grow very large for large transformers with many multi-head attention layers.
The primary aim of MLA is to reduce the size of this cache. That will be covered in the next section.</p>

<h3 id="kv-caching-code">2.2 Code</h3>

<p>Building on my code in <a href="https://github.com/LiorSinai/TransformersLite.jl">TransformersLite.jl</a>, it is straightforward to create a new <code class="language-plaintext highlighter-rouge">MultiHeadAttentionKVCache</code> layer with two caches:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> MultiHeadAttentionKVCache</span><span class="x">{</span>
    <span class="n">Q</span><span class="o">&lt;:</span><span class="n">Dense</span><span class="x">,</span> <span class="n">K</span><span class="o">&lt;:</span><span class="n">Dense</span><span class="x">,</span> <span class="n">V</span><span class="o">&lt;:</span><span class="n">Dense</span><span class="x">,</span> <span class="n">O</span><span class="o">&lt;:</span><span class="n">Dense</span><span class="x">,</span> <span class="n">C</span><span class="o">&lt;:</span><span class="kt">Array</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">3</span><span class="x">}</span>  <span class="k">where</span> <span class="n">T</span>
    <span class="x">}</span>
    <span class="n">nhead</span><span class="o">::</span><span class="kt">Int</span>
    <span class="n">denseQ</span><span class="o">::</span><span class="n">Q</span>
    <span class="n">denseK</span><span class="o">::</span><span class="n">K</span>
    <span class="n">denseV</span><span class="o">::</span><span class="n">V</span>
    <span class="n">denseO</span><span class="o">::</span><span class="n">O</span>
    <span class="n">cache_k</span><span class="o">::</span><span class="n">C</span>
    <span class="n">cache_v</span><span class="o">::</span><span class="n">C</span>
<span class="k">end</span>

<span class="n">Flux</span><span class="o">.</span><span class="nd">@layer</span> <span class="n">trainable</span><span class="o">=</span><span class="x">(</span><span class="n">denseQ</span><span class="x">,</span> <span class="n">denseK</span><span class="x">,</span> <span class="n">denseV</span><span class="x">,</span> <span class="n">denseO</span><span class="x">)</span></code></pre></figure>

<p>The forward pass calculates the current <code class="language-plaintext highlighter-rouge">q</code>, <code class="language-plaintext highlighter-rouge">k</code> and <code class="language-plaintext highlighter-rouge">v</code> values from the input and gets the rest from the cache.
It then continues without any additional modifications from the original code:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> </span><span class="o">(</span><span class="n">mha</span><span class="o">::</span><span class="n">MultiHeadAttentionKVCache</span><span class="x">)(</span>
    <span class="n">query</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">key</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">value</span><span class="o">::</span><span class="n">A3</span>
    <span class="x">;</span> <span class="n">start_pos</span><span class="o">::</span><span class="kt">Int</span><span class="o">=</span><span class="mi">1</span><span class="x">,</span> <span class="n">use_cache</span><span class="o">::</span><span class="kt">Bool</span><span class="o">=</span><span class="nb">true</span><span class="x">,</span> <span class="n">kwargs</span><span class="o">...</span>
    <span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">A3</span> <span class="o">&lt;:</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">3</span><span class="x">}}</span>
    <span class="n">q</span> <span class="o">=</span> <span class="n">mha</span><span class="o">.</span><span class="n">denseQ</span><span class="x">(</span><span class="n">query</span><span class="x">)</span> <span class="c"># size(q) == (dh, 1, B)</span>
    <span class="n">k</span> <span class="o">=</span> <span class="n">mha</span><span class="o">.</span><span class="n">denseK</span><span class="x">(</span><span class="n">key</span><span class="x">)</span>
    <span class="n">v</span> <span class="o">=</span> <span class="n">mha</span><span class="o">.</span><span class="n">denseV</span><span class="x">(</span><span class="n">value</span><span class="x">)</span> <span class="c"># size(k) == size(v) == (dh, 1, B)</span>
    <span class="k">if</span> <span class="n">use_cache</span>
        <span class="n">dim</span><span class="x">,</span> <span class="n">seq_length</span><span class="x">,</span> <span class="n">batch_dim</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">query</span><span class="x">)</span>
        <span class="n">end_pos</span> <span class="o">=</span> <span class="n">start_pos</span> <span class="o">+</span> <span class="n">seq_length</span> <span class="o">-</span> <span class="mi">1</span>
        <span class="n">mha</span><span class="o">.</span><span class="n">cache_k</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="n">start_pos</span><span class="o">:</span><span class="n">end_pos</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">batch_dim</span><span class="x">]</span> <span class="o">=</span> <span class="n">k</span>
        <span class="n">mha</span><span class="o">.</span><span class="n">cache_v</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="n">start_pos</span><span class="o">:</span><span class="n">end_pos</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">batch_dim</span><span class="x">]</span> <span class="o">=</span> <span class="n">v</span>
        <span class="n">K</span> <span class="o">=</span> <span class="n">mha</span><span class="o">.</span><span class="n">cache_k</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">end_pos</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">batch_dim</span><span class="x">]</span>
        <span class="n">V</span> <span class="o">=</span> <span class="n">mha</span><span class="o">.</span><span class="n">cache_v</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">end_pos</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">batch_dim</span><span class="x">]</span>
    <span class="k">else</span>
        <span class="n">K</span> <span class="o">=</span> <span class="n">k</span>
        <span class="n">V</span> <span class="o">=</span> <span class="n">v</span>
    <span class="k">end</span>
    <span class="n">A</span><span class="x">,</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">multi_head_scaled_dot_attention</span><span class="x">(</span><span class="n">mha</span><span class="o">.</span><span class="n">nhead</span><span class="x">,</span> <span class="n">q</span><span class="x">,</span> <span class="n">K</span><span class="x">,</span> <span class="n">V</span><span class="x">;</span> <span class="n">kwargs</span><span class="o">...</span><span class="x">)</span>
    <span class="n">mha</span><span class="o">.</span><span class="n">denseO</span><span class="x">(</span><span class="n">A</span><span class="x">),</span> <span class="n">scores</span>
<span class="k">end</span></code></pre></figure>

<p>Here is a small example of it in action. (For the full code, see <a href="https://github.com/LiorSinai/TransformersLite.jl/blob/feature/mla/test/MultiHeadAttention.jl">test/MultiHeadAttention.jl</a>.)</p>

<p>Create the layer and inputs:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">TransformersLite</span>
<span class="k">using</span> <span class="n">TransformersLite</span><span class="o">:</span> <span class="n">MultiHeadAttention</span><span class="x">,</span> <span class="n">MultiHeadAttentionKVCache</span>
<span class="k">using</span> <span class="n">TransformersLite</span><span class="o">:</span> <span class="n">make_causal_mask</span><span class="x">,</span> <span class="n">clone_add_kv_cache</span>
<span class="n">nhead</span><span class="x">,</span> <span class="n">dim_model</span><span class="x">,</span> <span class="n">dim_out</span> <span class="o">=</span> <span class="mi">4</span><span class="x">,</span> <span class="mi">32</span><span class="x">,</span> <span class="mi">13</span>
<span class="n">mha</span> <span class="o">=</span> <span class="n">MultiHeadAttention</span><span class="x">(</span><span class="n">nhead</span><span class="x">,</span> <span class="n">dim_model</span><span class="x">,</span> <span class="n">dim_out</span><span class="x">)</span> 
<span class="n">mha</span> <span class="o">=</span> <span class="n">clone_add_kv_cache</span><span class="x">(</span><span class="n">mha</span><span class="x">,</span> <span class="mi">64</span><span class="x">,</span> <span class="mi">8</span><span class="x">)</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="kt">Float32</span><span class="x">,</span> <span class="mi">32</span><span class="x">,</span> <span class="mi">10</span><span class="x">,</span> <span class="mi">5</span><span class="x">)</span></code></pre></figure>

<p>Fill the cache:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">mask</span> <span class="o">=</span> <span class="n">make_causal_mask</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">10</span><span class="x">,</span> <span class="mi">10</span><span class="x">))</span>
<span class="n">A</span><span class="x">,</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">mha</span><span class="x">(</span><span class="n">X</span><span class="x">,</span> <span class="n">X</span><span class="x">,</span> <span class="n">X</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">,</span> <span class="n">start_pos</span><span class="o">=</span><span class="mi">1</span><span class="x">,</span> <span class="n">use_cache</span><span class="o">=</span><span class="nb">true</span><span class="x">)</span>
<span class="n">size</span><span class="x">(</span><span class="n">A</span><span class="x">)</span> <span class="c"># (13, 10, 5)</span>
<span class="n">size</span><span class="x">(</span><span class="n">scores</span><span class="x">)</span> <span class="c"># (10, 10, 4, 5)</span></code></pre></figure>

<p>Use the cache with a new vector:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">x</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="kt">Float32</span><span class="x">,</span> <span class="mi">32</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="mi">5</span><span class="x">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">repeat</span><span class="x">([</span><span class="nb">true</span><span class="x">],</span> <span class="n">inner</span><span class="o">=</span><span class="x">(</span><span class="mi">11</span><span class="x">,</span> <span class="mi">1</span><span class="x">))</span>
<span class="n">Ax</span><span class="x">,</span> <span class="n">scoresx</span> <span class="o">=</span> <span class="n">mha</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="n">x</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">,</span> <span class="n">start_pos</span><span class="o">=</span><span class="mi">11</span><span class="x">,</span> <span class="n">use_cache</span><span class="o">=</span><span class="nb">true</span><span class="x">)</span>
<span class="n">size</span><span class="x">(</span><span class="n">Ax</span><span class="x">)</span> <span class="c"># (13, 1, 5)</span>
<span class="n">size</span><span class="x">(</span><span class="n">scoresx</span><span class="x">)</span> <span class="c"># (11, 1, 4, 5)</span></code></pre></figure>

<p>Compare without the cache:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">Xx</span> <span class="o">=</span> <span class="n">cat</span><span class="x">(</span><span class="n">X</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="n">dims</span><span class="o">=</span><span class="mi">2</span><span class="x">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">make_causal_mask</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">11</span><span class="x">,</span> <span class="mi">11</span><span class="x">))</span>
<span class="n">AXx</span><span class="x">,</span> <span class="n">scoresXx</span> <span class="o">=</span> <span class="n">mha</span><span class="x">(</span><span class="n">Xx</span><span class="x">,</span> <span class="n">Xx</span><span class="x">,</span> <span class="n">Xx</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">,</span> <span class="n">start_pos</span><span class="o">=</span><span class="mi">1</span><span class="x">,</span> <span class="n">use_cache</span><span class="o">=</span><span class="nb">false</span><span class="x">)</span>
<span class="n">isapprox</span><span class="x">(</span><span class="n">AXx</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="k">end</span><span class="x">,</span> <span class="o">:</span><span class="x">],</span> <span class="n">Ax</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="k">end</span><span class="x">,</span> <span class="o">:</span><span class="x">])</span> <span class="c"># true</span></code></pre></figure>

<h2 id="multi-head-latent-attention">3 Multi-Head Latent Attention</h2>
<h3 id="mla-cache">3.1 C cache</h3>

<p>We’ve seen that the KV cache has size $2d_h H \times N \times B$ elements per multi-head attention layer. (Each element is 1-4 bytes depending if FP8, FP16 or FP32 is used.)
The aim of MLA is to reduce this, specifically to $d_c \times N \times B$ elements per multi-head attention layer.
Therefore we will choose $d_c &lt; 2d_h H$.</p>

<figure class="post-figure" id="fig-mla">
<img class="img-95" src="/assets/posts/transformers/deepseek_mla.png" alt="Illustration of KV caching methods" />
<figcaption>Different KV caching techniques.</figcaption>
</figure>

<p><a href="https://github.com/deepseek-ai/DeepSeek-VL2">DeepSeek</a>’s innovation is to introduce a weight matrix $W^{DKV} \in \mathbb{R}^{d_c\times d}$ to compress the input $X \in \mathbb{R}^{d\times n}$ to a lower rank matrix $C^{KV} \in \mathbb{R}^{d_c \times n}$.
This $C^{KV}$ matrix is then stored in the cache.
Then two other weight matrices $W^{UK}$ and $W^{UV} \in \mathbb{R}^{d_h H\times d_c}$ uncompress the same $C^{KV}$ matrix to the key $K$ and value $V$ respectively. The above <a href="#fig-mla">figure</a> shows this visually.</p>

\[\begin{align}
c^{KV}_n &amp;= W^{DKV} x_n \\
K &amp;= W^{UK} C^{KV}_{1:n} \\
V &amp;= W^{UV} C^{KV}_{1:n}
\end{align}
\tag{3.1.1}
\label{eq:mla}\]

<p>The KV cache is now replaced with a $C^{KV}$ cache of size $d_c \times N \times B$.
DeepSeek theorises that this compression also results in a regularization effect that improves performance.
This is supported by other LoRA research.
However, the lossy compression might instead adversely affect quality.
DeepSeek provides no evidence towards either claim.</p>

<p>The compression also results in a significant performance boost which DeepSeek strangely does not mention in their <a href="https://github.com/deepseek-ai/DeepSeek-VL2">paper</a>.
Note that in MLA there are three matrix multiplications to perform to create $K$ and $V$
instead of two matrix multiplications in MHA.
However the three multiplications comprise of less scalar operations:<sup id="fnref:matrix_complexity" role="doc-noteref"><a href="#fn:matrix_complexity" class="footnote" rel="footnote">2</a></sup>
\(\begin{align}
\frac{\text{# MLA ops}}{\text{# MHA ops}} 
 &amp;= \frac{2(2d_h H + d)d_c nB}{2(2d_h H)d n B} \\
 &amp;= \frac{2\frac{d_h H}{d} + 1}{2 \tfrac{d_h H}{d_c}} \\  
 &amp;= \frac{3}{2r}                                  
\end{align}
\tag{3.1.2}
\label{eq:mla_ops}\)
with the standard $d = d_h H$ and a compression ratio $r=\tfrac{d_h H}{d_c}$.
This requires $r &gt; 1.5$ for performance gains.
The only performance penalty is the memory required for the $n$th $c^{KV}_n$ vector before it is transferred to the cache, which is $d_c \times 1 \times B$.</p>

<p>DeepSeek-V3 uses $d_h = 128$, $H=128$ and $d_c=4 d_h = 512$ which means it has a compression ratio of $32$ and a 20× speed up!</p>

<p>To reduce the activation memory during training, DeepSeek also applies the same strategy to the query:</p>

\[\begin{align}
C^{Q} &amp;= W^{DQ} X \\
Q &amp;= W^{UQ} C^{Q}
\end{align}
\tag{3.1.3}
\label{eq:cq}\]

<p>In total, five matrix multiplications are needed to create $Q$, $K$ and $V$ instead of three.
The overall ratio of scalar operations is $\tfrac{5}{3r}$, which requires $r&gt;1.67$ for performance gains.</p>

<p>DeepSeek give further enhancements to this which will be described shortly.
They also apply layer normalisation to $C^Q$ and $C^{KV}$ which I will ignore in this article.
For now, lets see this basic version of MLA in action.</p>

<h3 id="mla-code">3.2 Code </h3>

<p>First create a struct similar to the <code class="language-plaintext highlighter-rouge">MultiHeadAttentionKVCache</code> layer.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> MultiHeadLatentAttention</span><span class="x">{</span><span class="n">D1</span><span class="o">&lt;:</span><span class="n">Dense</span><span class="x">,</span> <span class="n">D2</span><span class="o">&lt;:</span><span class="n">Dense</span><span class="x">,</span> <span class="n">A</span><span class="o">&lt;:</span><span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">3</span><span class="x">}</span> <span class="k">where</span> <span class="n">T</span><span class="x">}</span> 
    <span class="n">nhead</span><span class="o">::</span><span class="kt">Int</span>
    <span class="n">denseDQ</span><span class="o">::</span><span class="n">D1</span>
    <span class="n">denseUQ</span><span class="o">::</span><span class="n">D1</span>
    <span class="n">denseDKV</span><span class="o">::</span><span class="n">D1</span>
    <span class="n">denseUK</span><span class="o">::</span><span class="n">D1</span>
    <span class="n">denseUV</span><span class="o">::</span><span class="n">D1</span>
    <span class="n">denseO</span><span class="o">::</span><span class="n">D2</span>
    <span class="n">cache_ckv</span><span class="o">::</span><span class="n">A</span>
<span class="k">end</span>

<span class="n">Flux</span><span class="o">.</span><span class="nd">@layer</span> <span class="n">MultiHeadLatentAttention</span> <span class="n">trainable</span><span class="o">=</span><span class="x">(</span><span class="n">denseDQ</span><span class="x">,</span> <span class="n">denseUQ</span><span class="x">,</span> <span class="n">denseDKV</span><span class="x">,</span> <span class="n">denseUK</span><span class="x">,</span> <span class="n">denseUV</span><span class="x">,</span> <span class="n">denseO</span><span class="x">)</span></code></pre></figure>

<p>Here is a convenience constructor to construct it from the various input dimensions:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> MultiHeadLatentAttention</span><span class="x">(;</span>
    <span class="n">nhead</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">dim_in</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">dim_head</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">dim_lora</span><span class="x">,</span> <span class="n">dim_out</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span>
    <span class="n">max_seq_length</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">max_batch_size</span><span class="o">::</span><span class="kt">Int</span>
    <span class="x">)</span>
    <span class="n">denseDQ</span> <span class="o">=</span> <span class="n">Dense</span><span class="x">(</span><span class="n">dim_in</span> <span class="o">=&gt;</span> <span class="n">dim_lora</span><span class="x">;</span> <span class="n">bias</span><span class="o">=</span><span class="nb">false</span><span class="x">)</span>
    <span class="n">denseUQ</span> <span class="o">=</span> <span class="n">Dense</span><span class="x">(</span><span class="n">dim_lora</span> <span class="o">=&gt;</span> <span class="n">dim_head</span> <span class="o">*</span> <span class="n">nhead</span><span class="x">;</span> <span class="n">bias</span><span class="o">=</span><span class="nb">false</span><span class="x">)</span>
    <span class="n">denseDKV</span> <span class="o">=</span> <span class="n">Dense</span><span class="x">(</span><span class="n">dim_in</span> <span class="o">=&gt;</span> <span class="n">dim_lora</span><span class="x">;</span> <span class="n">bias</span><span class="o">=</span><span class="nb">false</span><span class="x">)</span>
    <span class="n">denseUK</span> <span class="o">=</span> <span class="n">Dense</span><span class="x">(</span><span class="n">dim_lora</span> <span class="o">=&gt;</span> <span class="n">dim_head</span><span class="o">*</span><span class="n">nhead</span><span class="x">;</span> <span class="n">bias</span><span class="o">=</span><span class="nb">false</span><span class="x">)</span>
    <span class="n">denseUV</span> <span class="o">=</span> <span class="n">Dense</span><span class="x">(</span><span class="n">dim_lora</span> <span class="o">=&gt;</span> <span class="n">dim_head</span><span class="o">*</span><span class="n">nhead</span><span class="x">;</span> <span class="n">bias</span><span class="o">=</span><span class="nb">false</span><span class="x">)</span>
    <span class="n">denseO</span> <span class="o">=</span> <span class="n">Dense</span><span class="x">(</span><span class="n">dim_head</span><span class="o">*</span><span class="n">nhead</span> <span class="o">=&gt;</span> <span class="n">dim_out</span><span class="x">;</span> <span class="n">bias</span><span class="o">=</span><span class="nb">false</span><span class="x">)</span>
    <span class="n">cache_ckv</span> <span class="o">=</span> <span class="kt">Array</span><span class="x">{</span><span class="kt">Float32</span><span class="x">,</span> <span class="mi">3</span><span class="x">}(</span><span class="nb">undef</span><span class="x">,</span> <span class="n">dim_lora</span><span class="x">,</span> <span class="n">max_seq_length</span><span class="x">,</span> <span class="n">max_batch_size</span><span class="x">)</span>
    <span class="n">MultiHeadLatentAttention</span><span class="x">(</span>
        <span class="n">nhead</span><span class="x">,</span>
        <span class="n">denseDQ</span><span class="x">,</span> <span class="n">denseUQ</span><span class="x">,</span>
        <span class="n">denseDKV</span><span class="x">,</span> <span class="n">denseUK</span><span class="x">,</span> <span class="n">denseUV</span><span class="x">,</span>
        <span class="n">denseO</span><span class="x">,</span>
        <span class="n">cache_ckv</span>
    <span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>The forward pass is:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> </span><span class="o">(</span><span class="n">mla</span><span class="o">::</span><span class="n">MultiHeadLatentAttention</span><span class="x">)(</span><span class="n">query</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">key</span><span class="o">::</span><span class="n">A3</span>
    <span class="x">;</span> <span class="n">start_pos</span><span class="o">::</span><span class="kt">Int</span><span class="o">=</span><span class="mi">1</span><span class="x">,</span> <span class="n">use_cache</span><span class="o">::</span><span class="kt">Bool</span><span class="o">=</span><span class="nb">true</span><span class="x">,</span> <span class="n">mask</span><span class="o">::</span><span class="kt">Union</span><span class="x">{</span><span class="kt">Nothing</span><span class="x">,</span> <span class="n">M</span><span class="x">}</span><span class="o">=</span><span class="nb">nothing</span>
    <span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">A3</span> <span class="o">&lt;:</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">3</span><span class="x">},</span> <span class="n">M</span> <span class="o">&lt;:</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="kt">Bool</span><span class="x">}}</span>
    <span class="n">dm</span><span class="x">,</span> <span class="n">seq_length</span><span class="x">,</span> <span class="n">batch_dim</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">key</span><span class="x">)</span>
    <span class="n">cq</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">denseDQ</span><span class="x">(</span><span class="n">query</span><span class="x">)</span> <span class="c"># size(cq) == (dc, dq, B)</span>
    <span class="n">ckv</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">denseDKV</span><span class="x">(</span><span class="n">key</span><span class="x">)</span> <span class="c"># size(ckv) == (dc, dkv, B)</span>
    <span class="k">if</span> <span class="n">use_cache</span>
        <span class="n">end_pos</span> <span class="o">=</span> <span class="n">start_pos</span> <span class="o">+</span> <span class="n">seq_length</span> <span class="o">-</span> <span class="mi">1</span>
        <span class="n">mla</span><span class="o">.</span><span class="n">cache_ckv</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="n">start_pos</span><span class="o">:</span><span class="n">end_pos</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">batch_dim</span><span class="x">]</span> <span class="o">=</span> <span class="n">ckv</span>
        <span class="n">ckv</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">cache_ckv</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">end_pos</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">batch_dim</span><span class="x">]</span>
    <span class="k">end</span>
    <span class="n">K</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">denseUK</span><span class="x">(</span><span class="n">ckv</span><span class="x">)</span> <span class="c"># size(k) == (dh*nhead, dkv, B)</span>
    <span class="n">V</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">denseUV</span><span class="x">(</span><span class="n">ckv</span><span class="x">)</span> <span class="c"># size(v) == (dh*nhead, dkv, B)</span>
    <span class="n">Q</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">denseUQ</span><span class="x">(</span><span class="n">cq</span><span class="x">)</span>  <span class="c"># size(q) == (dh*nhead, dq, B)</span>
    <span class="n">A</span><span class="x">,</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">multi_head_scaled_dot_attention</span><span class="x">(</span><span class="n">mla</span><span class="o">.</span><span class="n">nhead</span><span class="x">,</span> <span class="n">Q</span><span class="x">,</span> <span class="n">K</span><span class="x">,</span> <span class="n">V</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">)</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">denseO</span><span class="x">(</span><span class="n">A</span><span class="x">)</span>
    <span class="n">A</span><span class="x">,</span> <span class="n">scores</span>
<span class="k">end</span></code></pre></figure>

<p>Create a test layer with compression ratio $r=\tfrac{d_hH}{d_c}=4$:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">nhead</span><span class="x">,</span> <span class="n">dim_head</span><span class="x">,</span> <span class="n">dim_lora</span><span class="x">,</span> <span class="n">dim_out</span> <span class="o">=</span> <span class="mi">8</span><span class="x">,</span> <span class="mi">64</span><span class="x">,</span> <span class="mi">128</span><span class="x">,</span> <span class="mi">8</span><span class="o">*</span><span class="mi">64</span>
<span class="n">dim_model</span> <span class="o">=</span> <span class="n">nhead</span> <span class="o">*</span> <span class="n">dim_head</span>
<span class="n">N</span><span class="x">,</span> <span class="n">max_seq_length</span><span class="x">,</span> <span class="n">batch_dim</span> <span class="o">=</span> <span class="mi">20</span><span class="x">,</span> <span class="mi">32</span><span class="x">,</span> <span class="mi">8</span>
<span class="n">mla</span> <span class="o">=</span> <span class="n">MultiHeadLatentAttention</span><span class="x">(</span>
    <span class="n">nhead</span><span class="o">=</span><span class="n">nhead</span><span class="x">,</span> <span class="n">dim_in</span><span class="o">=</span><span class="n">dim_model</span><span class="x">,</span> <span class="n">dim_head</span><span class="o">=</span><span class="n">div</span><span class="x">(</span><span class="n">dim_model</span><span class="x">,</span> <span class="n">nhead</span><span class="x">),</span>
    <span class="n">dim_lora</span><span class="o">=</span><span class="n">dim_lora</span><span class="x">,</span> <span class="n">dim_out</span><span class="o">=</span><span class="n">dim_out</span><span class="x">,</span>
    <span class="n">max_seq_length</span><span class="o">=</span><span class="n">max_seq_length</span><span class="x">,</span> <span class="n">max_batch_size</span><span class="o">=</span><span class="n">batch_dim</span>
    <span class="x">)</span>
<span class="n">X0</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="kt">Float32</span><span class="x">,</span> <span class="n">dim_model</span><span class="x">,</span> <span class="n">N</span><span class="x">,</span> <span class="n">batch_dim</span><span class="x">)</span></code></pre></figure>

<p>Fill the cache:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">mask</span> <span class="o">=</span> <span class="n">make_causal_mask</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="n">N</span><span class="x">,</span> <span class="n">N</span><span class="x">));</span>
<span class="n">A</span><span class="x">,</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">mla</span><span class="x">(</span><span class="n">X0</span><span class="x">,</span> <span class="n">X0</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">,</span> <span class="n">use_cache</span><span class="o">=</span><span class="nb">true</span><span class="x">);</span> 
<span class="n">size</span><span class="x">(</span><span class="n">A</span><span class="x">)</span> <span class="c"># (512, 20, 8)</span>
<span class="n">size</span><span class="x">(</span><span class="n">scores</span><span class="x">)</span> <span class="c"># (20, 20, 8, 8)</span></code></pre></figure>

<p>Use the cache with a new vector:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">x</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="kt">Float32</span><span class="x">,</span> <span class="n">dim_model</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="n">batch_dim</span><span class="x">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">repeat</span><span class="x">([</span><span class="nb">true</span><span class="x">],</span> <span class="n">inner</span><span class="o">=</span><span class="x">(</span><span class="n">N</span> <span class="o">+</span> <span class="mi">1</span><span class="x">,</span> <span class="mi">1</span><span class="x">))</span>
<span class="n">Ax</span><span class="x">,</span> <span class="n">scoresx</span> <span class="o">=</span> <span class="n">mla</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">x</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">,</span> <span class="n">start_pos</span><span class="o">=</span><span class="n">N</span><span class="o">+</span><span class="mi">1</span><span class="x">,</span> <span class="n">use_cache</span><span class="o">=</span><span class="nb">true</span><span class="x">)</span>
<span class="n">size</span><span class="x">(</span><span class="n">Ax</span><span class="x">)</span> <span class="c"># (512, 1, 8)</span>
<span class="n">size</span><span class="x">(</span><span class="n">scoresx</span><span class="x">)</span> <span class="c"># (21, 1, 8, 8)</span></code></pre></figure>

<h3 id="mla-absorption">3.3 Absorption</h3>

<p>DeepSeek suggests a way to further decrease the computational cost by absorbing weight matrices into each other.
To quote <a href="https://github.com/deepseek-ai/DeepSeek-VL2">DeepSeek</a> directly:</p>

<blockquote>
  <p>In addition, during inference, since $W^{UK}$ can be absorbed into $W^{Q}$ , and $W^{UV}$ can be absorbed
into ${W^O}$, we even do not need to compute keys and values out for attention.</p>
</blockquote>

<p>What they mean is that the weight matrices can be multiplied to produce a single weight matrix.
This can only be done during inference because during training they need to be kept separate so that the gradients flow properly backwards through each matrix.</p>

<p>This technique can be used independently of MLA.</p>

<p>To show why this works, rewrite the attention equation $\ref{eq:attention}$ as follows:</p>

\[\begin{align}
S &amp;= K^T{Q} \\
  &amp;= (W^{UK}C^{KV})^T (W^{UQ}C^Q) \\
  &amp;= (C^{KV})^T (W^{UK})^T W^{UQ} C^{Q} \\
  &amp;= (C^{KV})^T W^{KQ} C^{Q} \quad ; W^{KQ}=(W^{UK})^T W^{UQ}
\end{align}
\label{eq:absorbWKQ}
\tag{3.3.1}\]

<p>This looks straightforward but there are further complications with the dimensions.
Here is a dimensional analysis of the above equation following the two rules of batch matrix multiplication:</p>
<ol>
  <li>The inner matrix dimensions must match. That is, the second dimension of the first matrix must match the first dimension of the second.</li>
  <li>All the batch dimensions (dimensions 3 and greater) must be equal.</li>
</ol>

\[\begin{align}
&amp; (d_c \times n \times B)^T (d_h H \times d_c)^T (d_h H \times d_c) (d_c \times n \times B) \\
&amp;= (n \times d_c \times B) (d_c \times d_h \times H) (d_h \times d_c \times H) (d_c \times n \times B) \\
&amp;= (n \times d_c \times B) (d_c \times d_c \times H) (d_c \times n \times B) \\
&amp;= (n \times d_c \times 1 \times B) (d_c \times d_c \times H \times 1) (d_c \times n \times 1 \times B) \\
&amp;= n \times n \times H \times B
\end{align}
\label{eq:absorbWQ_dimension}
\tag{3.3.2}\]

<p>where</p>

<ul>
  <li>Line 2 reshapes the weight matrices  from $d_h H \times d_c$ to $d_h \times d_c \times H$. This is necessary because the non-linear softmax function must be applied independently over each head dimension.</li>
  <li>Line 3 shows that $W^{KQ} \in \mathbb{R}^{d_c \times d_c \times H}$.</li>
  <li>Line 4 adds extra broadcast dimensions to make the batch dimensions match.</li>
</ul>

<p>Broadcasting is a technique where the smaller array is replicated along all dimensions of size 1 to match the size of the larger array.
For broadcasted batched multiplication this only needs to be done for the 3rd and higher dimensions.
Pseudo-code for this is:</p>

<div id="pseudo-broadcast-batch-mul">
<blockquote>
<u><b>Broadcasted batched multiplication</b></u> <br />
inputs: $A \in \mathbb{R}^{I\times R \times L_A \times K_A}$, $B \in \mathbb{R} ^{R \times J \times L_B \times K_B}$ <br />
<b>for</b> $l$ in $1:\max(L_A, L_B)$ <br />
$\quad$ $l_A \leftarrow$ 1 <b> if </b> $L_A=1$ <b> else </b> $l$ <br />
$\quad$ $l_B \leftarrow$ 1 <b> if </b> $L_B=1$ <b> else </b> $l$ <br />
$\quad$ <b>for</b> $k$ in $1:\max(K_A, K_B)$ <br />
$\quad\quad$ $k_A \leftarrow$ 1 <b> if </b> $K_A=1$ <b> else </b> $k$ <br />
$\quad\quad$ $k_B \leftarrow$ 1 <b> if </b> $K_B=1$ <b> else </b> $k$ <br />
$\quad\quad$ $C_{l,b} = A_{l_A,k_A} B_{l_B, k_B}$ <br />
</blockquote>
</div>

<p>This same absorption technique can be applied to the value and output matrices:
\(\begin{align}
Y &amp;= W^O V Z \\
  &amp;= W^O (W^{UV} C^{KV} Z) \\
  &amp;= W^{OV} C^{KV} Z
\end{align}
\label{eq:absorbOV}
\tag{3.3.3}\)</p>

<p>The dimensional analysis here is similar:</p>

\[\begin{align}
&amp; (d_o \times d_h H) (d_h H \times d_c) (d_c \times n \times B) (n \times n \times H \times B) \\
&amp;= (d_o \times d_h \times H) (d_h \times d_c \times H) (d_c \times n \times 1 \times B) (n \times n \times H \times B)\\
&amp;= (d_o \times d_h \times H) (d_h \times d_c \times H) (d_c \times n \times H \times B) \\
&amp;= (d_o \times d_c \times H) (d_c \times n \times H \times B) \\
&amp;= (d_o \times d_c H) (d_c H \times n \times B) \\
&amp;= d_o \times n \times B
\end{align}
\label{eq:absorbOV_dimension}
\tag{3.3.4}\]

<p>This shows that the $C^{KV} Z$ multiplication is a broadcasted batched multiplication.
However where $W^{KQ} \in \mathbb{R}^{d_c \times d_c \times H}$, $W^{OV} \in \mathbb{R}^{d_o \times d_c H}$ is a typical 2D matrix.
Therefore the usual matrix multiplication can be applied by reshaping the $C^{KV} Z$ result from a 3D $d_c H \times n \times B$ array to a $d_c H \times nB$ matrix.</p>

<h3 id="broadcasted-batched-mul">3.4 Broadcasted batched multiplication</h3>

<p>I have written an implementation in Julia directly based on the <a href="#pseudo-broadcast-batch-mul">pseudo code</a>.
It can be seen here: <a href="https://github.com/LiorSinai/TransformersLite.jl/blob/feature/mla/src/broadcasted_batched_mul.jl">broadcasted_batched_mul.jl</a>.
However, it is inefficient and uses scalar indexing which is extremely slow on a GPU.</p>

<p>My solution instead is to physically replicate the broadcasted dimensions.
This is of course inefficient compared to virtual replication but it makes the function viable on a GPU.
The downside is it can be up to 4× slower than the naive code on a CPU.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">Flux</span><span class="o">:</span> <span class="n">batched_mul</span>
<span class="k">function</span><span class="nf"> broadcasted_batched_mul</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">N</span><span class="x">},</span> <span class="n">y</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">N</span><span class="x">})</span> <span class="k">where</span> <span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">N</span><span class="x">}</span>
    <span class="n">batch_dims_x</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">(</span><span class="n">size</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">idx</span><span class="x">)</span> <span class="o">==</span> <span class="mi">1</span> <span class="o">?</span> <span class="n">size</span><span class="x">(</span><span class="n">y</span><span class="x">,</span> <span class="n">idx</span><span class="x">)</span> <span class="o">:</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">idx</span> <span class="k">in</span> <span class="mi">3</span><span class="o">:</span><span class="n">N</span><span class="x">)</span>
    <span class="n">dims_x</span> <span class="o">=</span> <span class="x">(</span><span class="mi">1</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="n">batch_dims_x</span><span class="o">...</span><span class="x">)</span>
    <span class="n">batch_dims_y</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">(</span><span class="n">size</span><span class="x">(</span><span class="n">y</span><span class="x">,</span> <span class="n">idx</span><span class="x">)</span> <span class="o">==</span> <span class="mi">1</span> <span class="o">?</span> <span class="n">size</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">idx</span><span class="x">)</span> <span class="o">:</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">idx</span> <span class="k">in</span> <span class="mi">3</span><span class="o">:</span><span class="n">N</span><span class="x">)</span>
    <span class="n">dims_y</span> <span class="o">=</span> <span class="x">(</span><span class="mi">1</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="n">batch_dims_y</span><span class="o">...</span><span class="x">)</span>
    <span class="n">xb</span> <span class="o">=</span> <span class="n">repeat</span><span class="x">(</span><span class="n">x</span><span class="x">;</span> <span class="n">outer</span><span class="o">=</span><span class="n">dims_x</span><span class="x">)</span>
    <span class="n">yb</span> <span class="o">=</span> <span class="n">repeat</span><span class="x">(</span><span class="n">y</span><span class="x">;</span> <span class="n">outer</span><span class="o">=</span><span class="n">dims_y</span><span class="x">)</span>
    <span class="n">batched_mul</span><span class="x">(</span><span class="n">xb</span><span class="x">,</span> <span class="n">yb</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>The DeepSeek <a href="https://github.com/deepseek-ai/DeepSeek-V3">source code</a> meanwhile uses <a href="https://pytorch.org/docs/stable/generated/torch.einsum.html">torch.einsum</a> and carries out the multiplications right to left instead of creating a new matrix.<sup id="fnref:tensors" role="doc-noteref"><a href="#fn:tensors" class="footnote" rel="footnote">3</a></sup>
Here is the relevant code.
As far as I know, this has the same drawbacks with scalar indexing with none of the advantages of absorption as described in their paper.</p>

<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">wq_b</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">q_norm</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">wq_a</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
<span class="n">q</span> <span class="o">=</span> <span class="n">q</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="n">bsz</span><span class="p">,</span> <span class="n">seqlen</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">n_local_heads</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">qk_head_dim</span><span class="p">)</span>
<span class="n">kv</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">wkv_a</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">wkv_b</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">wkv_b</span><span class="p">.</span><span class="n">weight</span> 
<span class="n">wkv_b</span> <span class="o">=</span> <span class="n">wkv_b</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">n_local_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">kv_lora_rank</span><span class="p">)</span>
<span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">einsum</span><span class="p">(</span><span class="s">"bshd,hdc-&gt;bshc"</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">wkv_b</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">kv_cache</span><span class="p">[:</span><span class="n">bsz</span><span class="p">,</span> <span class="n">start_pos</span><span class="p">:</span><span class="n">end_pos</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">kv_norm</span><span class="p">(</span><span class="n">kv</span><span class="p">)</span>
<span class="n">scores</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">einsum</span><span class="p">(</span><span class="s">"bshc,btc-&gt;bsht"</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">kv_cache</span><span class="p">[:</span><span class="n">bsz</span><span class="p">,</span> <span class="p">:</span><span class="n">end_pos</span><span class="p">])</span></code></pre></figure>

<p>Presumably fast MLA implementions for GPU kernels would make use of virtual replication.
(It would great if someone can clarify what <a href="https://github.com/deepseek-ai/FlashMLA">FlashMLA</a> does.)</p>

<h3 id="mla-absorption-code">3.5 Code </h3>

<p>I will detail some of the code here. For the full code, see <a href="https://github.com/LiorSinai/TransformersLite.jl/blob/feature/mla/src/layers/MultiHeadLatentAttentionV2.jl">MultiHeadLatentAttentionV2.jl</a>.</p>

<p>The first step is create the $W^{KQ}$ and $W^{OV}$ matrices.</p>

<p>First, $W^{KR}=(W^{UK})^T W^{UQ}$ while reshaping from $d_h H \times d_c$ to $d_h \times d_c \times H$.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _absorb_WUK_WUQ</span><span class="x">(</span><span class="n">nhead</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">W_UK</span><span class="o">::</span><span class="kt">AbstractMatrix</span><span class="x">,</span> <span class="n">W_UQ</span><span class="o">::</span><span class="kt">AbstractMatrix</span><span class="x">)</span>
    <span class="n">dh</span> <span class="o">=</span> <span class="n">div</span><span class="x">(</span><span class="n">size</span><span class="x">(</span><span class="n">W_UK</span><span class="x">,</span> <span class="mi">1</span><span class="x">),</span> <span class="n">nhead</span><span class="x">)</span>
    <span class="n">dim_lora</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">W_UK</span><span class="x">,</span> <span class="mi">2</span><span class="x">)</span>
    <span class="n">W_UQ</span> <span class="o">=</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">reshape</span><span class="x">(</span><span class="n">W_UQ</span><span class="x">,</span> <span class="n">dh</span><span class="x">,</span> <span class="n">nhead</span><span class="x">,</span> <span class="n">dim_lora</span><span class="x">),</span> <span class="x">(</span><span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">2</span><span class="x">))</span> 
    <span class="n">W_UK</span> <span class="o">=</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">reshape</span><span class="x">(</span><span class="n">W_UK</span><span class="x">,</span> <span class="n">dh</span><span class="x">,</span> <span class="n">nhead</span><span class="x">,</span> <span class="n">dim_lora</span><span class="x">),</span> <span class="x">(</span><span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">2</span><span class="x">))</span>
    <span class="n">W_UKT</span> <span class="o">=</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">W_UK</span><span class="x">,</span> <span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">))</span> <span class="c"># (dh, dc, nhead)^T =&gt; (dc, dh, nhead)</span>
    <span class="n">batched_mul</span><span class="x">(</span><span class="n">W_UKT</span><span class="x">,</span> <span class="n">W_UQ</span><span class="x">)</span> 
<span class="k">end</span>
<span class="n">W_KQ</span> <span class="o">=</span> <span class="n">_absorb_WUK_WUQ</span><span class="x">(</span><span class="n">nhead</span><span class="x">,</span> <span class="n">denseUK</span><span class="o">.</span><span class="n">weight</span><span class="x">,</span> <span class="n">denseUQ</span><span class="o">.</span><span class="n">weight</span><span class="x">)</span></code></pre></figure>

<p>Then $W^{OV}=W^O W^{UV}$ while preserving the head dimension through reshaping:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _absorb_WO_WUV</span><span class="x">(</span><span class="n">nhead</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">W_O</span><span class="o">::</span><span class="kt">AbstractMatrix</span><span class="x">,</span> <span class="n">W_UV</span><span class="o">::</span><span class="kt">AbstractMatrix</span><span class="x">)</span>
    <span class="n">dh</span> <span class="o">=</span> <span class="n">div</span><span class="x">(</span><span class="n">size</span><span class="x">(</span><span class="n">W_UV</span><span class="x">,</span> <span class="mi">1</span><span class="x">),</span> <span class="n">nhead</span><span class="x">)</span>
    <span class="n">dim_lora</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">W_UV</span><span class="x">,</span> <span class="mi">2</span><span class="x">)</span>
    <span class="n">dout</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">W_O</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span>
    <span class="n">W_UVh</span> <span class="o">=</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">reshape</span><span class="x">(</span><span class="n">W_UV</span><span class="x">,</span> <span class="n">dh</span><span class="x">,</span> <span class="n">nhead</span><span class="x">,</span> <span class="n">dim_lora</span><span class="x">),</span> <span class="x">(</span><span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">2</span><span class="x">))</span> <span class="c"># (dh*nhead, dc) =&gt; (dh, dc, nhead)</span>
    <span class="n">W_Oh</span> <span class="o">=</span> <span class="n">reshape</span><span class="x">(</span><span class="n">W_O</span><span class="x">,</span> <span class="n">dout</span><span class="x">,</span> <span class="n">dh</span><span class="x">,</span> <span class="n">nhead</span><span class="x">)</span> <span class="c"># (dout, dh*nhead) =&gt; (dout, dh, nhead)</span>
    <span class="n">W_OVh</span> <span class="o">=</span> <span class="n">batched_mul</span><span class="x">(</span><span class="n">W_Oh</span><span class="x">,</span> <span class="n">W_UVh</span><span class="x">)</span> <span class="c"># (dout, dh, nhead) * (dh, dc, nhead)</span>
    <span class="n">reshape</span><span class="x">(</span><span class="n">W_OVh</span><span class="x">,</span> <span class="n">dout</span><span class="x">,</span> <span class="n">dim_lora</span><span class="o">*</span><span class="n">nhead</span><span class="x">)</span> <span class="c"># (dout, dc, nhead) =&gt; (dout, dc*nhead)</span>
<span class="k">end</span>
<span class="n">W_OV</span> <span class="o">=</span> <span class="n">_absorb_WO_WUV</span><span class="x">(</span><span class="n">nhead</span><span class="x">,</span> <span class="n">denseO</span><span class="o">.</span><span class="n">weight</span><span class="x">,</span> <span class="n">denseUV</span><span class="o">.</span><span class="n">weight</span><span class="x">)</span></code></pre></figure>

<p>The forward pass is the same as in the original <a href="#mla-code">code</a> until the end of caching:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> mla_absorb</span><span class="x">(</span>
    <span class="n">mla</span><span class="o">::</span><span class="n">MultiHeadLatentAttention</span><span class="x">,</span> <span class="n">query</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">key</span><span class="o">::</span><span class="n">A3</span>
    <span class="x">;</span> <span class="n">start_pos</span><span class="o">::</span><span class="kt">Int</span><span class="o">=</span><span class="mi">1</span><span class="x">,</span> <span class="n">use_cache</span><span class="o">::</span><span class="kt">Bool</span><span class="o">=</span><span class="nb">true</span><span class="x">,</span> <span class="n">mask</span><span class="o">::</span><span class="kt">Union</span><span class="x">{</span><span class="kt">Nothing</span><span class="x">,</span> <span class="n">M</span><span class="x">}</span><span class="o">=</span><span class="nb">nothing</span>
    <span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">A3</span> <span class="o">&lt;:</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">3</span><span class="x">},</span> <span class="n">M</span> <span class="o">&lt;:</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="kt">Bool</span><span class="x">}}</span>
    <span class="n">dm</span><span class="x">,</span> <span class="n">seq_length</span><span class="x">,</span> <span class="n">batch_dim</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">key</span><span class="x">)</span>
    <span class="n">dh</span> <span class="o">=</span> <span class="n">div</span><span class="x">(</span><span class="n">dm</span><span class="x">,</span> <span class="n">mla</span><span class="o">.</span><span class="n">nhead</span><span class="x">)</span>
    <span class="n">cq</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">norm_cq</span><span class="x">(</span><span class="n">mla</span><span class="o">.</span><span class="n">denseDQ</span><span class="x">(</span><span class="n">query</span><span class="x">))</span>  <span class="c"># size(cq) == (dc, dq, B)</span>
    <span class="n">ckv</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">norm_ckv</span><span class="x">(</span><span class="n">mla</span><span class="o">.</span><span class="n">denseDKV</span><span class="x">(</span><span class="n">key</span><span class="x">))</span> <span class="c"># size(ckv) == (dc, dkv, B)</span>
    <span class="k">if</span> <span class="n">use_cache</span>
        <span class="n">end_pos</span> <span class="o">=</span> <span class="n">start_pos</span> <span class="o">+</span> <span class="n">seq_length</span> <span class="o">-</span> <span class="mi">1</span>
        <span class="n">mla</span><span class="o">.</span><span class="n">cache_ckv</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="n">start_pos</span><span class="o">:</span><span class="n">end_pos</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">batch_dim</span><span class="x">]</span> <span class="o">=</span> <span class="n">ckv</span>
        <span class="n">ckv</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">cache_ckv</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">end_pos</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">batch_dim</span><span class="x">]</span>
    <span class="k">end</span></code></pre></figure>

<p>Then add the broadcast dimensions:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="n">ckv_</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">unsqueeze</span><span class="x">(</span><span class="n">ckv</span><span class="x">,</span> <span class="n">dims</span><span class="o">=</span><span class="mi">3</span><span class="x">)</span>
    <span class="n">keyT</span> <span class="o">=</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">ckv_</span><span class="x">,</span> <span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">4</span><span class="x">))</span> <span class="c"># (dkv, dc, B) =&gt; (dkv, dc, 1, B)</span>
    <span class="n">cq_</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">unsqueeze</span><span class="x">(</span><span class="n">cq</span><span class="x">,</span> <span class="n">dims</span><span class="o">=</span><span class="mi">3</span><span class="x">)</span> <span class="c"># (dkv, dc, B) =&gt; (dkv, dc, 1, B)</span>
    <span class="n">W_KQ</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">unsqueeze</span><span class="x">(</span><span class="n">mla</span><span class="o">.</span><span class="n">W_KQ</span><span class="x">,</span> <span class="n">dims</span><span class="o">=</span><span class="mi">4</span><span class="x">);</span> <span class="c"># (dc, dc, nhead) =&gt; (dc, dc, nhead, 1)</span></code></pre></figure>

<p>Then apply the equations as before except using <code class="language-plaintext highlighter-rouge">broadcasted_batched_mul</code> instead of <code class="language-plaintext highlighter-rouge">batched_mul</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="n">atten_base</span> <span class="o">=</span> <span class="n">broadcasted_batched_mul</span><span class="x">(</span><span class="n">keyT</span><span class="x">,</span> <span class="n">broadcasted_batched_mul</span><span class="x">(</span><span class="n">W_KQ</span><span class="x">,</span> <span class="n">cq_</span><span class="x">))</span>
    <span class="n">atten</span> <span class="o">=</span> <span class="n">one</span><span class="x">(</span><span class="n">T</span><span class="x">)</span><span class="o">/</span><span class="n">convert</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">sqrt</span><span class="x">(</span><span class="n">dh</span><span class="x">))</span> <span class="o">.*</span> <span class="x">(</span><span class="n">atten_base</span><span class="x">)</span>
    <span class="n">atten</span> <span class="o">=</span> <span class="n">apply_mask</span><span class="x">(</span><span class="n">atten</span><span class="x">,</span> <span class="n">mask</span><span class="x">)</span>
    <span class="n">scores</span> <span class="o">=</span> <span class="n">softmax</span><span class="x">(</span><span class="n">atten</span><span class="x">;</span> <span class="n">dims</span><span class="o">=</span><span class="mi">1</span><span class="x">)</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">broadcasted_batched_mul</span><span class="x">(</span><span class="n">ckv_</span><span class="x">,</span> <span class="n">scores</span><span class="x">)</span> <span class="c"># (dc, dq, nhead, B)</span>
    <span class="c"># (dc, dq, nhead, B) =&gt; (dc*nhead, dq, B)</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="x">[</span><span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">])</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">reshape</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="o">:</span><span class="x">,</span> <span class="n">size</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="mi">3</span><span class="x">),</span> <span class="n">size</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="mi">4</span><span class="x">))</span>
    <span class="n">mla</span><span class="o">.</span><span class="n">denseOV</span><span class="x">(</span><span class="n">A</span><span class="x">),</span> <span class="n">scores</span> 
<span class="k">end</span></code></pre></figure>

<p>Test that these give the same result:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">X0</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="kt">Float32</span><span class="x">,</span> <span class="n">dim_model</span><span class="x">,</span> <span class="n">N</span><span class="x">,</span> <span class="n">batch_dim</span><span class="x">)</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">make_causal_mask</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="n">N</span><span class="x">,</span> <span class="n">N</span><span class="x">));</span>
<span class="n">A_naive</span><span class="x">,</span> <span class="n">scores_naive</span> <span class="o">=</span> <span class="n">mla_naive</span><span class="x">(</span><span class="n">mla</span><span class="x">,</span> <span class="n">X0</span><span class="x">,</span> <span class="n">X0</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">,</span> <span class="n">use_cache</span><span class="o">=</span><span class="nb">true</span><span class="x">);</span>
<span class="n">A_absorb</span><span class="x">,</span> <span class="n">scores_absorb</span> <span class="o">=</span> <span class="n">mla_absorb</span><span class="x">(</span><span class="n">mla</span><span class="x">,</span> <span class="n">X0</span><span class="x">,</span> <span class="n">X0</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">,</span> <span class="n">use_cache</span><span class="o">=</span><span class="nb">true</span><span class="x">);</span>
<span class="n">isapprox</span><span class="x">(</span><span class="n">A_absorb</span><span class="x">,</span> <span class="n">A_naive</span><span class="x">)</span> <span class="c"># true</span></code></pre></figure>

<h3 id="mla-rope">3.6 Decoupled RoPE</h3>

<p>The last enhancement DeepSeek adds is <a href="https://arxiv.org/abs/2104.09864">RoPE</a>.
One issue with RoPE however is that it breaks the absorption property described above.
To prove this, RoPE can be represented as a series of matrix multiplications on each column in an input matrix $X$:</p>

\[\text{RoPE}(X) = \begin{bmatrix} R_1 X_1 &amp; R_2 X_2 &amp; ... &amp; R_n X_n \end{bmatrix}
\label{eq:RoPE}
\tag{3.4.1}\]

<p>Applying this to the scores equation:
\(\begin{align}
S &amp;= \text{RoPE}(K)^T\text{RoPE}(Q) \\
  &amp;= \text{RoPE}(W^{UK}C^{KV})^T \text{RoPE}(W^{UQ}C^Q) \\
  &amp;= \begin{bmatrix} R_1 W^{UK}_1 C^{KV}_1 &amp; ... &amp; R_n W^{UK}_n C^{KV}_n \end{bmatrix}^T \\
  &amp;\phantom{=x} \begin{bmatrix} R_1 W^{UQ}_1 C^Q_1 &amp; ... &amp; R_n W^{UQ}_n C^Q_n \end{bmatrix} \\
\implies S_{ij} &amp;= (C^{KV}_{i})^T (W^{UK}_i)^T R_i^T R_j W^{UQ}_j C^{Q}_j
\end{align}
\label{eq:absorb_RoPE}
\tag{3.4.2}\)</p>

<p>which shows that the rotation matrices will appear right in the middle of the product.</p>

<p>DeepSeek’s solution is to concatenate another matrix to the bottom of the key $K$ and query $Q$ respectively and only apply RoPE to these matrices.
Furthermore, because the new matrix $K^R$ will also need to be cached, they share it across all heads.
To put it another way, $K^R$ will be broadcasted across the head dimension during multiplication.
So for each head $h$:</p>

\[\begin{align}
K_h &amp;= \begin{bmatrix} W^{UK}_h C^{KV} \\ \text{RoPE}(W^{KR} X) \end{bmatrix} \\
Q_h &amp;= \begin{bmatrix} W^{UQ}_h C^{Q} \\ \text{RoPE}(W^{QR}_h C^{Q}) \end{bmatrix}
\end{align}
\label{eq:MLA_RoPE}
\tag{3.4.3}\]

<p>where $W^{KR} \in \mathbb{R}^{d_R \times d}$ and $W^{QR} \in \mathbb{R}^{d_R H \times d_c}$. This means that $K,Q \in \mathbb{R}^{(d_h + d_R) \times n \times H \times B}$.
The cache will consist of both $C^{KV}$ and $K^R$ for a total size of $(d_c + d_R) \times N \times B$ elements per layer.</p>

<p>Very conveniently, this results in an addition between the original and embedded scores:</p>

\[\begin{align}
S_h &amp;= K^T Q \\
  &amp;= \begin{bmatrix} (K^0)^T &amp; (K^{R})^T \end{bmatrix} \begin{bmatrix} Q^0 \\ Q^{R} \end{bmatrix} \\
  &amp;=  (K^0)^T Q^0 + (K^{R})^T Q^{R}
\end{align}
\label{eq:MLA_RoPE_scores_}
\tag{3.4.4}\]

<p>which means that these results can be calculated separately. Note that $S_h$ must now be scaled by $1/\sqrt{d_h + d_R}$ instead of $1/\sqrt{d_h}$.</p>

<h3 id="rope-code">3.7 Code</h3>

<p>I will detail some of the code here. For the full code, see <a href="https://github.com/LiorSinai/TransformersLite.jl/blob/feature/mla/src/layers/MultiHeadLatentAttentionV2.jl">MultiHeadLatentAttentionV2.jl</a>.</p>

<p>Here is a Julia implementation of RoPE:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> RoPE</span><span class="x">{</span><span class="n">T</span><span class="x">}</span>
    <span class="n">base</span><span class="o">::</span><span class="kt">Int</span>
    <span class="n">dim</span><span class="o">::</span><span class="kt">Int</span>
    <span class="n">seq_length</span><span class="o">::</span><span class="kt">Int</span>
    <span class="n">freqs_complex</span><span class="o">::</span><span class="kt">Matrix</span><span class="x">{</span><span class="kt">Complex</span><span class="x">{</span><span class="n">T</span><span class="x">}}</span>
<span class="k">end</span>

<span class="n">RoPE</span><span class="x">(</span><span class="n">dim</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">max_seq_length</span><span class="o">::</span><span class="kt">Int</span><span class="x">;</span> <span class="n">base</span><span class="o">::</span><span class="kt">Int</span><span class="o">=</span><span class="mi">10_000</span><span class="x">)</span> <span class="o">=</span> <span class="n">RoPE</span><span class="x">(</span><span class="kt">Float32</span><span class="x">,</span> <span class="n">dim</span><span class="x">,</span> <span class="n">max_seq_length</span><span class="x">;</span> <span class="n">base</span><span class="o">=</span><span class="n">base</span><span class="x">)</span>

<span class="k">function</span><span class="nf"> RoPE</span><span class="x">(</span><span class="n">T</span><span class="o">::</span><span class="kt">DataType</span><span class="x">,</span> <span class="n">dim</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">max_seq_length</span><span class="o">::</span><span class="kt">Int</span><span class="x">;</span> <span class="n">base</span><span class="o">::</span><span class="kt">Int</span><span class="o">=</span><span class="mi">10_000</span><span class="x">)</span>
    <span class="nd">@assert</span> <span class="n">dim</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span> <span class="s">"Require even dim"</span>
    <span class="n">θ</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">./</span> <span class="x">(</span><span class="n">base</span> <span class="o">.^</span> <span class="x">((</span><span class="mi">0</span><span class="o">:</span><span class="mi">2</span><span class="o">:</span><span class="x">(</span><span class="n">dim</span> <span class="o">-</span> <span class="mi">2</span><span class="x">))</span> <span class="o">/</span> <span class="n">dim</span><span class="x">))</span>
    <span class="n">angles</span> <span class="o">=</span> <span class="n">θ</span> <span class="o">*</span> <span class="n">transpose</span><span class="x">(</span><span class="mi">0</span><span class="o">:</span><span class="x">(</span><span class="n">max_seq_length</span><span class="o">-</span><span class="mi">1</span><span class="x">))</span>
    <span class="n">freqs</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">x</span> <span class="o">-&gt;</span> <span class="n">reverse</span><span class="x">(</span><span class="n">sincos</span><span class="x">(</span><span class="n">x</span><span class="x">)),</span> <span class="n">angles</span><span class="x">)</span>
    <span class="n">freqs_complex</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">cs</span> <span class="o">-&gt;</span> <span class="kt">Complex</span><span class="x">(</span><span class="n">cs</span><span class="o">...</span><span class="x">),</span> <span class="n">freqs</span><span class="x">)</span>
    <span class="n">RoPE</span><span class="x">{</span><span class="n">T</span><span class="x">}(</span><span class="n">base</span><span class="x">,</span> <span class="n">dim</span><span class="x">,</span> <span class="n">max_seq_length</span><span class="x">,</span> <span class="n">freqs_complex</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>The forward pass can be calculated with matrices, but the RoPE authors gave a more efficient implementation with complex numbers:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="x">(</span><span class="n">r</span><span class="o">::</span><span class="n">RoPE</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">)</span> <span class="o">=</span> <span class="n">apply_rope</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">r</span><span class="o">.</span><span class="n">freqs_complex</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">size</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="mi">2</span><span class="x">)])</span>
<span class="x">(</span><span class="n">r</span><span class="o">::</span><span class="n">RoPE</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">,</span> <span class="n">indices</span><span class="x">)</span> <span class="o">=</span> <span class="n">apply_rope</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">r</span><span class="o">.</span><span class="n">freqs_complex</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="n">indices</span><span class="x">])</span>

<span class="k">function</span><span class="nf"> apply_rope</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">freqs_complex</span><span class="o">::</span><span class="kt">AbstractMatrix</span><span class="x">{</span><span class="o">&lt;:</span><span class="kt">Complex</span><span class="x">{</span><span class="n">T</span><span class="x">}})</span> <span class="k">where</span> <span class="n">T</span>
    <span class="n">x_complex</span> <span class="o">=</span> <span class="n">reinterpret</span><span class="x">(</span><span class="kt">Complex</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">x</span><span class="x">)</span>
    <span class="n">rx_complex</span> <span class="o">=</span> <span class="n">freqs_complex</span> <span class="o">.*</span> <span class="n">x_complex</span>
    <span class="n">T</span><span class="o">.</span><span class="x">(</span><span class="n">reinterpret</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">rx_complex</span><span class="x">))</span>
<span class="k">end</span></code></pre></figure>

<p>Then add the <code class="language-plaintext highlighter-rouge">embedding</code>, <code class="language-plaintext highlighter-rouge">denseQR</code> and <code class="language-plaintext highlighter-rouge">denseKR</code> layers to the <code class="language-plaintext highlighter-rouge">MultiHeadLatentAttention</code> struct.</p>

<p>The embeddings are applied as follows:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _apply_embeddings</span><span class="x">(</span><span class="n">mla</span><span class="o">::</span><span class="n">MultiHeadLatentAttention</span><span class="x">,</span> <span class="n">key</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">cq</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">idx</span><span class="o">::</span><span class="kt">UnitRange</span><span class="x">{</span><span class="kt">Int</span><span class="x">})</span> <span class="k">where</span> <span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">A3</span> <span class="o">&lt;:</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">3</span><span class="x">}}</span>
    <span class="n">dim_lora</span><span class="x">,</span> <span class="n">dq</span><span class="x">,</span> <span class="n">batch_dim</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">cq</span><span class="x">)</span>
    <span class="n">kr</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">denseKR</span><span class="x">(</span><span class="n">key</span><span class="x">)</span>
    <span class="n">qr</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">denseQR</span><span class="x">(</span><span class="n">cq</span><span class="x">)</span>
    <span class="n">kr</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">embedding</span><span class="x">(</span><span class="n">kr</span><span class="x">,</span> <span class="n">idx</span><span class="x">)</span> <span class="c"># size(kr) == (dr, dkv, B)</span>
    <span class="n">qr</span> <span class="o">=</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">reshape</span><span class="x">(</span><span class="n">qr</span><span class="x">,</span> <span class="o">:</span><span class="x">,</span> <span class="n">mla</span><span class="o">.</span><span class="n">nhead</span><span class="x">,</span> <span class="n">dq</span><span class="x">,</span> <span class="n">batch_dim</span><span class="x">),</span> <span class="x">(</span><span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">))</span> <span class="c"># (dr*nhead, dq, B) =&gt; (dr, dq, nhead, B)</span>
    <span class="n">qr</span> <span class="o">=</span> <span class="n">mla</span><span class="o">.</span><span class="n">embedding</span><span class="x">(</span><span class="n">qr</span><span class="x">,</span> <span class="n">idx</span><span class="x">)</span>
    <span class="n">kr</span><span class="x">,</span> <span class="n">qr</span>
<span class="k">end</span>
<span class="n">kr</span><span class="x">,</span> <span class="n">qr</span> <span class="o">=</span> <span class="n">_apply_embeddings</span><span class="x">(</span><span class="n">mla</span><span class="x">,</span> <span class="n">key</span><span class="x">,</span> <span class="n">cq</span><span class="x">,</span> <span class="n">start_pos</span><span class="o">:</span><span class="n">end_pos</span><span class="x">)</span></code></pre></figure>

<p>Note that embedding is done per head, hence the reshaping of <code class="language-plaintext highlighter-rouge">qr</code>.</p>

<p>For the naive method, concatenate along the head dimension. This requires reshaping for <code class="language-plaintext highlighter-rouge">qr</code> and repeating <code class="language-plaintext highlighter-rouge">kr</code>.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="n">Q</span><span class="x">,</span> <span class="n">K</span> <span class="o">=</span> <span class="n">_cat_decoupled_embedding</span><span class="x">(</span><span class="n">mla</span><span class="o">.</span><span class="n">nhead</span><span class="x">,</span> <span class="n">Q</span><span class="x">,</span> <span class="n">qr</span><span class="x">,</span> <span class="n">K</span><span class="x">,</span> <span class="n">kr</span><span class="x">)</span></code></pre></figure>

<p>where:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _cat_decoupled_embedding</span><span class="x">(</span>
    <span class="n">nhead</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">Qin</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">Qr</span><span class="o">::</span><span class="n">A4</span><span class="x">,</span> <span class="n">Kin</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">kr</span><span class="o">::</span><span class="n">A3</span>
    <span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">A3</span> <span class="o">&lt;:</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">3</span><span class="x">},</span> <span class="n">A4</span> <span class="o">&lt;:</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">4</span><span class="x">}}</span>
    <span class="n">dhq</span><span class="x">,</span> <span class="n">dq</span><span class="x">,</span> <span class="n">B</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">Qin</span><span class="x">)</span>
    <span class="n">dhk</span><span class="x">,</span> <span class="n">dkv</span><span class="x">,</span> <span class="n">B</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">Kin</span><span class="x">)</span>
    <span class="n">Q</span> <span class="o">=</span> <span class="n">reshape</span><span class="x">(</span>
        <span class="n">cat</span><span class="x">(</span><span class="n">reshape</span><span class="x">(</span><span class="n">Qin</span><span class="x">,</span> <span class="o">:</span><span class="x">,</span> <span class="n">nhead</span><span class="x">,</span> <span class="n">dq</span><span class="x">,</span> <span class="n">B</span><span class="x">),</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">Qr</span><span class="x">,</span> <span class="x">(</span><span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">)),</span> <span class="n">dims</span><span class="o">=</span><span class="mi">1</span><span class="x">),</span>
        <span class="o">:</span> <span class="x">,</span> <span class="n">dq</span><span class="x">,</span> <span class="n">B</span><span class="x">)</span>
    <span class="n">Kr</span> <span class="o">=</span> <span class="n">repeat</span><span class="x">(</span><span class="n">Flux</span><span class="o">.</span><span class="n">unsqueeze</span><span class="x">(</span><span class="n">kr</span><span class="x">,</span> <span class="n">dims</span><span class="o">=</span><span class="mi">2</span><span class="x">),</span> <span class="n">outer</span><span class="o">=</span><span class="x">(</span><span class="mi">1</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="mi">1</span><span class="x">))</span>
    <span class="n">K</span> <span class="o">=</span> <span class="n">reshape</span><span class="x">(</span>
        <span class="n">cat</span><span class="x">(</span><span class="n">reshape</span><span class="x">(</span><span class="n">Kin</span><span class="x">,</span> <span class="o">:</span><span class="x">,</span> <span class="n">nhead</span><span class="x">,</span> <span class="n">dkv</span><span class="x">,</span> <span class="n">B</span><span class="x">),</span> <span class="n">reshape</span><span class="x">(</span><span class="n">Kr</span><span class="x">,</span> <span class="o">:</span><span class="x">,</span> <span class="n">nhead</span><span class="x">,</span> <span class="n">dkv</span><span class="x">,</span> <span class="n">B</span><span class="x">),</span> <span class="n">dims</span><span class="o">=</span><span class="mi">1</span><span class="x">),</span>
        <span class="o">:</span><span class="x">,</span> <span class="n">dkv</span><span class="x">,</span> <span class="n">B</span><span class="x">)</span>
    <span class="n">Q</span><span class="x">,</span> <span class="n">K</span>
<span class="k">end</span></code></pre></figure>

<p>Then continue as before.</p>

<p>With absorption, broadcast batched multiply <code class="language-plaintext highlighter-rouge">kr</code> and <code class="language-plaintext highlighter-rouge">qr</code> and add to the original attention:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="n">krT</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">unsqueeze</span><span class="x">(</span><span class="n">permutedims</span><span class="x">(</span><span class="n">kr</span><span class="x">,</span> <span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">)),</span> <span class="n">dims</span><span class="o">=</span><span class="mi">3</span><span class="x">)</span> <span class="c"># (dr, dkv, B) =&gt; (dkv, dr, 1, 1B)</span>
    <span class="n">atten_base</span> <span class="o">=</span> <span class="n">broadcasted_batched_mul</span><span class="x">(</span><span class="n">keyT</span><span class="x">,</span> <span class="n">broadcasted_batched_mul</span><span class="x">(</span><span class="n">W_KQ</span><span class="x">,</span> <span class="n">cq_</span><span class="x">))</span>
    <span class="n">atten_embed</span> <span class="o">=</span> <span class="n">broadcasted_batched_mul</span><span class="x">(</span><span class="n">krT</span><span class="x">,</span> <span class="n">qr</span><span class="x">)</span>
    <span class="n">atten</span> <span class="o">=</span> <span class="n">one</span><span class="x">(</span><span class="n">T</span><span class="x">)</span><span class="o">/</span><span class="n">convert</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">sqrt</span><span class="x">(</span><span class="n">dh</span> <span class="o">+</span> <span class="n">dr</span><span class="x">))</span> <span class="o">.*</span> <span class="x">(</span><span class="n">atten_base</span> <span class="o">+</span> <span class="n">atten_embed</span><span class="x">)</span></code></pre></figure>

<h2 id="conclusion">Conclusion</h2>

<p>Overall, I think MLA is a smart and useful idea. 
However, after having explored it in depth, I am more critical of their enhancements.</p>

<p>The basic premise of Multi-Head Latent Attention is simple.
It compresses the input matrix so that a single smaller $C^{KV}$ matrix can be stored instead of the key $K$ and value $V$ matrices.
It then uncompresses this matrix into $K$ and $V$ with two additional weight matrices.
This also results in a significant performance increase which scales with the compression ratio $\frac{d_h H}{d_c}$ by a factor of $\tfrac{2}{3}$ - so the compression needs to be greater than a modest 1.5 for gains to be realised - and requires no further modifications to existing MHA code.
However, it is unclear what the qualitative effects of the compression are, and it is strange that DeepSeek did not discuss the performance benefits.</p>

<p>To this DeepSeek adds weight absorption and decoupled RoPE.
We have seen this complicates the mathematics and requires careful dimensional analysis.
True performance gains only come with an optimised <code class="language-plaintext highlighter-rouge">broadcasted_batched_mul</code> function.
Their own open source code does not even have such optimisations.
Personally, I see no benefit to this and would recommend the naive method with RoPE applied normally.
That is, apply each of the weight matrices to their inputs individually and then apply RoPE to the entire $K$ and $Q$ matrices.</p>

<p>While I am impressed with the ingenuity behind MLA, DeepSeek’s omissions coupled with extra, unwieldly enhancements makes me more skeptical of their methodology.
If I examine their other techniques, I will do so with more caution.</p>

<hr />

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:forgive" role="doc-endnote">
      <p>Indices for dimensions are shown when they are relevant and left out when they’re not. For example, the row index (along the embedding dimension) is generally ignored so $X_j$ is the $j$th column/token of $X$. But then later for the scores I’ll use $X_{i,j}$ because both indices represent a position in the token sequence. But then later I use $X_h$ to indicate the $h$th head of $X$ which is along the 3rd dimension. Please forgive me for this and other abuses of matrix notation in this post. <a href="#fnref:forgive" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:matrix_complexity" role="doc-endnote">
      <p>This is the naive matrix multiplication algorithm. For sizes $n \times d$ and $d \times m$, for each of the $nm$ output elements there are $d$ multiplications and $d-1$ additions (no addition for the first element), so there are $nm(2d-1)$ operations in total. <a href="#fnref:matrix_complexity" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:tensors" role="doc-endnote">
      <p>In general multiplication is not defined for higher order arrays. But there is a set of multidimensional algebraic objects called <a href="https://en.wikipedia.org/wiki/Tensor">tensors</a> where it is, and Einstein notation was designed for this use case.
Confusingly, Google named their machine learning framework TensorFlow and calls higher order arrays tensors.
So one should differentiate between machine learning tensors and geometric tensors.
They are not the same.
To give a simple explanation: one can think of geometric tensors as higher order arrays with severe constraints on their entries and operations because they represent geometric objects. These constraints make it harder - not easier - to code higher order arrays as geometric tensors. <a href="#fnref:tensors" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Lior Sinai</name></author><category term="machine-learning" /><category term="mathematics" /><category term="transformers" /><category term="&apos;machine" /><category term="learning&apos;" /><category term="&apos;deep" /><category term="learning&apos;" /><summary type="html"><![CDATA[A deep dive into DeepSeek’s Multi-Head Latent Attention, including the mathematics and implementation details. The layer is recreated in Julia using Flux.jl.]]></summary></entry><entry><title type="html">Notes on the Martinez-Rueda Polygon Clipping algorithm</title><link href="https://liorsinai.github.io/mathematics/2025/01/11/bentley-ottman.html" rel="alternate" type="text/html" title="Notes on the Martinez-Rueda Polygon Clipping algorithm" /><published>2025-01-11T00:00:00+00:00</published><updated>2025-01-11T00:00:00+00:00</updated><id>https://liorsinai.github.io/mathematics/2025/01/11/bentley-ottman</id><content type="html" xml:base="https://liorsinai.github.io/mathematics/2025/01/11/bentley-ottman.html"><![CDATA[<p><em>The Martinez-Rueda algorithm computes boolean operations between polygons. It can be used for polygon intersections (polygon clipping), unions, differences and XORs. I recently implemented it by following a comprehensive guide at https://sean.fun/a/polygon-clipping-pt2/. However, it was slightly lacking in some complex scenarios, mainly resulting from the strict ordering required by the Bentley-Ottmann line intersection algorithm. This post explains my minor modifications to address this crucial part of the algorithm.</em></p>

<h3 id="table-of-contents">Table of Contents</h3>

<nav id="toc"></nav>
<script src="/assets/makeTableOfContents.js"></script>

<h2 id="introduction">1 Introduction</h2>

<figure class="post-figure" id="fig-spiral-star">
<img class="img-95" src="/assets/posts/polygon-clipping/spiral_star_martinez.png" alt="Boolean operations spiral start" />
<figcaption>Boolean operations between a spiral and a star, computed with the Martinez-Rueda algorithm.</figcaption>
</figure>

<p>I recently updated my <a href="https://github.com/LiorSinai/PolygonAlgorithms.jl">PolygonAlgorithms.jl</a> package to use the Martinez-Rueda algorithm for boolean operations between polygons.
I had originally implemented a version of the Weiler-Atherton algorithm, explained in detail in an earlier <a href="/mathematics/2023/09/30/polygon-clipping">blog post</a>.
However, that algorithm can only calculate intersections between polygons, whereas Martinez-Rueda simultaneously calculates 
the intersection as well as unions, differences and XORs between the polygons.
See the above example and the table below for a brief comparison between the algorithms.</p>

<table>
<thead>
  <tr>
    <th></th>
    <th>Martinez-Rueda</th>
    <th>Weiler-Atherton</th>
  </tr>
</thead>
<tbody>
  <tr>
    <td>Operation</td>
    <td>Segment level. Compare fill annotations.</td>
    <td>Point level. Walk along loops.</td>
  </tr>
  <tr>
    <td>Polygon types</td>
    <td>Convex, concave, self-intersecting, holes.</td>
    <td>Convex, concave. Can be extended to holes.</td>
  </tr>
  <tr>
    <td>Time complexity</td>
    <td>$\mathcal{O}(nm)$</td>
    <td>$\mathcal{O}((n+m+k)\log(n+m))$</td>
  </tr>
  <tr>
    <td>Return types</td>
    <td>Segments and regions.</td>
    <td>Points, segments and regions.</td>
  </tr>
</tbody>
</table>

<p>The Martinez-Rueda algorithm is more versatile because it fundamentally operates at a segment level whereas Weiler-Atherton which operates a point level, and so it has a “bigger picture” view of the polygons.
A disadvantage of the Martinez-Rueda algorithm is that it is more sensitive to numerical inaccuracies - for reasons that will be described shortly - such as a line that is almost vertical or tiny regions of intersection. In practice I found their runtimes similar, with Martinez-Rueda running faster in some situations and slower in others. For the spiral-star example, it is about 1.5 slower.</p>

<p>The original paper can be found <a href="https://www.researchgate.net/publication/220163820_A_new_algorithm_for_computing_Boolean_operations_on_polygons">here</a>, but I followed the guide at <a href="https://sean.fun/a/polygon-clipping-pt2/">https://sean.fun/a/polygon-clipping-pt2/</a>.</p>

<figure class="post-figure" id="fig-fill-annotations">
<img class="img-95" src="/assets/posts/polygon-clipping/martinez_rueda.png" alt="Fill annotations" />
<figcaption>Fill annotations in the Martinez-Rueda algorithm.</figcaption>
</figure>

<p>The core idea behind the Martinez-Rueda algorithm is to calculate fill annotations for each segment for each polygon: is this segment filled above and below by this polygon, and is it filled above and below by the other polygon?
Once these are known, it is easy to select the relevant segments to the given operation, and to link them up again into polygons.</p>

<figure class="post-figure" id="fig-sweep">
	<div class="row">
		<div class="col">
			<img class="img-fluid" src="/assets/posts/polygon-clipping/martinez-rueda-sweep.png" alt="Line sweep" />
		</div>
		<div class="col">
			<img class="img-fluid" src="/assets/posts/polygon-clipping/martinez-rueda-stack.png" alt="Line stack" />
		</div>
    </div>
<figcaption>Line sweep and line stack in the Martinez-Rueda algorithm. Source: <a href="https://sean.fun/a/polygon-clipping-pt2/">sean.fun/a/polygon-clipping-pt2</a>.</figcaption>
</figure>

<p>The genius of the Martinez-Rueda algorithm is to extend upon the <a href="https://en.wikipedia.org/wiki/Bentley%E2%80%93Ottmann_algorithm">Bentley-Ottmann algorithm</a> for segment intersections to do this.
It does a vertical line sweep from left to right, bottom to top.
At any any given moment, we can imagine having a stack of all the lines that intersect the vertical line, ordered from top to bottom.
According to the Bentley-Ottmann algorithm, to find intersections through a segment, we only need to check for intersections with the segments immediately above and immediately below it in stack.
At the same time, we can propagate the fill annotations from the segment below, or empty space if nothing is below it.
Hence, finding the exact segments that are above and below a segment is paramount to this algorithm, and 
even slight mistakes can cause errors that propagate to other segments.</p>

<p>This is the gist of the algorithm.
The practicalities of handling the event queue and many edge cases such as handling coincident lines and tricky annotation situations are described in the <a href="https://sean.fun/a/polygon-clipping-pt2/">article</a>.
From now, I will focus only on the <code class="language-plaintext highlighter-rouge">is_above?</code> algorithm to determine if a segment is above another segment.
I spent a long time debugging the whole Martinez-Martinez algorithm against a variety of test cases, and I always seemed to land back at this <code class="language-plaintext highlighter-rouge">is_above?</code> function. Getting this function right solved most of my problems.</p>

<h2 id="is-above">2 Is above?</h2>

<p>For reference, the function is called <a href="https://sean.fun/a/polygon-clipping-pt2/#finding-the-status-transition">statusCompare</a> in the article.</p>

<p>The goal of this algorithm is to sort lines by height.
This will then give a sweep status like:</p>

<figure class="post-figure" id="fig-sweep-status">
<img class="img-95" src="/assets/posts/polygon-clipping/sweep_status.png" alt="Lines sorted by height" />
<figcaption>Lines sorted by height.</figcaption>
</figure>

<p>Given line segments $AB$ and $CD$, it is tempting to sort them only by the starting point coordinate:</p>

\[y_A \geq y_C
\tag{2.1}
\label{eq:greater}\]

<p>This will work in most cases.
However already in the figure we can see an example where it does not.
Line 2’s starting point is below line 3’s, but it makes more sense to consider line 2 as “above” line 3.</p>

<p>A better definition for “above” is needed.
Instead, we will consider one segment above another if its starting point is above its projection on the other line:</p>

<figure class="post-figure" id="fig-project">
<img class="img-60" src="/assets/posts/polygon-clipping/projections-2.png" alt="Start points projected onto the other line" />
<figcaption></figcaption>
</figure>

<p>That is, $y_p \leq y_A$, where $y_p$ is:</p>

\[\begin{align}
  y_p &amp;= \frac{y_D - y_C}{x_D - x_C}(x_A - x_C) + y_C \\
\implies &amp; 0 \leq (y_A - y_C)(x_D - x_C) - (y_D - y_C)(x_A - x_C) \; ; x_D \neq x_C 
\end{align}
\tag{2.2}
\label{eq:projection}\]

<figure class="post-figure" id="fig-symmetry">
<img class="img-60" src="/assets/posts/polygon-clipping/symmetry.png" alt="Symmetry between projections" />
<figcaption></figcaption>
</figure>

<p>However one problem is this equation is not symmetrical. (This is the case in the original <a href="https://sean.fun/a/polygon-clipping-pt2/#finding-the-status-transition">statusCompare</a>.) 
In the figure above, both projections are below the other line.
Hence <code class="language-plaintext highlighter-rouge">is_above</code> will return false for both segments. Yet one must be above the other.
Therefore to maintain symmetry, the function will always only consider the right segment. 
If it is the segment of interest, we check if its starting point is above its projection on the left line.
Otherwise if we are checking if the left segment is above, we check if the right segment’s starting point’s is below its projection on the left segment.</p>

<p>There are two other special cases. The first is if the starting point is colinear or coincident with the other line:</p>

<figure class="post-figure" id="fig-coincident">
<img class="img-60" src="/assets/posts/polygon-clipping/coincident.png" alt="Coincident lines" />
<figcaption></figcaption>
</figure>

<p>In this case, the endpoint is used instead.</p>

<p>The second and final case is a vertical line:</p>

<figure class="post-figure" id="fig-vertical">
<img class="img-60" src="/assets/posts/polygon-clipping/vertical.png" alt="A vertical line" />
<figcaption></figcaption>
</figure>

<p>As implied by equation $\ref{eq:projection}$, if the line is vertical the projection equation is indeterminant.
In fact, if the line were slight sloped towards the left or towards the right, the answer would differ.
Here instead we will simply compare y-values. That is, fallback to $\ref{eq:greater}$.
(The original <a href="https://sean.fun/a/polygon-clipping-pt2/#finding-the-status-transition">statusCompare</a> did not account for this case.)</p>

<div class="card">
  <div class="card-body">
    <h5 class="card-title">No fallback</h5>
    <p class="card-text">
		If there is no fallback, then when $x_C=x_D$ equation $\ref{eq:projection}$ becomes:
    $$
    0 \geq (y_D - y_C)(x_A - x_C)
    $$
    In the algorithm vertical events are always constructed from bottom to top, so $y_D &gt; y_C$ and this becomes a test whether or not the $A$ is to the left of the vertical $CD$ segment.
	</p>
  </div>
</div>

<p>Hence the <code class="language-plaintext highlighter-rouge">is_above</code> algorithm is:</p>

<blockquote>
<u><b>Is segment AB above CD?</b></u> <br />
inputs: $AB$, $CD$ <br />
<b>if</b> colinear($A, C, D$) <br />
$\quad$ <b>return</b> point_above_line($B, CD$) <br />
<b>if</b> $x_C &lt; x_A$ <br />
$\quad$ <b>return</b> point_above_line($A, CD$) <br />
<b>else</b> <br />
$\quad$ <b>return not</b> point_above_line($C, AB$)
</blockquote>

<p>where <code class="language-plaintext highlighter-rouge">point_above_line</code> is:</p>
<blockquote>
<u><b>point_above_line</b></u> <br />
inputs: $P$, $CD$ <br />
<b>if</b> $x_C = x_D$ <br />
$\quad$ <b>return</b> $y_p \geq \text{min}(y_C, y_D)$ <br />
<b>return</b> $ (y_P - y_C)(x_D - x_C) - (y_D - y_C)(x_P - x_C) \geq 0$
</blockquote>

<p>This final algorithm is simple, but absolutely crucial for the algorithm.</p>

<h2 id="compare-events">3 Compare events</h2>

<p>For reference, the function is called <a href="https://sean.fun/a/polygon-clipping-pt2/#initializing-events">eventCompare</a> in the article.</p>

<p>It sorts segment events from left to right, bottom to top.
There are two events per segment: a start event and an end event.
An example ordering is:</p>

<figure class="post-figure" id="fig-events">
<img class="img-60" src="/assets/posts/polygon-clipping/event_queue.png" alt="Event queue" />
<figcaption></figcaption>
</figure>

<p>The algorithm is:</p>
<ol>
  <li>If the points are not the same, the smaller event is to the left, or the lower one if they are on a vertical line.</li>
  <li>If the other points are also the same, this event is not smaller. (Equal segments.)</li>
  <li>If the one is a start event and the other an end event, the end event is considered smaller. (Common points.)</li>
  <li>The smaller event is below the other one according to <code class="language-plaintext highlighter-rouge">not is_above</code>, unless the segment of interest is vertical, then the smaller event is “not above” if it is to the right.  (Common start/end points.)</li>
</ol>

<p>For example, in the picture:</p>
<ul>
  <li>Event 1 is smaller than event 2 by step 4: lower event is to right of a vertical segment.</li>
  <li>Event 2 is smaller than event 5 by step 1: they are on the same segment, but event 2 is defined by the lower start point.</li>
  <li>Event 3 is smaller than event 4 by step 4: common start point but segment 3 is lower than segment 4.</li>
  <li>Event 5 is smaller than event 6 by step 3: same point but event 5 is an end event, while event 6 is a start event.</li>
</ul>

<p>And so on.</p>

<h2 id="conclusion">4 Conclusion</h2>

<p>This was a short post to address minor issues and some improvements to two parts of the Martinez-Rueda implementation from 
<a href="https://sean.fun/a/polygon-clipping-pt2/">https://sean.fun/a/polygon-clipping-pt2/</a>.
Otherwise that article did a very good job at explaining this algorithm and I highly recommend it.</p>

<hr />]]></content><author><name>Lior Sinai</name></author><category term="mathematics" /><category term="polygons" /><summary type="html"><![CDATA[The Martinez-Rueda algorithm computes boolean operations between polygons. It can be used for polygon intersections (polygon clipping), unions, differences and XORs. I recently implemented it by following a comprehensive guide at https://sean.fun/a/polygon-clipping-pt2/. However, it was slightly lacking in some complex scenarios, mainly resulting from the strict ordering required by the Bentley-Ottmann line intersection algorithm. This post explains my minor modifications to address this crucial part of the algorithm.]]></summary></entry><entry><title type="html">MicroGrad.jl: Part 5 MLP</title><link href="https://liorsinai.github.io/machine-learning/2024/08/19/micrograd-5-mlp.html" rel="alternate" type="text/html" title="MicroGrad.jl: Part 5 MLP" /><published>2024-08-19T00:00:00+00:00</published><updated>2024-08-24T00:00:00+00:00</updated><id>https://liorsinai.github.io/machine-learning/2024/08/19/micrograd-5-mlp</id><content type="html" xml:base="https://liorsinai.github.io/machine-learning/2024/08/19/micrograd-5-mlp.html"><![CDATA[<p><em>A series on automatic differentiation in Julia. Part 5 shows how the MicroGrad.jl code can be used for a machine learning framework like Flux.jl. The working example is a multi-layer perceptron trained on the moons dataset.</em></p>

<p>This is part of a series. The other articles are:</p>
<ul>
  <li><a href="/machine-learning/2024/07/27/micrograd-1-chainrules">Part 1: ChainRules</a>.</li>
  <li><a href="/machine-learning/2024/08/03/micrograd-2-expr">Part 2: Automation with expressions</a>.</li>
  <li><a href="/machine-learning/2024/08/10/micrograd-3-ir">Part 3: Automation with IR</a>.</li>
  <li><a href="/machine-learning/2024/08/17/micrograd-4-ext">Part 4: Extensions</a>.</li>
</ul>

<p>All source code can be found at <a href="https://github.com/LiorSinai/MicroGrad.jl">MicroGrad.jl</a>.</p>

<h3 id="table-of-contents">Table of Contents</h3>

<nav id="toc"></nav>
<script src="/assets/makeTableOfContents.js"></script>

<h2 id="introduction">1 Introduction</h2>

<figure class="post-figure">
<img class="img-30" src="/assets/posts/micrograd/mlp.png" alt="multi-layer perceptron" />
<figcaption>A 2×6×2 multi-layer perceptron</figcaption>
</figure>

<p>The previous four sections have developed a minimal automatic differentiation package.
The aim of this part is to demonstrate how it can be used as the backbone for a machine learning framework like Flux.jl.
In this post we will create a multi-layer perceptron also known as a fully connected neural network.
This is an extremely popular and powerful machine learning model.
New code will be needed for the forward pass and for some extra <code class="language-plaintext highlighter-rouge">rrule</code>s. 
Otherwise, the rest is handled by code from the previous parts.</p>

<h2 id="moons-dataset">2 Moons dataset </h2>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/micrograd/moons.png" alt="Moons dataset" />
<figcaption></figcaption>
</figure>

<p>The moons dataset is a toy dataset for testing and visualising classification algorithms.
While clearly distinct, the curved nature of the two classes requires a non-linear algorithm to discern them. 
This was the dataset chosen by Karpathy to demonstrate his <a href="https://github.com/karpathy/micrograd">micrograd</a> package, and so it will be used here too.</p>

<p>This dataset can be reconstructed in Julia as follows, based on the <a href="https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html">Scikit-Learn</a> function:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">Random</span>

<span class="k">function</span><span class="nf"> make_moons</span><span class="x">(</span><span class="n">rng</span><span class="o">::</span><span class="kt">AbstractRNG</span><span class="x">,</span> <span class="n">n_samples</span><span class="o">::</span><span class="kt">Int</span><span class="o">=</span><span class="mi">100</span><span class="x">;</span> <span class="n">noise</span><span class="o">::</span><span class="kt">Union</span><span class="x">{</span><span class="kt">Nothing</span><span class="x">,</span> <span class="kt">AbstractFloat</span><span class="x">}</span><span class="o">=</span><span class="nb">nothing</span><span class="x">)</span>
    <span class="n">n_moons</span> <span class="o">=</span> <span class="n">floor</span><span class="x">(</span><span class="kt">Int</span><span class="x">,</span> <span class="n">n_samples</span> <span class="o">/</span> <span class="mi">2</span><span class="x">)</span>
    <span class="n">t_min</span> <span class="o">=</span> <span class="mf">0.0</span>
    <span class="n">t_max</span> <span class="o">=</span> <span class="nb">π</span>
    <span class="n">t_inner</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="n">rng</span><span class="x">,</span> <span class="n">n_moons</span><span class="x">)</span> <span class="o">*</span> <span class="x">(</span><span class="n">t_max</span> <span class="o">-</span> <span class="n">t_min</span><span class="x">)</span> <span class="o">.+</span> <span class="n">t_min</span>
    <span class="n">t_outer</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="n">rng</span><span class="x">,</span> <span class="n">n_moons</span><span class="x">)</span> <span class="o">*</span> <span class="x">(</span><span class="n">t_max</span> <span class="o">-</span> <span class="n">t_min</span><span class="x">)</span> <span class="o">.+</span> <span class="n">t_min</span>
    <span class="n">outer_circ_x</span> <span class="o">=</span> <span class="n">cos</span><span class="o">.</span><span class="x">(</span><span class="n">t_outer</span><span class="x">)</span>
    <span class="n">outer_circ_y</span> <span class="o">=</span> <span class="n">sin</span><span class="o">.</span><span class="x">(</span><span class="n">t_outer</span><span class="x">)</span>
    <span class="n">inner_circ_x</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">.-</span> <span class="n">cos</span><span class="o">.</span><span class="x">(</span><span class="n">t_inner</span><span class="x">)</span>
    <span class="n">inner_circ_y</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">.-</span> <span class="n">sin</span><span class="o">.</span><span class="x">(</span><span class="n">t_inner</span><span class="x">)</span> <span class="o">.-</span> <span class="mf">0.5</span>

    <span class="n">data</span> <span class="o">=</span> <span class="x">[</span><span class="n">outer_circ_x</span> <span class="n">outer_circ_y</span><span class="x">;</span> <span class="n">inner_circ_x</span> <span class="n">inner_circ_y</span><span class="x">]</span>
    <span class="n">z</span> <span class="o">=</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">data</span><span class="x">,</span> <span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">1</span><span class="x">))</span>
    <span class="k">if</span> <span class="o">!</span><span class="n">isnothing</span><span class="x">(</span><span class="n">noise</span><span class="x">)</span>
        <span class="n">z</span> <span class="o">+=</span> <span class="n">noise</span> <span class="o">*</span> <span class="n">randn</span><span class="x">(</span><span class="n">size</span><span class="x">(</span><span class="n">z</span><span class="x">))</span>
    <span class="k">end</span>
    <span class="n">z</span>
<span class="k">end</span>

<span class="n">make_moons</span><span class="x">(</span><span class="n">n_samples</span><span class="o">::</span><span class="kt">Int</span><span class="o">=</span><span class="mi">100</span><span class="x">;</span> <span class="n">options</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="n">make_moons</span><span class="x">(</span><span class="n">Random</span><span class="o">.</span><span class="n">default_rng</span><span class="x">(),</span> <span class="n">n_samples</span><span class="x">;</span> <span class="n">options</span><span class="o">...</span><span class="x">)</span></code></pre></figure>

<p>Creating the moons and labels:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">n</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">make_moons</span><span class="x">(</span><span class="mi">2</span><span class="n">n</span><span class="x">;</span> <span class="n">noise</span><span class="o">=</span><span class="mf">0.1</span><span class="x">)</span> <span class="c"># 2×200 Matrix </span>
<span class="n">y</span> <span class="o">=</span> <span class="n">vcat</span><span class="x">(</span><span class="n">fill</span><span class="x">(</span><span class="mi">1</span><span class="x">,</span> <span class="n">n</span><span class="x">)</span><span class="o">...</span><span class="x">,</span> <span class="n">fill</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="n">n</span><span class="x">)</span><span class="o">...</span><span class="x">)</span> <span class="c"># 200-element Vector{Int64}</span></code></pre></figure>

<h2 id="layers">3 Layers</h2>
<h3 id="relu">3.1 ReLU</h3>

<p>The Rectified Linear Unit (ReLU) is a common activation function in machine learning. It is defined as follows:</p>

\[\text{relu}(x)=\begin{cases} 
x, &amp; \text{if $x&gt; 0$}  \\
0, &amp; \text{otherwise}
\end{cases}\]

<p>This can be realised as a broadcast of the <code class="language-plaintext highlighter-rouge">max</code> function:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">relu</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">)</span> <span class="o">=</span> <span class="n">max</span><span class="o">.</span><span class="x">(</span><span class="mi">0</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span></code></pre></figure>

<p>The derivative is:</p>

\[\frac{\partial \text{relu}}{\partial x}=\begin{cases} 
1, &amp; \text{if $x&gt; 0$}  \\
0, &amp; \text{otherwise}
\end{cases}\]

<p>In code:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">relu</span><span class="x">),</span> <span class="n">x</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">)</span>
    <span class="n">relu_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span> <span class="o">=</span> <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="n">ifelse</span><span class="o">.</span><span class="x">(</span><span class="n">x</span> <span class="o">.&gt;</span> <span class="mi">0</span><span class="x">,</span> <span class="n">Δ</span><span class="x">,</span> <span class="mi">0</span><span class="x">))</span>
    <span class="n">relu</span><span class="x">(</span><span class="n">x</span><span class="x">),</span> <span class="n">relu_back</span>
<span class="k">end</span></code></pre></figure>

<h3 id="dense-layer">3.2 Dense layer</h3>

<p>The fully connected layer equation is:</p>

\[Y_{ij} = a\left(\sum_k (W_{ik}X_{kj} + b_{i}) \right)\]

<p>This is the code from Flux.jl to create this fully connected layer (<a href="https://github.com/FluxML/Flux.jl/blob/033f4b22c07d4bbd42fb3c13c2a138cecf722122/src/layers/basic.jl#L154">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">Random</span>
<span class="k">struct</span><span class="nc"> Dense</span><span class="x">{</span><span class="n">M</span><span class="o">&lt;:</span><span class="kt">AbstractMatrix</span><span class="x">,</span> <span class="n">B</span><span class="o">&lt;:</span><span class="kt">AbstractMatrix</span><span class="x">,</span> <span class="n">F</span><span class="x">}</span>
    <span class="n">weight</span><span class="o">::</span><span class="n">M</span>
    <span class="n">bias</span><span class="o">::</span><span class="n">B</span>
    <span class="n">activation</span><span class="o">::</span><span class="n">F</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> </span><span class="o">(</span><span class="n">a</span><span class="o">::</span><span class="n">Dense</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">)</span>
    <span class="n">a</span><span class="o">.</span><span class="n">activation</span><span class="x">(</span><span class="n">a</span><span class="o">.</span><span class="n">weight</span> <span class="o">*</span> <span class="n">x</span> <span class="o">.+</span> <span class="n">a</span><span class="o">.</span><span class="n">bias</span><span class="x">)</span>
<span class="k">end</span>

<span class="n">Dense</span><span class="x">((</span><span class="k">in</span><span class="x">,</span> <span class="n">out</span><span class="x">)</span><span class="o">::</span><span class="kt">Pair</span><span class="x">;</span> <span class="n">activation</span><span class="o">=</span><span class="n">relu</span><span class="x">)</span> <span class="o">=</span> <span class="n">Dense</span><span class="x">(</span><span class="n">glorot_uniform</span><span class="x">(</span><span class="k">in</span><span class="x">,</span> <span class="n">out</span><span class="x">),</span> <span class="n">zeros</span><span class="x">(</span><span class="n">out</span><span class="x">,</span> <span class="mi">1</span><span class="x">),</span> <span class="n">activation</span><span class="x">)</span>

<span class="k">function</span><span class="nf"> glorot_uniform</span><span class="x">(</span><span class="n">rng</span><span class="o">::</span><span class="kt">AbstractRNG</span><span class="x">,</span> <span class="n">fan_in</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">fan_out</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span>
    <span class="n">scale</span> <span class="o">=</span> <span class="n">sqrt</span><span class="x">(</span><span class="mi">24</span> <span class="o">/</span> <span class="x">(</span><span class="n">fan_in</span> <span class="o">+</span> <span class="n">fan_out</span><span class="x">))</span>  <span class="c"># 0.5 * sqrt(24) = sqrt(1/4 * 24) = sqrt(6)</span>
    <span class="x">(</span><span class="n">rand</span><span class="x">(</span><span class="n">rng</span><span class="x">,</span> <span class="n">fan_out</span><span class="x">,</span> <span class="n">fan_in</span><span class="x">)</span> <span class="o">.-</span> <span class="mf">0.5</span><span class="x">)</span> <span class="o">.*</span> <span class="n">scale</span>
<span class="k">end</span>

<span class="n">glorot_uniform</span><span class="x">(</span><span class="n">fan_in</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">fan_out</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span> <span class="o">=</span> <span class="n">glorot_uniform</span><span class="x">(</span><span class="n">Random</span><span class="o">.</span><span class="n">default_rng</span><span class="x">(),</span> <span class="n">fan_in</span><span class="x">,</span> <span class="n">fan_out</span><span class="x">)</span></code></pre></figure>

<p>Also add a method to <code class="language-plaintext highlighter-rouge">paramaters</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">parameters</span><span class="x">(</span><span class="n">a</span><span class="o">::</span><span class="n">Dense</span><span class="x">)</span> <span class="o">=</span> <span class="x">(;</span><span class="n">weight</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">weight</span><span class="x">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">a</span><span class="o">.</span><span class="n">bias</span><span class="x">)</span></code></pre></figure>

<p>Create and test:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">X</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">Dense</span><span class="x">(</span><span class="mi">2</span> <span class="o">=&gt;</span> <span class="mi">3</span><span class="x">;</span> <span class="n">activation</span><span class="o">=</span><span class="n">relu</span><span class="x">)</span>
<span class="n">layer</span><span class="x">(</span><span class="n">X</span><span class="x">)</span> <span class="c"># 3×3 Matrix{Float64}</span></code></pre></figure>

<h3 id="reverse-broadcast">3.3 Reverse broadcast</h3>

<p>Inspect the IR <code class="language-plaintext highlighter-rouge">@code_ir layer(X)</code>:</p>

<figure class="highlight"><pre><code class="language-plaintext" data-lang="plaintext">1: (%1, %2)
  %3 = Base.getproperty(%1, :activation)
  %4 = Main.:+
  %5 = Base.getproperty(%1, :weight)
  %6 = %5 * %2
  %7 = Base.getproperty(%1, :bias)
  %8 = Base.broadcasted(%4, %6, %7)
  %9 = Base.materialize(%8)
  %10 = (%3)(%9)
  return %10</code></pre></figure>

<p>From part 1 and part 4 we have <code class="language-plaintext highlighter-rouge">rrule</code>s for <code class="language-plaintext highlighter-rouge">getproperty</code> (<code class="language-plaintext highlighter-rouge">getfield</code>), matrix multiplication (<code class="language-plaintext highlighter-rouge">*</code>) and for the activation (<code class="language-plaintext highlighter-rouge">relu</code>). We still need <code class="language-plaintext highlighter-rouge">rrule</code>s for <code class="language-plaintext highlighter-rouge">broadcasted</code> and <code class="language-plaintext highlighter-rouge">materialize</code>.</p>

<p>Creating rules for broadcasting in general is complex<sup id="fnref:broadcast" role="doc-noteref"><a href="#fn:broadcast" class="footnote" rel="footnote">1</a></sup>, so instead create a specific rule for the broadcast invoked here:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">Broadcast</span><span class="o">.</span><span class="n">broadcasted</span><span class="x">),</span> <span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="o">+</span><span class="x">),</span> <span class="n">A</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">{</span><span class="o">&lt;:</span><span class="kt">Real</span><span class="x">},</span> <span class="n">B</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">{</span><span class="o">&lt;:</span><span class="kt">Real</span><span class="x">})</span>
    <span class="n">broadcast_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span> <span class="o">=</span> <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="nb">nothing</span><span class="x">,</span> <span class="n">unbroadcast</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="n">Δ</span><span class="x">),</span> <span class="n">unbroadcast</span><span class="x">(</span><span class="n">B</span><span class="x">,</span> <span class="n">Δ</span><span class="x">))</span>
    <span class="n">broadcast</span><span class="x">(</span><span class="o">+</span><span class="x">,</span> <span class="n">A</span><span class="x">,</span> <span class="n">B</span><span class="x">),</span> <span class="n">broadcast_back</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> unbroadcast</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">,</span> <span class="n">x̄</span><span class="x">)</span>
    <span class="k">if</span> <span class="n">length</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">==</span> <span class="n">length</span><span class="x">(</span><span class="n">x̄</span><span class="x">)</span>
        <span class="n">x̄</span>
    <span class="k">else</span>
      <span class="n">dims</span> <span class="o">=</span> <span class="n">ntuple</span><span class="x">(</span><span class="n">d</span> <span class="o">-&gt;</span> <span class="n">size</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">d</span><span class="x">)</span> <span class="o">==</span> <span class="mi">1</span> <span class="o">?</span> <span class="n">d</span> <span class="o">:</span> <span class="n">ndims</span><span class="x">(</span><span class="n">x̄</span><span class="x">)</span><span class="o">+</span><span class="mi">1</span><span class="x">,</span> <span class="n">ndims</span><span class="x">(</span><span class="n">x̄</span><span class="x">))</span>
      <span class="n">dx</span> <span class="o">=</span> <span class="n">sum</span><span class="x">(</span><span class="n">x̄</span><span class="x">;</span> <span class="n">dims</span> <span class="o">=</span> <span class="n">dims</span><span class="x">)</span>
      <span class="n">check_dims</span><span class="x">(</span><span class="n">size</span><span class="x">(</span><span class="n">x</span><span class="x">),</span> <span class="n">size</span><span class="x">(</span><span class="n">dx</span><span class="x">))</span>
      <span class="n">dx</span>
    <span class="k">end</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> check_dims</span><span class="x">(</span><span class="n">size_x</span><span class="x">,</span> <span class="n">size_dx</span><span class="x">)</span> <span class="c"># see ChainRulesCore.ProjectTo</span>
    <span class="k">for</span> <span class="x">(</span><span class="n">i</span><span class="x">,</span> <span class="n">d</span><span class="x">)</span> <span class="k">in</span> <span class="n">enumerate</span><span class="x">(</span><span class="n">size_x</span><span class="x">)</span>
        <span class="n">dd</span> <span class="o">=</span> <span class="n">i</span> <span class="o">&lt;=</span> <span class="n">length</span><span class="x">(</span><span class="n">size_dx</span><span class="x">)</span> <span class="o">?</span> <span class="n">size_dx</span><span class="x">[</span><span class="n">i</span><span class="x">]</span> <span class="o">:</span> <span class="mi">1</span> <span class="c"># broadcasted dim</span>
        <span class="k">if</span> <span class="n">d</span> <span class="o">!=</span> <span class="n">dd</span> 
            <span class="n">throw</span><span class="x">(</span><span class="kt">DimensionMismatch</span><span class="x">(</span><span class="s">"variable with size(x) == </span><span class="si">$</span><span class="s">size_x cannot have a gradient with size(dx) == </span><span class="si">$</span><span class="s">size_dx"</span><span class="x">))</span>
        <span class="k">end</span>
    <span class="k">end</span>
<span class="k">end</span></code></pre></figure>

<p>Testing:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">X</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">2</span><span class="x">)</span>
<span class="n">Z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">rrule</span><span class="x">(</span><span class="n">Base</span><span class="o">.</span><span class="n">broadcasted</span><span class="x">,</span> <span class="o">+</span><span class="x">,</span> <span class="n">X</span><span class="x">,</span> <span class="n">b</span><span class="x">)</span> <span class="c"># (2×4 Matrix{Float64}, broadcast_back)</span>
<span class="n">back</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">))</span> <span class="c"># (nothing, nothing, ones(2, 4), [4.0; 4.0;;])</span></code></pre></figure>

<p>The definition for <code class="language-plaintext highlighter-rouge">Base.Broadcast.materialize</code> is:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="nd">@inline</span> <span class="n">materialize</span><span class="x">(</span><span class="n">bc</span><span class="o">::</span><span class="n">Broadcasted</span><span class="x">)</span> <span class="o">=</span> <span class="n">copy</span><span class="x">(</span><span class="n">instantiate</span><span class="x">(</span><span class="n">bc</span><span class="x">))</span>
<span class="n">materialize</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span></code></pre></figure>

<p>Hence we need <code class="language-plaintext highlighter-rouge">rrule</code>s for <code class="language-plaintext highlighter-rouge">copy</code> and <code class="language-plaintext highlighter-rouge">instantiate</code> (<a href="https://github.com/JuliaDiff/ChainRules.jl/blob/dba6cb57d73ba837c5ab6fd1f968f3a5d301ca9c/src/rulesets/Base/broadcast.jl#L5">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">copy</span><span class="x">),</span> <span class="n">bc</span><span class="o">::</span><span class="n">Broadcast</span><span class="o">.</span><span class="n">Broadcasted</span><span class="x">)</span>
    <span class="n">uncopy</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span> <span class="o">=</span> <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span>
    <span class="k">return</span> <span class="n">copy</span><span class="x">(</span><span class="n">bc</span><span class="x">),</span> <span class="n">uncopy</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">Broadcast</span><span class="o">.</span><span class="n">instantiate</span><span class="x">),</span> <span class="n">bc</span><span class="o">::</span><span class="n">Broadcast</span><span class="o">.</span><span class="n">Broadcasted</span><span class="x">)</span>
    <span class="n">uninstantiate</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span> <span class="o">=</span> <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span>
    <span class="k">return</span> <span class="n">Broadcast</span><span class="o">.</span><span class="n">instantiate</span><span class="x">(</span><span class="n">bc</span><span class="x">),</span> <span class="n">uninstantiate</span>
<span class="k">end</span></code></pre></figure>

<p>Now the pullback for the <code class="language-plaintext highlighter-rouge">Dense</code> layer works:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">Y</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">layer</span><span class="x">,</span> <span class="n">X</span><span class="x">)</span> <span class="c"># (3×4 Matrix, Pullback)</span>
<span class="n">back</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">3</span><span class="x">,</span> <span class="mi">4</span><span class="x">))</span> <span class="c"># ((;weight=...,bias=...,activation=nothing), 2×4 Matrix)</span>
<span class="n">Y</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">m</span><span class="o">-&gt;</span><span class="n">m</span><span class="x">(</span><span class="n">X</span><span class="x">),</span> <span class="n">layer</span><span class="x">)</span> <span class="c"># (3×4 Matrix, Pullback)</span>
<span class="n">back</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">3</span><span class="x">,</span> <span class="mi">4</span><span class="x">))</span> <span class="c"># (nothing, (;weight=...,bias=...,activation=nothing))</span></code></pre></figure>

<h3 id="chain">3.4 Chain </h3>

<p>Here is the Flux code to create a generic chain (<a href="https://github.com/FluxML/Flux.jl/blob/033f4b22c07d4bbd42fb3c13c2a138cecf722122/src/layers/basic.jl#L35">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> Chain</span><span class="x">{</span><span class="n">T</span><span class="o">&lt;:</span><span class="kt">Tuple</span><span class="x">}</span>
    <span class="n">layers</span><span class="o">::</span><span class="n">T</span>
<span class="k">end</span>
  
<span class="n">Chain</span><span class="x">(</span><span class="n">xs</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="n">Chain</span><span class="x">(</span><span class="n">xs</span><span class="x">)</span>

<span class="x">(</span><span class="n">c</span><span class="o">::</span><span class="n">Chain</span><span class="x">)(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">_apply_chain</span><span class="x">(</span><span class="n">c</span><span class="o">.</span><span class="n">layers</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span>

<span class="nd">@generated</span> <span class="k">function</span><span class="nf"> _apply_chain</span><span class="x">(</span><span class="n">layers</span><span class="o">::</span><span class="kt">Tuple</span><span class="x">{</span><span class="kt">Vararg</span><span class="x">{</span><span class="kt">Any</span><span class="x">,</span><span class="n">N</span><span class="x">}},</span> <span class="n">x</span><span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">N</span><span class="x">}</span>
  <span class="n">symbols</span> <span class="o">=</span> <span class="n">vcat</span><span class="x">(</span><span class="o">:</span><span class="n">x</span><span class="x">,</span> <span class="x">[</span><span class="n">gensym</span><span class="x">()</span> <span class="k">for</span> <span class="n">_</span> <span class="k">in</span> <span class="mi">1</span><span class="o">:</span><span class="n">N</span><span class="x">])</span>
  <span class="n">calls</span> <span class="o">=</span> <span class="x">[</span><span class="o">:</span><span class="x">(</span><span class="o">$</span><span class="x">(</span><span class="n">symbols</span><span class="x">[</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="x">])</span> <span class="o">=</span> <span class="n">layers</span><span class="x">[</span><span class="o">$</span><span class="n">i</span><span class="x">](</span><span class="o">$</span><span class="x">(</span><span class="n">symbols</span><span class="x">[</span><span class="n">i</span><span class="x">])))</span> <span class="k">for</span> <span class="n">i</span> <span class="k">in</span> <span class="mi">1</span><span class="o">:</span><span class="n">N</span><span class="x">]</span>
  <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">block</span><span class="x">,</span> <span class="n">calls</span><span class="o">...</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>Add a method to <code class="language-plaintext highlighter-rouge">parameters</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">parameters</span><span class="x">(</span><span class="n">c</span><span class="o">::</span><span class="n">Chain</span><span class="x">)</span> <span class="o">=</span> <span class="x">(;</span><span class="n">layers</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">parameters</span><span class="x">,</span> <span class="n">c</span><span class="o">.</span><span class="n">layers</span><span class="x">))</span></code></pre></figure>

<p>We will need an <code class="language-plaintext highlighter-rouge">rrule</code> for <code class="language-plaintext highlighter-rouge">getindex</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">world</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">get_world_counter</span><span class="x">()</span>
<span class="n">pr1</span> <span class="o">=</span> <span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="n">_apply_chain</span><span class="x">),</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">cos</span><span class="x">),</span> <span class="n">typeof</span><span class="x">(</span><span class="n">sin</span><span class="x">)},</span> <span class="kt">Float64</span><span class="x">)</span></code></pre></figure>

<p>It is as follows (<a href="https://github.com/JuliaDiff/ChainRules.jl/blob/dba6cb57d73ba837c5ab6fd1f968f3a5d301ca9c/src/rulesets/Base/indexing.jl#L22">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">getindex</span><span class="x">),</span> <span class="n">x</span><span class="o">::</span><span class="n">T</span><span class="x">,</span> <span class="n">i</span><span class="o">::</span><span class="kt">Integer</span><span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">T</span><span class="o">&lt;:</span><span class="kt">Tuple</span><span class="x">}</span>
    <span class="k">function</span><span class="nf"> getindex_back_1</span><span class="x">(</span><span class="n">Δy</span><span class="x">)</span>
        <span class="n">dx</span> <span class="o">=</span> <span class="n">ntuple</span><span class="x">(</span><span class="n">j</span> <span class="o">-&gt;</span> <span class="n">j</span> <span class="o">==</span> <span class="n">i</span> <span class="o">?</span> <span class="n">Δy</span> <span class="o">:</span> <span class="nb">nothing</span><span class="x">,</span> <span class="n">length</span><span class="x">(</span><span class="n">x</span><span class="x">))</span>
        <span class="k">return</span> <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="x">(</span><span class="n">dx</span><span class="o">...</span><span class="x">,),</span> <span class="nb">nothing</span><span class="x">)</span>
    <span class="k">end</span>
    <span class="k">return</span> <span class="n">x</span><span class="x">[</span><span class="n">i</span><span class="x">],</span> <span class="n">getindex_back_1</span>
<span class="k">end</span></code></pre></figure>

<p>Test (compare the results in <a href="/machine-learning/2024/07/27/micrograd-1-chainrules#chainrules-trigonometry">part 1</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">model</span> <span class="o">=</span> <span class="n">Chain</span><span class="x">(</span><span class="n">cos</span><span class="x">,</span> <span class="n">sin</span><span class="x">)</span>
<span class="n">model</span><span class="x">(</span><span class="mf">0.9</span><span class="x">)</span> <span class="c"># 0.5823</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">model</span><span class="x">,</span> <span class="mf">0.9</span><span class="x">)</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># ((layers=(nothing, nothing),), -0.6368)</span></code></pre></figure>

<p>Test a multi-layer perceptron:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">model</span> <span class="o">=</span> <span class="n">Chain</span><span class="x">(</span>
    <span class="n">Dense</span><span class="x">(</span><span class="mi">2</span> <span class="o">=&gt;</span> <span class="mi">16</span><span class="x">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">relu</span><span class="x">),</span>
    <span class="n">Dense</span><span class="x">(</span><span class="mi">16</span> <span class="o">=&gt;</span> <span class="mi">16</span><span class="x">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">relu</span><span class="x">),</span>
    <span class="n">Dense</span><span class="x">(</span><span class="mi">16</span><span class="o">=&gt;</span><span class="mi">2</span><span class="x">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">relu</span><span class="x">)</span>
<span class="x">)</span>
<span class="n">model</span><span class="x">(</span><span class="n">X</span><span class="x">)</span> <span class="c"># 2×4 Matrix</span>
<span class="n">Z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">m</span><span class="o">-&gt;</span><span class="n">m</span><span class="x">(</span><span class="n">X</span><span class="x">),</span> <span class="n">model</span><span class="x">)</span>  <span class="c"># (2×4 Matrix, Pullback)</span>
<span class="n">back</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">))</span> <span class="c"># (nothing, (layers=((weight=...), (weight=...), (weight=...))))</span></code></pre></figure>

<h2 id="loss">4 Loss</h2>
<h3 id="Cross-entropy">4.1 Cross entropy</h3>

<p>The output of the machine learning model will be a probability $p_j$ for a sample $j$ being in a certain class. This will be compared to a probability for a known label $y_j$, which is either 1 if that sample is in the class or 0 if it is not.
An obvious value to maximise is their product:</p>

\[y_j p_j
\tag{4.1}\]

<p>with range $[0, 1]$.</p>

<figure class="post-figure">
<img class="img-30" src="/assets/posts/micrograd/logx.png" alt="-log(x)" align="right" />
<figcaption></figcaption>
</figure>

<p>However most machine learning optimisation algorithms aim to minimise a loss.
So instead $p_j$ is scaled as $-\log(p_j)$, so that the loss ranges from $[0, \infty)$ with the goal to minimise it at 0.
This is called the cross entropy loss:</p>

\[L(p_j, y_j) = -y_j \log(p_j)
\tag{4.2}
\label{eq:cross_entropy}\]

<h3 id="logit-ross-entropy">4.2 Logit cross entropy</h3>

<p>The outputs of the neural network are not probabilities but instead a vector of logits containing $N$ real values for $N$ classes.
By convention these logits are scaled to a probability distribution using the softmax function:</p>

\[s(x)_i = \frac{e^{x_i}}{\sum_{r=1}^{N} e^{x_r}}
\tag{4.3}
\label{eq:softmax}\]

<p>Combining equations $\ref{eq:cross_entropy}$ and $\ref{eq:softmax}$ and taking a mean across samples gives the mean logit cross entropy loss:</p>

\[\begin{align}
 L(x, y) &amp;= -\frac{1}{n}\sum_{j=1}^n \sum_{i=1}^N y_{ij} z_{ij} \\
         &amp;= -\frac{1}{n}\sum_{j=1}^n \sum_{i=1}^N y_{ij} \left(x_{ij} - \log\left(\sum_{r=1}^{N} e^{x_{rj}}\right) \right)
\end{align}
\tag{4.4}
\label{eq:logit_cross_entropy}\]

<p>where $z_{ij}$ is the output of the logsoftmax function. Assuming that $y_{ij}$ is 1 for exactly one value of $i$ and 0 otherwise, this can be simplified to:</p>

\[\begin{align}
 L(x, y) = -\frac{1}{n}\sum_{j=1}^n \left(x_{j} - \log\left(\sum_{r=1}^{N} e^{x_{rj}}\right) \right)
\end{align}
\tag{4.5}
\label{eq:logit_cross_entropy_2}\]

<p>In Julia this can be implemented as follows (<a href="https://github.com/FluxML/Flux.jl/blob/dd9b644c9b71d313749d9ab139334ac16df6488e/src/losses/functions.jl#L273">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">StatsBase</span>
<span class="n">logsoftmax</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span> <span class="o">.-</span> <span class="n">log</span><span class="o">.</span><span class="x">(</span><span class="n">sum</span><span class="x">(</span><span class="n">exp</span><span class="o">.</span><span class="x">(</span><span class="n">x</span><span class="x">),</span> <span class="n">dims</span><span class="o">=</span><span class="mi">1</span><span class="x">))</span>
<span class="k">function</span><span class="nf"> logit_cross_entropy</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">)</span>
    <span class="n">mean</span><span class="x">(</span><span class="o">-</span><span class="n">sum</span><span class="x">(</span><span class="n">y</span> <span class="o">.*</span> <span class="n">logsoftmax</span><span class="x">(</span><span class="n">x</span><span class="x">),</span> <span class="n">dims</span><span class="o">=</span><span class="mi">1</span><span class="x">))</span>
<span class="k">end</span></code></pre></figure>

<p>According to the multivariable chain rule, the derivative with respect to one logit $x_{ij}$ in the vector for sample $j$ is (gradients come from the main case $k=i$ case as well as the sum in the softmax for $k\neq i$):</p>

\[\begin{align}
\frac{\partial L}{\partial x_{ij}} &amp;= \sum_{k=1}^N \frac{\partial L}{\partial z_{kj}} \frac{\partial z_{kj}}{\partial x_{ij}} \\
 &amp;= \sum_{k=1}^N \left( -\frac{y_{kj}\Delta}{n}  \frac{\partial}{\partial x_{ij}}\left(x_{kj} - \log\left(\sum_{r=1}^{N} e^{x_{rj}}\right) \right) \right) \\
 &amp;= \sum_{k=1}^N \left(-\frac{y_{kj} \Delta}{n} \left(\delta_{ij} - \frac{e^{x_{ij}}}{\sum_{r=1}^{N} e^{x_{rj}}} \right) \right) \\
 &amp;= -\frac{\Delta}{n}  \left(y_{ij} - s(x_j)_{i} \sum_{k=1}^N y_{kj}\right)
 \end{align}
\tag{4.6}
\label{eq:back_logitcrossentropy}\]

<p>where $\delta_{ij}$ is the Kronecker delta. Assuming that $y_{kj}$ is 1 for one value of $k$ and 0 otherwise, this simplifies too:</p>

\[\begin{align}
\frac{\partial L}{\partial x_{ij}} &amp;= -\frac{\Delta}{n}(y_{ij} - s(x_j)_{i})
 \end{align}
\tag{4.7}
\label{eq:back_logitcrossentropy_2}\]

<p>In Julia this can be implemented as follows (<a href="https://github.com/FluxML/NNlib.jl/blob/013aa51f7ff9c2e035afa8763b5d02e105d81b78/src/softmax.jl#L123">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">logsoftmax</span><span class="x">),</span> <span class="n">x</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">)</span>
    <span class="n">expx</span> <span class="o">=</span> <span class="n">exp</span><span class="o">.</span><span class="x">(</span><span class="n">x</span><span class="x">)</span>
    <span class="n">Σ</span> <span class="o">=</span> <span class="n">sum</span><span class="x">(</span><span class="n">expx</span><span class="x">,</span> <span class="n">dims</span><span class="o">=</span><span class="mi">1</span><span class="x">)</span>
    <span class="k">function</span><span class="nf"> logsoftmax_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span>
        <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="n">Δ</span> <span class="o">.-</span> <span class="n">sum</span><span class="x">(</span><span class="n">Δ</span><span class="x">;</span> <span class="n">dims</span><span class="o">=</span><span class="mi">1</span><span class="x">)</span> <span class="o">.*</span> <span class="n">expx</span> <span class="o">./</span> <span class="n">Σ</span><span class="x">)</span>
    <span class="k">end</span>
    <span class="n">x</span> <span class="o">.-</span> <span class="n">log</span><span class="o">.</span><span class="x">(</span><span class="n">Σ</span><span class="x">),</span> <span class="n">logsoftmax_back</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">logit_cross_entropy</span><span class="x">),</span>  <span class="n">x</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">)</span>
    <span class="n">ls</span><span class="x">,</span> <span class="n">logsoftmax_back</span> <span class="o">=</span> <span class="n">rrule</span><span class="x">(</span><span class="n">logsoftmax</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span>
    <span class="k">function</span><span class="nf"> logit_cross_entropy_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span>
        <span class="n">size_ls</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">ls</span><span class="x">)</span>
        <span class="n">n</span> <span class="o">=</span> <span class="n">length</span><span class="x">(</span><span class="n">size_ls</span><span class="x">)</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="o">?</span> <span class="n">prod</span><span class="x">(</span><span class="n">size</span><span class="x">(</span><span class="n">ls</span><span class="x">)[</span><span class="mi">2</span><span class="o">:</span><span class="k">end</span><span class="x">])</span> <span class="o">:</span> <span class="mi">1</span>
        <span class="n">∂x</span> <span class="o">=</span> <span class="n">logsoftmax_back</span><span class="x">(</span><span class="o">-</span><span class="n">y</span> <span class="o">*</span> <span class="n">Δ</span><span class="o">/</span><span class="n">n</span><span class="x">)[</span><span class="mi">2</span><span class="x">]</span>
        <span class="n">∂y</span> <span class="o">=</span> <span class="o">-</span><span class="n">Δ</span><span class="o">/</span><span class="n">n</span> <span class="o">.*</span> <span class="n">ls</span>
        <span class="k">return</span> <span class="nb">nothing</span><span class="x">,</span> <span class="n">∂x</span> <span class="x">,</span> <span class="n">∂y</span>
    <span class="k">end</span>
    <span class="n">mean</span><span class="x">(</span><span class="o">-</span><span class="n">sum</span><span class="x">(</span><span class="n">y</span> <span class="o">.*</span> <span class="n">ls</span><span class="x">,</span> <span class="n">dims</span> <span class="o">=</span> <span class="mi">1</span><span class="x">)),</span> <span class="n">logit_cross_entropy_back</span>
<span class="k">end</span></code></pre></figure>

<p>Testing:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">y1</span><span class="x">,</span> <span class="n">y2</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">4</span><span class="x">),</span> <span class="n">rand</span><span class="x">(</span><span class="mi">4</span><span class="x">)</span>
<span class="n">l</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">logit_cross_entropy</span><span class="x">,</span> <span class="n">y1</span><span class="x">,</span> <span class="n">y2</span><span class="x">)</span> <span class="c"># (2.69, logit_cross_entropy_back)</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># (nothing, [0.4,...], [1.37,...] )</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">)</span> 
<span class="n">Y</span> <span class="o">=</span> <span class="x">[</span><span class="mf">1.0</span> <span class="mf">1.0</span> <span class="mf">0.0</span> <span class="mf">0.0</span> <span class="x">;</span> <span class="mf">0.0</span> <span class="mf">0.0</span> <span class="mf">1.0</span> <span class="mf">1.0</span><span class="x">]</span> <span class="c"># one hot encoded</span>
<span class="n">l</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">logit_cross_entropy</span><span class="x">,</span> <span class="n">X</span><span class="x">,</span> <span class="n">Y</span><span class="x">)</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># (nothing, 2×4 Matrix, 2×4 Matrix)</span></code></pre></figure>

<h2 id="train-and-evaluate">5 Train and Evaluate</h2>
<h3 id="train">5.1 Train </h3>

<p>Create the moons data and labels:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">n</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">make_moons</span><span class="x">(</span><span class="mi">2</span><span class="n">n</span><span class="x">;</span> <span class="n">noise</span><span class="o">=</span><span class="mf">0.1</span><span class="x">)</span> <span class="c"># 2×200 Matrix </span>
<span class="n">y</span> <span class="o">=</span> <span class="n">vcat</span><span class="x">(</span><span class="n">fill</span><span class="x">(</span><span class="mi">1</span><span class="x">,</span> <span class="n">n</span><span class="x">)</span><span class="o">...</span><span class="x">,</span> <span class="n">fill</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="n">n</span><span class="x">)</span><span class="o">...</span><span class="x">)</span> <span class="c"># 200-element Vector{Int64}</span></code></pre></figure>

<p>Convert the labels to a one hot presentation:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> onehot</span><span class="x">(</span><span class="n">y</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">,</span> <span class="n">labels</span><span class="x">)</span>
    <span class="n">num_classes</span> <span class="o">=</span> <span class="n">maximum</span><span class="x">(</span><span class="n">labels</span><span class="x">)</span>
    <span class="n">Y</span> <span class="o">=</span> <span class="n">zeros</span><span class="x">(</span><span class="n">num_classes</span><span class="x">,</span> <span class="n">length</span><span class="x">(</span><span class="n">y</span><span class="x">))</span>
    <span class="k">for</span> <span class="x">(</span><span class="n">j</span><span class="x">,</span> <span class="n">label</span><span class="x">)</span> <span class="k">in</span> <span class="n">enumerate</span><span class="x">(</span><span class="n">y</span><span class="x">)</span>
        <span class="n">Y</span><span class="x">[</span><span class="n">label</span><span class="x">,</span> <span class="n">j</span><span class="x">]</span> <span class="o">+=</span> <span class="mi">1</span>
    <span class="k">end</span>
    <span class="n">Y</span>
<span class="k">end</span>
<span class="n">Y</span> <span class="o">=</span> <span class="n">onehot</span><span class="x">(</span><span class="n">y</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="mi">2</span><span class="x">)</span></code></pre></figure>

<p>Create the model:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">model</span> <span class="o">=</span> <span class="n">Chain</span><span class="x">(</span>
    <span class="n">Dense</span><span class="x">(</span><span class="mi">2</span> <span class="o">=&gt;</span> <span class="mi">16</span><span class="x">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">relu</span><span class="x">),</span>
    <span class="n">Dense</span><span class="x">(</span><span class="mi">16</span> <span class="o">=&gt;</span> <span class="mi">16</span><span class="x">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">relu</span><span class="x">),</span>
    <span class="n">Dense</span><span class="x">(</span><span class="mi">16</span><span class="o">=&gt;</span><span class="mi">2</span><span class="x">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">relu</span><span class="x">)</span>
<span class="x">)</span></code></pre></figure>

<p>Test the loss function:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">l</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">m</span><span class="o">-&gt;</span><span class="n">logit_cross_entropy</span><span class="x">(</span><span class="n">m</span><span class="x">(</span><span class="n">X</span><span class="x">),</span> <span class="n">Y</span><span class="x">),</span> <span class="n">model</span><span class="x">);</span> <span class="c"># (0.69, Pullback{...}(...))</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># (nothing, layers=((weight=...),(weight=...),(weight=...),))</span></code></pre></figure>

<p>Use the exact same <code class="language-plaintext highlighter-rouge">gradient_descent!</code> function from <a href="/machine-learning/2024/08/17/micrograd-4-ext#generic-gradient-descent">part 4</a>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">history</span> <span class="o">=</span> <span class="n">gradient_descent!</span><span class="x">(</span>
    <span class="n">model</span><span class="x">,</span> <span class="n">logit_cross_entropy</span><span class="x">,</span> <span class="n">X</span><span class="x">,</span> <span class="n">Y</span>
    <span class="x">;</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.9</span><span class="x">,</span> <span class="n">max_iters</span><span class="o">=</span><span class="mi">200</span>
<span class="x">)</span></code></pre></figure>

<h3 id="evaluate">5.2 Evaluate </h3>

<p>Plot the history:</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/micrograd/moons_history.png" alt="Training history" />
<figcaption>Training history</figcaption>
</figure>

<p>Calculate accuracy:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">Y_pred</span> <span class="o">=</span> <span class="n">model</span><span class="x">(</span><span class="n">X</span><span class="x">)</span>
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">vec</span><span class="x">(</span><span class="n">map</span><span class="x">(</span><span class="n">idx</span> <span class="o">-&gt;</span> <span class="n">idx</span><span class="x">[</span><span class="mi">1</span><span class="x">],</span> <span class="n">argmax</span><span class="x">(</span><span class="n">Y_pred</span><span class="x">,</span> <span class="n">dims</span><span class="o">=</span><span class="mi">1</span><span class="x">)))</span>
<span class="n">mean</span><span class="x">(</span><span class="n">y_pred</span> <span class="o">.==</span> <span class="n">y</span><span class="x">)</span> <span class="c"># 100%</span></code></pre></figure>

<p>Plot decision boundary:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">Plots</span>
<span class="n">xmin</span><span class="x">,</span> <span class="n">xmax</span> <span class="o">=</span> <span class="n">extrema</span><span class="x">(</span><span class="n">X</span><span class="x">[</span><span class="mi">1</span><span class="x">,</span> <span class="o">:</span><span class="x">])</span>
<span class="n">ymin</span><span class="x">,</span> <span class="n">ymax</span> <span class="o">=</span> <span class="n">extrema</span><span class="x">(</span><span class="n">X</span><span class="x">[</span><span class="mi">2</span><span class="x">,</span> <span class="o">:</span><span class="x">])</span>
<span class="n">h</span> <span class="o">=</span> <span class="mf">0.01</span>
<span class="n">xrange</span> <span class="o">=</span> <span class="x">(</span><span class="n">xmin</span><span class="o">-</span><span class="mf">0.1</span><span class="x">)</span><span class="o">:</span><span class="n">h</span><span class="o">:</span><span class="x">(</span><span class="n">xmax</span><span class="o">+</span><span class="mf">0.1</span><span class="x">)</span>
<span class="n">yrange</span> <span class="o">=</span> <span class="x">(</span><span class="n">ymin</span><span class="o">-</span><span class="mf">0.1</span><span class="x">)</span><span class="o">:</span><span class="n">h</span><span class="o">:</span><span class="x">(</span><span class="n">ymax</span><span class="o">+</span><span class="mf">0.1</span><span class="x">)</span>

<span class="n">x_grid</span> <span class="o">=</span> <span class="n">xrange</span><span class="err">'</span> <span class="o">.*</span> <span class="n">ones</span><span class="x">(</span><span class="n">length</span><span class="x">(</span><span class="n">yrange</span><span class="x">))</span>
<span class="n">y_grid</span> <span class="o">=</span> <span class="n">ones</span><span class="x">(</span><span class="n">length</span><span class="x">(</span><span class="n">xrange</span><span class="x">))</span><span class="err">'</span> <span class="o">.*</span> <span class="n">yrange</span>
<span class="n">Z</span> <span class="o">=</span> <span class="n">similar</span><span class="x">(</span><span class="n">x_grid</span><span class="x">)</span>
<span class="k">for</span> <span class="n">idx</span> <span class="k">in</span> <span class="n">eachindex</span><span class="x">(</span><span class="n">x_grid</span><span class="x">)</span>
    <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="x">([</span><span class="n">x_grid</span><span class="x">[</span><span class="n">idx</span><span class="x">],</span> <span class="n">y_grid</span><span class="x">[</span><span class="n">idx</span><span class="x">]])</span>
    <span class="n">Z</span><span class="x">[</span><span class="n">idx</span><span class="x">]</span> <span class="o">=</span> <span class="n">softmax</span><span class="x">(</span><span class="n">logits</span><span class="x">)[</span><span class="mi">1</span><span class="x">]</span>
<span class="k">end</span>
<span class="n">canvas</span> <span class="o">=</span> <span class="n">heatmap</span><span class="x">(</span><span class="n">xrange</span><span class="x">,</span> <span class="n">yrange</span><span class="x">,</span> <span class="n">Z</span><span class="x">,</span> <span class="n">size</span><span class="o">=</span><span class="x">(</span><span class="mi">800</span><span class="x">,</span> <span class="mi">500</span><span class="x">))</span></code></pre></figure>

<p>Plot points over the boundary:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">scatter!</span><span class="x">(</span>
    <span class="n">X</span><span class="x">[</span><span class="mi">1</span><span class="x">,</span> <span class="o">:</span><span class="x">],</span> <span class="n">X</span><span class="x">[</span><span class="mi">2</span><span class="x">,</span> <span class="o">:</span><span class="x">],</span> <span class="n">color</span><span class="o">=</span><span class="n">y</span><span class="x">,</span> <span class="n">label</span><span class="o">=</span><span class="s">""</span><span class="x">,</span> <span class="n">aspectratio</span><span class="o">=:</span><span class="n">equal</span><span class="x">,</span>
    <span class="n">xlims</span> <span class="o">=</span> <span class="n">xlims</span><span class="x">(</span><span class="n">canvas</span><span class="x">),</span>
    <span class="n">ylims</span> <span class="o">=</span> <span class="n">ylims</span><span class="x">(</span><span class="n">canvas</span><span class="x">),</span>
<span class="x">)</span></code></pre></figure>

<p>The result:</p>
<figure class="post-figure">
<img class="img-80" src="/assets/posts/micrograd/moons_decision_boundary.png" alt="Decision boundary" />
<figcaption>The probability boundaries of a multi-layer perceptron trained on the moons dataset.</figcaption>
</figure>

<h2 id="conclusion">6 Conclusion</h2>

<p>That was a long and difficult journey.
I hope you understand how automatic differentiation with Zygote.jl works now!</p>

<hr />

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:broadcast" role="doc-endnote">
      <p>The Zygote.jl code for <a href="https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl">broadcast</a> has this gem of a comment:</p>
      <blockquote>
<p>
    There's a saying that debugging code is about twice as hard as 
    writing it in the first place. So if you're as clever as you can
    be when writing code, how will you ever debug it?
</p>
<p>
    AD faces a similar dilemma: if you write code that's as clever as
    the compiler can handle, how will you ever differentiate it? 
    Differentiating makes clever code that bit more complex and the 
    compiler gives up, usually resulting in 100x worse performance.
</p>
<p>
    Base's broadcasting is very cleverly written, and this makes 
    differentiating it... somewhat tricky.
</p>
</blockquote>
      <p><a href="#fnref:broadcast" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Lior Sinai</name></author><category term="machine-learning" /><category term="mathematics" /><category term="transformers" /><category term="&apos;machine" /><category term="learning&apos;" /><category term="&apos;deep" /><category term="learning&apos;" /><summary type="html"><![CDATA[A series on automatic differentiation in Julia. Part 5 shows how the MicroGrad.jl code can be used for a machine learning framework like Flux.jl. The working example is a multi-layer perceptron trained on the moons dataset.]]></summary></entry><entry><title type="html">MicroGrad.jl: Part 4 Extensions</title><link href="https://liorsinai.github.io/machine-learning/2024/08/17/micrograd-4-ext.html" rel="alternate" type="text/html" title="MicroGrad.jl: Part 4 Extensions" /><published>2024-08-17T00:00:00+00:00</published><updated>2024-08-17T00:00:00+00:00</updated><id>https://liorsinai.github.io/machine-learning/2024/08/17/micrograd-4-ext</id><content type="html" xml:base="https://liorsinai.github.io/machine-learning/2024/08/17/micrograd-4-ext.html"><![CDATA[<p><em>A series on automatic differentiation in Julia. Part 4 extends part 3 to handle maps, getfield and anonymous functions. It creates a generic gradient descent and uses this to fit a polynomial.</em></p>

<p>This is part of a series. The other articles are:</p>
<ul>
  <li><a href="/machine-learning/2024/07/27/micrograd-1-chainrules">Part 1: ChainRules</a>.</li>
  <li><a href="/machine-learning/2024/08/03/micrograd-2-expr">Part 2: Automation with expressions</a>.</li>
  <li><a href="/machine-learning/2024/08/10/micrograd-3-ir">Part 3: Automation with IR</a>.</li>
  <li><a href="/machine-learning/2024/08/19/micrograd-5-mlp">Part 5: MLP</a>.</li>
</ul>

<p>All source code can be found at <a href="https://github.com/LiorSinai/MicroGrad.jl">MicroGrad.jl</a>.</p>

<h3 id="table-of-contents">Table of Contents</h3>

<nav id="toc"></nav>
<script src="/assets/makeTableOfContents.js"></script>

<h2 id="introduction">1 Introduction</h2>

<p>By end of part 3 we had code that could automatically differentiate many functions as long as we had <code class="language-plaintext highlighter-rouge">rrule</code>s and there was no control flow.</p>

<p>However, the code failed for the polynomial model:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> Polynomial</span><span class="x">{</span><span class="n">V</span><span class="o">&lt;:</span><span class="kt">AbstractVector</span><span class="x">}</span>
    <span class="n">weights</span><span class="o">::</span><span class="n">V</span>
<span class="k">end</span>
<span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Polynomial</span><span class="x">)(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">evalpoly</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">m</span><span class="o">.</span><span class="n">weights</span><span class="x">)</span>
<span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Polynomial</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">)</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">m</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Polynomial</span><span class="x">([</span><span class="mf">3.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">,</span> <span class="o">-</span><span class="mf">3.0</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">])</span>
<span class="n">x</span> <span class="o">=</span> <span class="x">[</span><span class="mf">1.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">,</span> <span class="mf">3.0</span><span class="x">,</span> <span class="mf">4.0</span><span class="x">]</span>
<span class="n">pullback</span><span class="x">(</span><span class="n">model</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span> <span class="c"># ERROR: No method found for Tuple{typeof(fieldtype) ....}</span></code></pre></figure>

<p>Calling <code class="language-plaintext highlighter-rouge">@code_ir model(x)</code>, we can see that code is lowered as follows:</p>

<figure class="highlight"><pre><code class="language-plaintext" data-lang="plaintext">1: (%1, %2)
  %7 = Main.map(%6, %2)
  return %7</code></pre></figure>

<p>And further that <code class="language-plaintext highlighter-rouge">model(1.0)</code> is lowered to:</p>

<figure class="highlight"><pre><code class="language-plaintext" data-lang="plaintext">1: (%1, %2)
  %3 = Base.getproperty(%1, :weights)
  %4 = Main.evalpoly(%2, %3)
  return %4</code></pre></figure>

<p>We could have also defined the <code class="language-plaintext highlighter-rouge">map</code> using an anonymous function:</p>

<figure class="highlight"><pre><code class="language-plaintext" data-lang="plaintext">(m::Polynomial)(x::AbstractVector) = map(x-&gt;evalpoly(x, m.weights), x)</code></pre></figure>

<p>In which case it would have been lowered to:</p>

<figure class="highlight"><pre><code class="language-plaintext" data-lang="plaintext">1: (%1, %2)
  %3 = Main.:(var"#43#44")
  %4 = Core.typeof(%1)
  %5 = Core.apply_type(%3, %4)
  %6 = %new(%5, %1)
  %7 = Main.map(%6, %2)
  return %7</code></pre></figure>

<p>The calls to <code class="language-plaintext highlighter-rouge">Core.typeof</code> and <code class="language-plaintext highlighter-rouge">Core.apply_type</code> are in the list of ignored functions.
However we need to handle <code class="language-plaintext highlighter-rouge">map</code>, <code class="language-plaintext highlighter-rouge">getproperty</code> and <code class="language-plaintext highlighter-rouge">%new</code>.
These sort of functions do not have formal mathematical derivatives and so they do not have <code class="language-plaintext highlighter-rouge">rrule</code>s in ChainRules.jl.
Instead, Zygote.jl handles these functions with their own custom pullbacks.
Zygote also replaces some low level functions like <code class="language-plaintext highlighter-rouge">new</code>, <code class="language-plaintext highlighter-rouge">getproperty</code> and <code class="language-plaintext highlighter-rouge">getindex</code> entirely with custom code.</p>

<h2 id="extending-pullback">2 Extending pullback</h2>
<h3 id="pullback-map">2.1 map</h3>

<p>The pullback for <code class="language-plaintext highlighter-rouge">map</code> is fairly complex. What will be presented here is a simplified version.
It might also help to look at the less generic code in the example in <a href="http://localhost:4000/machine-learning/2024/07/27/micrograd-1-chainrules.html#gradient-descent-map">part 1</a>.</p>

<p>Consider the following code:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">f</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">sin</span><span class="x">(</span><span class="n">x</span><span class="x">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="x">[</span><span class="mf">0.1</span><span class="x">,</span> <span class="mf">0.2</span><span class="x">,</span> <span class="mf">0.5</span><span class="x">]</span>
<span class="n">map</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span></code></pre></figure>

<p>The <code class="language-plaintext highlighter-rouge">pullback</code> for <code class="language-plaintext highlighter-rouge">map</code> should return 3 values: $\text{s̄elf}$ for <code class="language-plaintext highlighter-rouge">map</code>, $\bar{f}$ for the function <code class="language-plaintext highlighter-rouge">f</code> and $\bar{x}$ for each value in <code class="language-plaintext highlighter-rouge">x</code>.</p>

<p>The code will start by getting pullbacks for each value in <code class="language-plaintext highlighter-rouge">x</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">ys_and_backs</span> <span class="o">=</span> <span class="n">map</span><span class="x">((</span><span class="n">xs</span><span class="o">...</span><span class="x">)</span> <span class="o">-&gt;</span> <span class="n">pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="n">xs</span><span class="o">...</span><span class="x">),</span> <span class="n">x</span><span class="x">)</span> <span class="c"># ((0.099, Pullback), (0.198, Pullback), (0.479, Pullback))</span></code></pre></figure>

<p>This list is in a “zipped” format: there are $n$ entries of $(y_i, \mathcal{B}_i)$ for an array length $n$.
This will be unzipped into two lists each of length $n$: $(y_1,…,y_n), (\mathcal{B}_1,…,\mathcal{B}_n)$:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">Δ</span> <span class="o">=</span> <span class="n">ones</span><span class="x">(</span><span class="n">length</span><span class="x">(</span><span class="n">x</span><span class="x">))</span>
<span class="n">ys</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">first</span><span class="x">,</span> <span class="n">ys_and_backs</span><span class="x">)</span> <span class="c"># (0.099, 0.198, 0.479)</span>
<span class="n">∂f_and_∂x_zipped</span> <span class="o">=</span> <span class="n">map</span><span class="x">(((</span><span class="n">_</span><span class="x">,</span> <span class="n">pb</span><span class="x">),</span> <span class="n">δ</span><span class="x">)</span> <span class="o">-&gt;</span> <span class="n">pb</span><span class="x">(</span><span class="n">δ</span><span class="x">),</span> <span class="n">ys_and_backs</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span> <span class="c"># ((nothing, 0.995), (nothing, 0.980), (nothing, 0.877))</span></code></pre></figure>

<p>The gradients list of $n$ entries</p>

\[((\text{s̄elf}_1, \bar{x}_{11}, ..., \bar{x}_{k1}), ...,(\text{s̄elf}_n, \bar{x}_{1n}, ..., \bar{x}_{kn}))\]

<p>needs to be further unzipped into $k+1$ lists for $\text{s̄elf}$ and $k$ arguments:</p>

\[(\text{s̄elf}_1,...,\text{s̄elf}_{n}), (\bar{x}_{11},...,\bar{x}_{1n}), ... (\bar{x}_{k1},...,\bar{x}_{kn})\]

<p>This is done with an <code class="language-plaintext highlighter-rouge">unzip</code> function which generalises <code class="language-plaintext highlighter-rouge">first</code> to any index <code class="language-plaintext highlighter-rouge">i</code> (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/lib/array.jl#L137">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> StaticGetter</span><span class="x">{</span><span class="n">i</span><span class="x">}</span> <span class="k">end</span>
<span class="x">(</span><span class="o">::</span><span class="n">StaticGetter</span><span class="x">{</span><span class="n">i</span><span class="x">})(</span><span class="n">v</span><span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">i</span><span class="x">}</span> <span class="o">=</span> <span class="n">v</span><span class="x">[</span><span class="n">i</span><span class="x">]</span>
<span class="x">(</span><span class="o">::</span><span class="n">StaticGetter</span><span class="x">{</span><span class="n">i</span><span class="x">})(</span><span class="o">::</span><span class="kt">Nothing</span><span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">i</span><span class="x">}</span> <span class="o">=</span> <span class="nb">nothing</span>

<span class="k">function</span><span class="nf"> _unzip</span><span class="x">(</span><span class="n">tuples</span><span class="x">,</span> <span class="o">::</span><span class="kt">Val</span><span class="x">{</span><span class="n">N</span><span class="x">})</span> <span class="k">where</span> <span class="x">{</span><span class="n">N</span><span class="x">}</span>
  <span class="n">getters</span> <span class="o">=</span> <span class="n">ntuple</span><span class="x">(</span><span class="n">n</span> <span class="o">-&gt;</span> <span class="n">StaticGetter</span><span class="x">{</span><span class="n">n</span><span class="x">}(),</span> <span class="n">N</span><span class="x">)</span>
  <span class="n">map</span><span class="x">(</span><span class="n">g</span> <span class="o">-&gt;</span> <span class="n">map</span><span class="x">(</span><span class="n">g</span><span class="x">,</span> <span class="n">tuples</span><span class="x">),</span> <span class="n">getters</span><span class="x">)</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> unzip</span><span class="x">(</span><span class="n">tuples</span><span class="x">)</span>
  <span class="n">N</span> <span class="o">=</span> <span class="n">length</span><span class="x">(</span><span class="n">first</span><span class="x">(</span><span class="n">tuples</span><span class="x">))</span>
  <span class="n">_unzip</span><span class="x">(</span><span class="n">tuples</span><span class="x">,</span> <span class="kt">Val</span><span class="x">(</span><span class="n">N</span><span class="x">))</span>
<span class="k">end</span></code></pre></figure>

<p>The result:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">∂f_and_∂x</span> <span class="o">=</span> <span class="n">unzip</span><span class="x">(</span><span class="n">∂f_and_∂x_zipped</span><span class="x">)</span> <span class="c"># [nothing, nothing, nothing], [0.995, 0.98, 0.877]</span></code></pre></figure>

<p>As a final step, all the gradients for the function are accumulated into one value:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">∂f</span> <span class="o">=</span> <span class="n">reduce</span><span class="x">(</span><span class="n">accum</span><span class="x">,</span> <span class="n">∂f_and_∂x</span><span class="x">[</span><span class="mi">1</span><span class="x">])</span> <span class="c"># nothing</span></code></pre></figure>

<p>Putting all this code in a single function (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/lib/array.jl#L185">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> pullback</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">map</span><span class="x">),</span> <span class="n">f</span><span class="o">::</span><span class="n">F</span><span class="x">,</span> <span class="n">args</span><span class="o">::</span><span class="kt">Vararg</span><span class="x">{</span><span class="kt">Any</span><span class="x">,</span> <span class="n">N</span><span class="x">})</span> <span class="k">where</span> <span class="x">{</span><span class="n">F</span><span class="x">,</span> <span class="n">N</span><span class="x">}</span>
    <span class="n">ys_and_backs</span> <span class="o">=</span> <span class="n">map</span><span class="x">((</span><span class="n">xs</span><span class="o">...</span><span class="x">)</span> <span class="o">-&gt;</span> <span class="n">pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="n">xs</span><span class="o">...</span><span class="x">),</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
    <span class="n">ys</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">first</span><span class="x">,</span> <span class="n">ys_and_backs</span><span class="x">)</span>
    <span class="k">function</span><span class="nf"> map_pullback</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span>
      <span class="c"># technically should apply f in reverse and reverse back afterwards in case f is stateful</span>
      <span class="n">∂f_and_∂x_zipped</span> <span class="o">=</span> <span class="n">map</span><span class="x">(((</span><span class="n">_</span><span class="x">,</span> <span class="n">pb</span><span class="x">),</span> <span class="n">δ</span><span class="x">)</span> <span class="o">-&gt;</span> <span class="n">pb</span><span class="x">(</span><span class="n">δ</span><span class="x">),</span> <span class="n">ys_and_backs</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span>
      <span class="n">∂f_and_∂x</span> <span class="o">=</span> <span class="n">unzip</span><span class="x">(</span><span class="n">∂f_and_∂x_zipped</span><span class="x">)</span> 
      <span class="n">∂f</span> <span class="o">=</span> <span class="n">reduce</span><span class="x">(</span><span class="n">accum</span><span class="x">,</span> <span class="n">∂f_and_∂x</span><span class="x">[</span><span class="mi">1</span><span class="x">])</span>
      <span class="n">∂args</span> <span class="o">=</span> <span class="n">∂f_and_∂x</span><span class="x">[</span><span class="mi">2</span><span class="o">:</span><span class="k">end</span><span class="x">]</span>
      <span class="k">return</span> <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="n">∂f</span><span class="x">,</span> <span class="n">∂args</span><span class="o">...</span><span class="x">)</span>
    <span class="k">end</span>
    <span class="n">ys</span><span class="x">,</span> <span class="n">map_pullback</span>
<span class="k">end</span></code></pre></figure>

<p>Testing:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">x</span> <span class="o">=</span> <span class="x">[</span><span class="mf">0.1</span><span class="x">,</span> <span class="mf">0.2</span><span class="x">,</span> <span class="mf">0.5</span><span class="x">]</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">map</span><span class="x">,</span> <span class="n">sin</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span> 
<span class="n">back</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="n">length</span><span class="x">(</span><span class="n">x</span><span class="x">)))</span> <span class="c"># (nothing, nothing, [0.995, 0.98, 0.877])</span></code></pre></figure>

<p>And also:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">f</span><span class="x">(</span><span class="n">a</span><span class="x">,</span><span class="n">b</span><span class="x">)</span><span class="o">=</span><span class="n">a</span><span class="o">/</span><span class="x">(</span><span class="n">a</span><span class="o">+</span><span class="n">b</span><span class="o">*</span><span class="n">b</span><span class="x">)</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">map</span><span class="x">,</span> <span class="n">f</span><span class="x">,</span> <span class="x">[</span><span class="mf">2.0</span><span class="x">,</span> <span class="mf">4.0</span><span class="x">],</span> <span class="x">[</span><span class="mf">3.0</span><span class="x">,</span> <span class="mf">5.0</span><span class="x">])</span> 
<span class="n">back</span><span class="x">([</span><span class="mf">1.0</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">])</span> <span class="c"># (nothing, nothing, [0.074, 0.029], [-0.099, -0.047])</span></code></pre></figure>

<h3 id="pullback-instrument">2.2 Instrument</h3>

<p>Zygote.jl modifies some of the source code before creating the primal and reverse passes.
Here is a simplified version of this <code class="language-plaintext highlighter-rouge">instrument</code> function which only replaces <code class="language-plaintext highlighter-rouge">new</code> and <code class="language-plaintext highlighter-rouge">getfield</code> (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/compiler/reverse.jl#L121">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> instrument</span><span class="x">(</span><span class="n">ir</span><span class="o">::</span><span class="n">IR</span><span class="x">)</span>
    <span class="n">pr</span> <span class="o">=</span> <span class="kt">Pipe</span><span class="x">(</span><span class="n">ir</span><span class="x">)</span>
    <span class="k">for</span> <span class="x">(</span><span class="n">v</span><span class="x">,</span> <span class="n">st</span><span class="x">)</span> <span class="k">in</span> <span class="n">pr</span>
        <span class="n">ex</span> <span class="o">=</span> <span class="n">st</span><span class="o">.</span><span class="n">expr</span>
        <span class="k">if</span> <span class="n">isexpr</span><span class="x">(</span><span class="n">ex</span><span class="x">,</span> <span class="o">:</span><span class="n">new</span><span class="x">)</span>
            <span class="n">pr</span><span class="x">[</span><span class="n">v</span><span class="x">]</span> <span class="o">=</span> <span class="n">xcall</span><span class="x">(</span><span class="n">Main</span><span class="x">,</span> <span class="o">:</span><span class="n">__new__</span><span class="x">,</span> <span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="o">...</span><span class="x">)</span>
        <span class="k">elseif</span> <span class="n">is_literal_getfield</span><span class="x">(</span><span class="n">ex</span><span class="x">)</span>
            <span class="n">pr</span><span class="x">[</span><span class="n">v</span><span class="x">]</span> <span class="o">=</span> <span class="n">xcall</span><span class="x">(</span><span class="n">Main</span><span class="x">,</span> <span class="o">:</span><span class="n">literal_getfield</span><span class="x">,</span> <span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="x">[</span><span class="mi">2</span><span class="x">],</span> <span class="kt">Val</span><span class="x">(</span><span class="n">unwrapquote</span><span class="x">(</span><span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="x">[</span><span class="mi">3</span><span class="x">])))</span>
        <span class="k">end</span>
    <span class="k">end</span>
    <span class="n">finish</span><span class="x">(</span><span class="n">pr</span><span class="x">)</span>
<span class="k">end</span>

<span class="n">iscall</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">m</span><span class="o">::</span><span class="kt">Module</span><span class="x">,</span> <span class="n">n</span><span class="o">::</span><span class="kt">Symbol</span><span class="x">)</span> <span class="o">=</span> <span class="n">isexpr</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="o">:</span><span class="n">call</span><span class="x">)</span> <span class="o">&amp;&amp;</span> <span class="n">x</span><span class="o">.</span><span class="n">args</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span> <span class="o">==</span> <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">m</span><span class="x">,</span> <span class="n">n</span><span class="x">)</span>
<span class="n">unwrapquote</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span>
<span class="n">unwrapquote</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">QuoteNode</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">value</span>

<span class="n">is_literal_getfield</span><span class="x">(</span><span class="n">ex</span><span class="x">)</span> <span class="o">=</span>
  <span class="x">(</span><span class="n">iscall</span><span class="x">(</span><span class="n">ex</span><span class="x">,</span> <span class="n">Core</span><span class="x">,</span> <span class="o">:</span><span class="n">getfield</span><span class="x">)</span> <span class="o">||</span> <span class="n">iscall</span><span class="x">(</span><span class="n">ex</span><span class="x">,</span> <span class="n">Base</span><span class="x">,</span> <span class="o">:</span><span class="n">getfield</span><span class="x">))</span> <span class="o">&amp;&amp;</span>
  <span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="x">[</span><span class="mi">3</span><span class="x">]</span> <span class="k">isa</span> <span class="kt">Union</span><span class="x">{</span><span class="kt">QuoteNode</span><span class="x">,</span><span class="kt">Integer</span><span class="x">}</span></code></pre></figure>

<p>Modify the existing <code class="language-plaintext highlighter-rouge">_generate_pullback_via_decomposition</code> and <code class="language-plaintext highlighter-rouge">_generate_callable_pullback</code> functions to call it:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _generate_pullback_via_decomposition</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">)</span>
    <span class="n">m</span> <span class="o">=</span> <span class="n">meta</span><span class="x">(</span><span class="n">T</span><span class="x">;</span> <span class="n">world</span><span class="o">=</span><span class="n">world</span><span class="x">)</span>
    <span class="n">isnothing</span><span class="x">(</span><span class="n">m</span><span class="x">)</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="nb">nothing</span>
    <span class="n">ir</span> <span class="o">=</span> <span class="n">IR</span><span class="x">(</span><span class="n">m</span><span class="x">)</span>
    <span class="n">length</span><span class="x">(</span><span class="n">blocks</span><span class="x">(</span><span class="n">ir</span><span class="x">))</span> <span class="o">==</span> <span class="mi">1</span> <span class="o">||</span> <span class="n">error</span><span class="x">(</span><span class="s">"control flow is not supported"</span><span class="x">)</span>
    <span class="n">ir</span> <span class="o">=</span> <span class="n">instrument</span><span class="x">(</span><span class="n">ir</span><span class="x">)</span> <span class="c"># new</span>
    <span class="n">pr</span><span class="x">,</span> <span class="n">calls</span> <span class="o">=</span> <span class="n">primal</span><span class="x">(</span><span class="n">ir</span><span class="x">,</span> <span class="n">T</span><span class="x">)</span>
    <span class="n">m</span><span class="x">,</span> <span class="n">pr</span><span class="x">,</span> <span class="n">calls</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> _generate_callable_pullback</span><span class="x">(</span><span class="n">j</span><span class="o">::</span><span class="kt">Type</span><span class="x">{</span><span class="o">&lt;:</span><span class="n">Pullback</span><span class="x">{</span><span class="n">S</span><span class="x">,</span> <span class="n">T</span><span class="x">}},</span> <span class="n">world</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">S</span><span class="x">,</span> <span class="n">T</span><span class="x">}</span>
    <span class="n">m</span> <span class="o">=</span> <span class="n">meta</span><span class="x">(</span><span class="n">S</span><span class="x">;</span> <span class="n">world</span><span class="o">=</span><span class="n">world</span><span class="x">)</span>
    <span class="n">ir</span> <span class="o">=</span> <span class="n">IR</span><span class="x">(</span><span class="n">m</span><span class="x">)</span>
    <span class="n">isnothing</span><span class="x">(</span><span class="n">ir</span><span class="x">)</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="o">:</span><span class="x">(</span><span class="n">error</span><span class="x">(</span><span class="s">"Non-differentiable function "</span><span class="x">,</span> <span class="n">repr</span><span class="x">(</span><span class="n">args</span><span class="x">[</span><span class="mi">1</span><span class="x">])))</span>
    <span class="n">length</span><span class="x">(</span><span class="n">blocks</span><span class="x">(</span><span class="n">ir</span><span class="x">))</span> <span class="o">==</span> <span class="mi">1</span> <span class="o">||</span> <span class="n">error</span><span class="x">(</span><span class="s">"control flow is not supported"</span><span class="x">)</span>
    <span class="n">ir</span> <span class="o">=</span> <span class="n">instrument</span><span class="x">(</span><span class="n">ir</span><span class="x">)</span> <span class="c"># new</span>
    <span class="n">back</span> <span class="o">=</span> <span class="n">reverse_differentiate</span><span class="x">(</span><span class="n">ir</span><span class="x">)</span>
    <span class="n">back</span> <span class="o">=</span> <span class="n">slots!</span><span class="x">(</span><span class="n">inlineable!</span><span class="x">(</span><span class="n">back</span><span class="x">))</span>
    <span class="n">ci</span> <span class="o">=</span> <span class="n">build_codeinfo_</span><span class="x">(</span><span class="n">back</span><span class="x">)</span>
    <span class="n">ci</span><span class="o">.</span><span class="n">slotnames</span> <span class="o">=</span> <span class="x">[</span><span class="kt">Symbol</span><span class="x">(</span><span class="s">"#self#"</span><span class="x">),</span> <span class="o">:</span><span class="n">Δ</span><span class="x">]</span>
    <span class="n">ci</span>
<span class="k">end</span> </code></pre></figure>

<p>Now we need to define <code class="language-plaintext highlighter-rouge">literal_getfield</code> and <code class="language-plaintext highlighter-rouge">__new__</code> and their pullbacks.</p>

<h3 id="pullback-getfield">2.3 getfield</h3>

<p>Calls to <code class="language-plaintext highlighter-rouge">getproperty</code> default to <code class="language-plaintext highlighter-rouge">getfield</code>, where a field is is declared in a struct’s declaration.
The <code class="language-plaintext highlighter-rouge">getfield</code> function is substituted with <code class="language-plaintext highlighter-rouge">literal_getfield</code> (<a href="https://github.com/FluxML/ZygoteRules.jl/blob/f9bf0e367fa259c5aa68f0e14ccbf2125d734bd6/src/ZygoteRules.jl#L19">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">literal_getfield</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="o">::</span><span class="kt">Val</span><span class="x">{</span><span class="n">f</span><span class="x">})</span> <span class="k">where</span> <span class="n">f</span> <span class="o">=</span> <span class="n">getfield</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">f</span><span class="x">)</span></code></pre></figure>

<p>The pullback will return a <code class="language-plaintext highlighter-rouge">NamedTuple</code> for each field, where the gradient is <code class="language-plaintext highlighter-rouge">Δ</code> for the relevant field and nothing for the others (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/lib/lib.jl#L228">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="nd">@generated</span> <span class="n">nt_nothing</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">tuple</span><span class="x">,</span> <span class="x">[</span><span class="o">:</span><span class="x">(</span><span class="o">$</span><span class="n">f</span><span class="o">=</span><span class="nb">nothing</span><span class="x">)</span> <span class="k">for</span> <span class="n">f</span> <span class="k">in</span> <span class="n">fieldnames</span><span class="x">(</span><span class="n">x</span><span class="x">)]</span><span class="o">...</span><span class="x">)</span>
<span class="nd">@generated</span> <span class="n">pair</span><span class="x">(</span><span class="o">::</span><span class="kt">Val</span><span class="x">{</span><span class="n">k</span><span class="x">},</span> <span class="n">v</span><span class="x">,</span> <span class="n">_</span><span class="o">=</span><span class="nb">nothing</span><span class="x">)</span> <span class="k">where</span> <span class="n">k</span> <span class="o">=</span> <span class="o">:</span><span class="x">(</span><span class="o">$</span><span class="n">k</span> <span class="o">=</span> <span class="n">v</span><span class="x">,)</span>

<span class="k">function</span><span class="nf"> pullback</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">literal_getfield</span><span class="x">),</span> <span class="n">x</span><span class="x">,</span> <span class="o">::</span><span class="kt">Val</span><span class="x">{</span><span class="n">f</span><span class="x">})</span> <span class="k">where</span> <span class="n">f</span>
  <span class="n">val</span> <span class="o">=</span> <span class="n">getfield</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">f</span><span class="x">)</span>
  <span class="k">function</span><span class="nf"> literal_getfield_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span>
    <span class="k">if</span> <span class="n">isimmutable</span><span class="x">(</span><span class="n">x</span><span class="x">)</span>
      <span class="n">dx</span> <span class="o">=</span> <span class="x">(;</span> <span class="n">nt_nothing</span><span class="x">(</span><span class="n">x</span><span class="x">)</span><span class="o">...</span><span class="x">,</span> <span class="n">pair</span><span class="x">(</span><span class="kt">Val</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="n">Δ</span><span class="x">)</span><span class="o">...</span><span class="x">)</span>
      <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="n">dx</span><span class="x">,</span> <span class="nb">nothing</span><span class="x">)</span>
    <span class="k">else</span>
      <span class="n">error</span><span class="x">(</span><span class="s">"multable stucts not supported"</span><span class="x">)</span>
    <span class="k">end</span>
  <span class="k">end</span>
  <span class="n">val</span><span class="x">,</span> <span class="n">literal_getfield_back</span>
<span class="k">end</span>

<span class="n">pullback</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">getfield</span><span class="x">),</span> <span class="n">x</span><span class="x">,</span> <span class="n">field_name</span><span class="o">::</span><span class="kt">Symbol</span><span class="x">)</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">literal_getfield</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="kt">Val</span><span class="x">(</span><span class="n">field_name</span><span class="x">))</span></code></pre></figure>

<p>For example:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> Foo</span>
    <span class="n">a</span>
    <span class="n">b</span>
    <span class="n">c</span>
<span class="k">end</span>
<span class="n">foo</span> <span class="o">=</span> <span class="n">Foo</span><span class="x">(</span><span class="mf">1.0</span><span class="x">,</span> <span class="sc">'a'</span><span class="x">,</span> <span class="s">"hello"</span><span class="x">)</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">getfield</span><span class="x">,</span> <span class="n">foo</span><span class="x">,</span> <span class="o">:</span><span class="n">b</span><span class="x">)</span> <span class="c"># ('a', literal_getfield_back)</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># (nothing, (a = nothing, b = 1.0, c = nothing), nothing)</span></code></pre></figure>

<p>And for the polynomial model:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">model</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">)</span>
<span class="n">back</span><span class="x">(</span><span class="mf">2.3</span><span class="x">)</span> <span class="c"># ((weights = [2.3, 2.3, 2.3, 2.3],), -2.3)</span></code></pre></figure>

<p>For the first time we have a value $\text{s̄elf}$, which is the named tuple for the fields.</p>

<h3 id="pullback-new">2.4 new</h3>

<p>The code now works with:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Polynomial</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">)</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">m</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span></code></pre></figure>

<p>It returns $\text{s̄elf}$ and $\bar{x}$:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">model</span> <span class="o">=</span> <span class="n">Polynomial</span><span class="x">([</span><span class="mf">3.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">,</span> <span class="o">-</span><span class="mf">3.0</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">])</span>
<span class="n">x</span> <span class="o">=</span> <span class="x">[</span><span class="mf">1.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">,</span> <span class="mf">3.0</span><span class="x">,</span> <span class="mf">4.0</span><span class="x">]</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">model</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span>
<span class="n">back</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">4</span><span class="x">))</span> <span class="c"># ((weights = [4.0, 10.0, 30.0, 100.0],), [-1.0, 2.0, 11.0, 26.0])</span></code></pre></figure>

<p>However with an anonymous function:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Polynomial</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">)</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">x</span><span class="o">-&gt;</span><span class="n">evalpoly</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">m</span><span class="o">.</span><span class="n">weights</span><span class="x">),</span> <span class="n">x</span><span class="x">)</span></code></pre></figure>

<p><code class="language-plaintext highlighter-rouge">nothing</code> is returned for $\text{s̄elf}$:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">model</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span>
<span class="n">back</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">4</span><span class="x">))</span> <span class="c"># (nothing, [-1.0, 2.0, 11.0, 26.0])</span></code></pre></figure>

<p>If we inspect the <code class="language-plaintext highlighter-rouge">primal(ir)</code>, we see that it’s because no pullbacks and hence no gradients are recorded against variable <code class="language-plaintext highlighter-rouge">%1</code> (<code class="language-plaintext highlighter-rouge">self</code>):</p>

<figure class="highlight"><pre><code class="language-plaintext" data-lang="plaintext">1: (%1, %2)
  %3 = Main.:(var"#74#75")
  %4 = Core.typeof(%1)
  %5 = Core.apply_type(%3, %4)
  %6 = %new(%5, %1)
  %7 = Main.pullback(Main.map, %6, %2)
  %8 = Base.getindex(%7, 1)
  %9 = Base.getindex(%7, 2)
  %10 = Base.tuple(%9)
  %11 = (Pullback{Any})(%10)
  %12 = Base.tuple(%8, %11)
  return %12</code></pre></figure>

<p>The solution is to swap <code class="language-plaintext highlighter-rouge">%new</code> with a call to a custom function <code class="language-plaintext highlighter-rouge">__new__</code> with a pullback.
This function is as follows (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/tools/builtins.jl">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">macro</span><span class="nf"> __splatnew__</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">args</span><span class="x">)</span>
  <span class="n">esc</span><span class="x">(</span><span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">splatnew</span><span class="x">,</span> <span class="n">T</span><span class="x">,</span> <span class="n">args</span><span class="x">))</span>
<span class="k">end</span>

<span class="nd">@inline</span> <span class="n">__new__</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="nd">@__splatnew__</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">args</span><span class="x">)</span></code></pre></figure>

<p>And the pullback is (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/lib/lib.jl#L289">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">Base</span><span class="o">:</span> <span class="n">RefValue</span>
<span class="k">struct</span><span class="nc"> Jnew</span><span class="x">{</span><span class="n">T</span><span class="x">,</span><span class="n">G</span><span class="x">}</span>
  <span class="n">g</span><span class="o">::</span><span class="n">G</span>
<span class="k">end</span>

<span class="n">Jnew</span><span class="x">{</span><span class="n">T</span><span class="x">}(</span><span class="n">g</span><span class="x">)</span> <span class="k">where</span> <span class="n">T</span> <span class="o">=</span> <span class="n">Jnew</span><span class="x">{</span><span class="n">T</span><span class="x">,</span><span class="n">typeof</span><span class="x">(</span><span class="n">g</span><span class="x">)}(</span><span class="n">g</span><span class="x">)</span>

<span class="k">function</span><span class="nf"> pullback</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">__new__</span><span class="x">),</span> <span class="n">T</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
  <span class="n">x</span> <span class="o">=</span> <span class="n">__new__</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
  <span class="n">g</span> <span class="o">=</span> <span class="o">!</span><span class="n">ismutabletype</span><span class="x">(</span><span class="n">T</span><span class="x">)</span> <span class="o">||</span> <span class="n">fieldcount</span><span class="x">(</span><span class="n">T</span><span class="x">)</span> <span class="o">==</span> <span class="mi">0</span> <span class="o">?</span> <span class="nb">nothing</span> <span class="o">:</span> <span class="n">grad_mut</span><span class="x">(</span><span class="n">x</span><span class="x">)</span>
  <span class="n">x</span><span class="x">,</span> <span class="n">Jnew</span><span class="x">{</span><span class="n">T</span><span class="x">,</span><span class="n">typeof</span><span class="x">(</span><span class="n">g</span><span class="x">)}(</span><span class="n">g</span><span class="x">)</span>
<span class="k">end</span>

<span class="nd">@generated</span> <span class="k">function</span><span class="nf"> </span><span class="o">(</span><span class="n">back</span><span class="o">::</span><span class="n">Jnew</span><span class="x">{</span><span class="n">T</span><span class="x">,</span><span class="n">G</span><span class="x">})(</span><span class="n">Δ</span><span class="o">::</span><span class="kt">Union</span><span class="x">{</span><span class="kt">NamedTuple</span><span class="x">,</span><span class="kt">Nothing</span><span class="x">,</span><span class="n">RefValue</span><span class="x">})</span> <span class="k">where</span> <span class="x">{</span><span class="n">T</span><span class="x">,</span><span class="n">G</span><span class="x">}</span>
  <span class="o">!</span><span class="n">ismutabletype</span><span class="x">(</span><span class="n">T</span><span class="x">)</span> <span class="o">&amp;&amp;</span> <span class="n">Δ</span> <span class="o">==</span> <span class="kt">Nothing</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="o">:</span><span class="nb">nothing</span>
  <span class="n">Δ</span> <span class="o">=</span> <span class="n">G</span> <span class="o">==</span> <span class="kt">Nothing</span> <span class="o">?</span> <span class="o">:</span><span class="n">Δ</span> <span class="o">:</span>
      <span class="n">Δ</span> <span class="o">&lt;:</span> <span class="n">RefValue</span> <span class="o">?</span> <span class="o">:</span><span class="x">(</span><span class="n">back</span><span class="o">.</span><span class="n">g</span><span class="x">[])</span> <span class="o">:</span>
      <span class="o">:</span><span class="x">(</span><span class="n">accum</span><span class="x">(</span><span class="n">back</span><span class="o">.</span><span class="n">g</span><span class="x">[],</span> <span class="n">Δ</span><span class="x">))</span>
  <span class="k">quote</span>
    <span class="n">x̄</span> <span class="o">=</span> <span class="o">$</span><span class="n">Δ</span>
    <span class="o">$</span><span class="x">(</span><span class="n">G</span> <span class="o">==</span> <span class="kt">Nothing</span> <span class="o">||</span> <span class="o">:</span><span class="x">(</span><span class="n">back</span><span class="o">.</span><span class="n">g</span><span class="x">[]</span> <span class="o">=</span> <span class="n">nt_nothing</span><span class="x">(</span><span class="o">$</span><span class="n">Δ</span><span class="x">)))</span>
    <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="nb">nothing</span><span class="x">,</span> <span class="o">$</span><span class="x">(</span><span class="n">map</span><span class="x">(</span><span class="n">f</span> <span class="o">-&gt;</span> <span class="o">:</span><span class="x">(</span><span class="n">x̄</span><span class="o">.$</span><span class="n">f</span><span class="x">),</span> <span class="n">fieldnames</span><span class="x">(</span><span class="n">T</span><span class="x">))</span><span class="o">...</span><span class="x">))</span>
  <span class="k">end</span>
<span class="k">end</span></code></pre></figure>

<p>Now if we try the following (after redefining <code class="language-plaintext highlighter-rouge">@generated function pullback</code> and <code class="language-plaintext highlighter-rouge">function (methodinstance::Pullback)</code>) we should get the same results:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">model</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span>
<span class="n">back</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">4</span><span class="x">))</span> <span class="c"># ((weights = [4.0, 10.0, 30.0, 100.0],), [-1.0, 2.0, 11.0, 26.0])</span></code></pre></figure>

<h2 id="gradient-descent-revisited">3 Gradient Descent revisited</h2>
<h3 id="generic-gradient-descent">3.1 Generic Gradient Descent</h3>

<p>Now that we have an automatic differentiation engine, it is possible to create a much more generic gradient descent function than in <a href="/machine-learning/2024/07/27/micrograd-1-chainrules.html#gradient-descent">part 1</a>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> gradient_descent!</span><span class="x">(</span>
    <span class="n">model</span><span class="x">,</span>
    <span class="n">loss</span><span class="x">,</span>
    <span class="n">X</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">,</span>
    <span class="n">Y</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span>
    <span class="x">;</span> <span class="n">learning_rate</span><span class="o">::</span><span class="kt">AbstractFloat</span><span class="o">=</span><span class="mf">0.1</span><span class="x">,</span>
    <span class="n">max_iters</span><span class="o">::</span><span class="kt">Integer</span><span class="o">=</span><span class="mi">100</span>
    <span class="x">)</span>
    <span class="n">losses</span> <span class="o">=</span> <span class="kt">Float64</span><span class="x">[]</span>
    <span class="k">for</span> <span class="n">i</span> <span class="k">in</span> <span class="mi">1</span><span class="o">:</span><span class="n">max_iters</span>
        <span class="n">loss_iter</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">model</span><span class="x">)</span> <span class="k">do</span> <span class="n">m</span>
            <span class="n">result</span> <span class="o">=</span> <span class="n">m</span><span class="x">(</span><span class="n">X</span><span class="x">)</span>
            <span class="n">loss</span><span class="x">(</span><span class="n">result</span><span class="x">,</span> <span class="n">Y</span><span class="x">)</span>
        <span class="k">end</span> 
        <span class="n">Δf</span><span class="x">,</span> <span class="n">Δm</span> <span class="o">=</span> <span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span>
        <span class="n">update_params!</span><span class="x">(</span><span class="n">parameters</span><span class="x">(</span><span class="n">model</span><span class="x">),</span> <span class="n">Δm</span><span class="x">;</span> <span class="n">learning_rate</span><span class="o">=</span><span class="n">learning_rate</span><span class="x">)</span>
        <span class="n">push!</span><span class="x">(</span><span class="n">losses</span><span class="x">,</span> <span class="n">loss_iter</span><span class="x">)</span>  
    <span class="k">end</span>
    <span class="n">losses</span>
<span class="k">end</span></code></pre></figure>

<p>Note that <code class="language-plaintext highlighter-rouge">pullback(m-&gt;f(m), model)</code> is directly equivalent to <code class="language-plaintext highlighter-rouge">pullback(model) do f(m) end</code>.</p>

<p>The <code class="language-plaintext highlighter-rouge">update_params!</code> function is defined as follows:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> update_params!</span><span class="x">(</span><span class="n">params</span><span class="o">::</span><span class="kt">NamedTuple</span><span class="x">,</span> <span class="n">grads</span><span class="o">::</span><span class="kt">NamedTuple</span><span class="x">;</span> <span class="n">options</span><span class="o">...</span><span class="x">)</span>
    <span class="k">for</span> <span class="n">key</span> <span class="k">in</span> <span class="n">keys</span><span class="x">(</span><span class="n">params</span><span class="x">)</span>
        <span class="n">update_params!</span><span class="x">(</span><span class="n">params</span><span class="x">[</span><span class="n">key</span><span class="x">],</span> <span class="n">grads</span><span class="x">[</span><span class="n">key</span><span class="x">];</span> <span class="n">options</span><span class="o">...</span><span class="x">)</span>
    <span class="k">end</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> update_params!</span><span class="x">(</span><span class="n">params</span><span class="o">::</span><span class="kt">Tuple</span><span class="x">,</span> <span class="n">grads</span><span class="o">::</span><span class="kt">Tuple</span><span class="x">;</span> <span class="n">options</span><span class="o">...</span><span class="x">)</span>
    <span class="k">for</span> <span class="x">(</span><span class="n">p</span><span class="x">,</span> <span class="n">g</span><span class="x">)</span> <span class="k">in</span> <span class="n">zip</span><span class="x">(</span><span class="n">params</span><span class="x">,</span> <span class="n">grads</span><span class="x">)</span>
        <span class="n">update_params!</span><span class="x">(</span><span class="n">p</span><span class="x">,</span> <span class="n">g</span><span class="x">;</span> <span class="n">options</span><span class="o">...</span><span class="x">)</span>
    <span class="k">end</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> update_params!</span><span class="x">(</span><span class="n">params</span><span class="x">,</span> <span class="n">grads</span><span class="x">;</span> <span class="n">learning_rate</span><span class="o">::</span><span class="kt">AbstractFloat</span><span class="o">=</span><span class="mf">0.1</span><span class="x">)</span>
    <span class="n">params</span> <span class="o">.-=</span> <span class="n">learning_rate</span> <span class="o">.*</span> <span class="n">grads</span> <span class="c"># must broadcast to edit elements and not copies!</span>
<span class="k">end</span></code></pre></figure>

<p>The <code class="language-plaintext highlighter-rouge">parameters</code> function is defined per model.
(Flux uses the generic Functors.jl library to accomplish something similar.)</p>

<h3 id="polynomial-curve-fitting-revisited">3.2 Polynomial curve fitting revisited</h3>

<p>Let’s create the exact same data set from part 1:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">StatsBase</span>
<span class="n">target_weights</span> <span class="o">=</span> <span class="x">[</span><span class="mf">15.0</span><span class="x">,</span> <span class="o">-</span><span class="mf">2.1</span><span class="x">,</span> <span class="mf">13.9</span><span class="x">,</span> <span class="mf">1.5</span><span class="x">]</span>
<span class="n">noise_factor</span> <span class="o">=</span> <span class="mf">0.2</span>
<span class="n">xs</span> <span class="o">=</span> <span class="x">(</span><span class="n">rand</span><span class="x">(</span><span class="mi">100</span><span class="x">)</span> <span class="o">.-</span> <span class="mf">0.5</span><span class="x">)</span> <span class="o">.*</span> <span class="mi">10</span>
<span class="n">ys</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">x</span> <span class="o">-&gt;</span> <span class="n">evalpoly</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">target_weights</span><span class="x">),</span> <span class="n">xs</span><span class="x">)</span>
<span class="n">scale_factor</span> <span class="o">=</span> <span class="n">mean</span><span class="x">(</span><span class="n">abs</span><span class="o">.</span><span class="x">(</span><span class="n">ys</span><span class="x">))</span>
<span class="n">ys</span> <span class="o">.+=</span> <span class="n">randn</span><span class="x">(</span><span class="n">length</span><span class="x">(</span><span class="n">ys</span><span class="x">))</span> <span class="o">*</span> <span class="n">scale_factor</span> <span class="o">*</span> <span class="n">noise_factor</span></code></pre></figure>

<p>The <code class="language-plaintext highlighter-rouge">Polynomial</code> model is defined in the introduction. We also need a custom method for <code class="language-plaintext highlighter-rouge">parameters</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">parameters</span><span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Polynomial</span><span class="x">)</span> <span class="o">=</span> <span class="x">(;</span><span class="n">weights</span><span class="o">=</span><span class="n">m</span><span class="o">.</span><span class="n">weights</span><span class="x">)</span></code></pre></figure>

<p>Define the model:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">model</span> <span class="o">=</span> <span class="n">Polynomial</span><span class="x">(</span><span class="n">rand</span><span class="x">(</span><span class="mi">4</span><span class="x">))</span></code></pre></figure>

<p>Some sanity checks:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">x</span> <span class="o">=</span> <span class="x">[</span><span class="mf">1.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">,</span> <span class="mf">3.0</span><span class="x">]</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">model</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span> <span class="c"># ([1.68, 7.21, 21.2], Pullback) </span>
<span class="n">back</span><span class="x">([</span><span class="mf">1.0</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">])</span> <span class="c"># ((weights = [3.0, 6.0, 14.0, 36.0],), [-1.0, 2.0, 11.0])</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">m</span><span class="o">-&gt;</span><span class="n">m</span><span class="x">(</span><span class="n">x</span><span class="x">),</span> <span class="n">model</span><span class="x">)</span> 
<span class="n">back</span><span class="x">([</span><span class="mf">1.0</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">])</span> <span class="c"># (nothing, (weights = [3.0, 6.0, 14.0, 36.0],))</span>
<span class="n">y</span> <span class="o">=</span> <span class="x">[</span><span class="mf">2.0</span><span class="x">,</span> <span class="mf">4.0</span><span class="x">,</span> <span class="mf">8.0</span><span class="x">]</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">m</span><span class="o">-&gt;</span><span class="n">mse</span><span class="x">(</span><span class="n">m</span><span class="x">(</span><span class="n">x</span><span class="x">),</span> <span class="n">y</span><span class="x">),</span> <span class="n">model</span><span class="x">)</span> 
<span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># (nothing, (weights = [10.7 30.5, 87.6, 254.6],))</span></code></pre></figure>

<p>Train the model:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">history</span> <span class="o">=</span> <span class="n">gradient_descent!</span><span class="x">(</span><span class="n">model</span><span class="x">,</span> <span class="n">mse</span><span class="x">,</span> <span class="n">xs</span><span class="x">,</span> <span class="n">ys</span><span class="x">;</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-5</span><span class="x">,</span> <span class="n">max_iters</span><span class="o">=</span><span class="mi">2000</span><span class="x">)</span></code></pre></figure>

<p>This works just as well as before.</p>

<h2 id="conclusion">4 Conclusion</h2>

<p>We now have a fully working AD package.
It has some limitations, such as it cannot handle control flow or keyword arguments.
However it can already work on a wide variety of code.
All that might be needed is new <code class="language-plaintext highlighter-rouge">rrule</code> definitions.
The next and final <a href="/machine-learning/2024/08/19/micrograd-5-mlp">part</a> of this series is a demonstration of exactly that.</p>

<hr />]]></content><author><name>Lior Sinai</name></author><category term="machine-learning" /><category term="mathematics" /><category term="transformers" /><category term="&apos;machine" /><category term="learning&apos;" /><category term="&apos;deep" /><category term="learning&apos;" /><summary type="html"><![CDATA[A series on automatic differentiation in Julia. Part 4 extends part 3 to handle maps, getfield and anonymous functions. It creates a generic gradient descent and uses this to fit a polynomial.]]></summary></entry><entry><title type="html">MicroGrad.jl: Part 3 Automation with IRTools</title><link href="https://liorsinai.github.io/machine-learning/2024/08/10/micrograd-3-ir.html" rel="alternate" type="text/html" title="MicroGrad.jl: Part 3 Automation with IRTools" /><published>2024-08-10T00:00:00+00:00</published><updated>2024-08-17T00:00:00+00:00</updated><id>https://liorsinai.github.io/machine-learning/2024/08/10/micrograd-3-ir</id><content type="html" xml:base="https://liorsinai.github.io/machine-learning/2024/08/10/micrograd-3-ir.html"><![CDATA[<p><em>A series on automatic differentiation in Julia. Part 3 uses metaprogramming based on IRTools.jl to generate a modified (primal) forward pass and to reverse differentiate it into a backward pass. This is a more robust approach than the expression based approach in Part 2.</em></p>

<p>This is part of a series. The other articles are:</p>
<ul>
  <li><a href="/machine-learning/2024/07/27/micrograd-1-chainrules">Part 1: ChainRules</a>.</li>
  <li><a href="/machine-learning/2024/08/03/micrograd-2-expr">Part 2: Automation with expressions</a>.</li>
  <li><a href="/machine-learning/2024/08/17/micrograd-4-ext">Part 4: Extensions</a>.</li>
  <li><a href="/machine-learning/2024/08/19/micrograd-5-mlp">Part 5: MLP</a>.</li>
</ul>

<p>All source code can be found at <a href="https://github.com/LiorSinai/MicroGrad.jl">MicroGrad.jl</a>.
The code here is based on the example at <a href="https://github.com/FluxML/IRTools.jl/blob/master/examples/reverse.jl">IRTools.jl</a>.</p>

<h3 id="table-of-contents">Table of Contents</h3>

<nav id="toc"></nav>
<script src="/assets/makeTableOfContents.js"></script>

<h2 id="introduction">1 Introduction</h2>

<p><a href="/machine-learning/2024/07/27/micrograd-1-chainrules">Part 1</a> introduced the <code class="language-plaintext highlighter-rouge">rrule</code> for implementing chain rules 
and <a href="/machine-learning/2024/08/03/micrograd-2-expr">Part 2</a> defined a <code class="language-plaintext highlighter-rouge">@generated pullback</code> function for inspecting and decomposing complex code.
The goal here is to replicate the results of Part 2 except in a more robust manner using the <a href="https://fluxml.ai/IRTools.jl/latest/">IRTools.jl</a> package.</p>

<div class="message-container warning-message">
    <div class="message-icon fa fa-fw fa-2x fa-exclamation-triangle">
    </div>
    <div class="content-container">
        <div class="message-body">
        Metaprogramming is a powerful tool, but it introduces complexity that can make code more difficult to understand. It can easily introduces critical bugs that can crash a program.
        Care should be taken when using it.
        </div>
    </div>
</div>

<p>For example, from part 1 there are <code class="language-plaintext highlighter-rouge">rrule</code>s for <code class="language-plaintext highlighter-rouge">+</code>, <code class="language-plaintext highlighter-rouge">*</code> and <code class="language-plaintext highlighter-rouge">/</code>.
The goal is then to automatically differentiate the following:</p>

\[f(a, b) = \frac{a}{a + b^2}\]

<p>like so:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">f</span><span class="x">(</span><span class="n">a</span><span class="x">,</span> <span class="n">b</span><span class="x">)</span> <span class="o">=</span> <span class="n">a</span> <span class="o">/</span> <span class="x">(</span><span class="n">a</span> <span class="o">+</span> <span class="n">b</span><span class="o">*</span><span class="n">b</span><span class="x">)</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">,</span> <span class="mf">3.0</span><span class="x">)</span> <span class="c"># (0.1818, ∂(f))</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># (nothing, 0.0744, -0.099)</span></code></pre></figure>

<p>where <code class="language-plaintext highlighter-rouge">pullback</code> is a <code class="language-plaintext highlighter-rouge">@generated</code> function that inspects the Intermediate Representation (IR) code for <code class="language-plaintext highlighter-rouge">f</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">IRTools</span>
<span class="n">ir</span> <span class="o">=</span> <span class="nd">@code_ir</span> <span class="n">f</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">3</span><span class="x">)</span>
<span class="cm">#= 1: (%1, %2, %3)
  %4 = %3 * %3
  %5 = %2 + %4
  %6 = %2 / %5
  return %6
=#</span></code></pre></figure>

<p>This is an advanced use of the Julia programming language.
You should be comfortable with the language before reading this post.
At the very least, the Julia documentation page on <a href="https://docs.julialang.org/en/v1/manual/metaprogramming/">metaprogramming</a> is required for this post and will be considered assumed knowledge, especially the sections on “Expressions and evaluation”, “Code Generation” and “Generated Functions”. I also suggest going through the <a href="https://fluxml.ai/IRTools.jl/latest/">IRTools.jl</a> documentation first.</p>

<p>This post can be read independently to Part 2 and will repeat parts of it.
However it is advised to read Part 2 first because it is easier to understand than this post.</p>

<h2 id="wengert-lists">2 Differentiating Wengert Lists</h2>

<p>The <a href="https://fluxml.ai/Zygote.jl/stable/">Zygote.jl</a> automatic differentiation (AD) package is a realisation of the paper <a href="https://arxiv.org/abs/1810.07951">Don’t Unroll Adjoint: Differentiating SSA-Form Programs (2019)</a> by Michael J Innes.<br />
The paper works with Wengert lists, also known as tapes, and a generalisation of it called Static Single Assignment (SSA) form.
The aim here is to develop a minimal AD package, so this series only focuses on the sections on Wengert lists.
A consequence is that the code will not be to handle any non-linear logic in Julia, for example any control flow like <code class="language-plaintext highlighter-rouge">if</code>, <code class="language-plaintext highlighter-rouge">while</code> or <code class="language-plaintext highlighter-rouge">for</code> blocks.</p>

<p>The paper uses the same example as the introduction:</p>

\[f(a, b) = \frac{a}{a + b^2}
\tag{2.1}
\label{eq:f}\]

<p>This can be broken down into smaller steps where each intermediate variable is saved.
This is known as a Wengert list, or tape, or (backpropagation) graph:</p>

\[\begin{align}
y_1 &amp;= b \times b \\
y_2 &amp;= a + y_1 \\
y_3 &amp;= a / y_2
\end{align}
\tag{2.2}
\label{eq:f_wengert}\]

<p>To differentiate this, all function calls are wrapped with a differentiation function $\mathcal{J}$ which returns both the output $y$ and a pullback function $\mathcal{B}$.
This is called the <em>primal</em> form:</p>

\[\begin{align}
y_1, \mathcal{B}_1 &amp;\leftarrow \mathcal{J}(\times, b, b) \\
y_2, \mathcal{B}_2 &amp;\leftarrow \mathcal{J}(+, a, y_1) \\
y_3, \mathcal{B}_3 &amp;\leftarrow \mathcal{J}(/, a, y_2)
\end{align}
\tag{2.3}
\label{eq:primal}\]

<p>The pullback function $\mathcal{B}$ takes as input the gradient of a scalar $l$ (typically a loss function) to a function $y(x)$ and returns the gradient with regards to the variable $x$.
This partial gradient $\frac{\partial l}{\partial x}$ is written as $\bar{x}$.</p>

\[\begin{align}
\bar{x} &amp;= \frac{\partial l}{\partial x} = \frac{\partial l}{\partial y} \frac{\partial y}{\partial x}
\end{align}
\tag{2.4}
\label{eq:bar_x}\]

<p>so we can write in this mathematical notation as:</p>

\[\begin{align}
\bar{x} &amp;\leftarrow \mathcal{B}(\bar{y}) = \bar{y} \frac{\partial y}{\partial x}\\
\text{or} \quad \bar{x} &amp;\leftarrow  \mathcal{B}(\bar{y}) = J^{\dagger}\bar{y}
\end{align}
\tag{2.5}
\label{eq:pullback}\]

<p>where $\bar{y}=\frac{\partial l}{\partial y}$ and $J=\frac{\partial y}{\partial x}$ is the Jacobian (gradient) for arrays.</p>

<p>The various partial gradients are calculated by reversing the list.
Each pullback function $\mathcal{B}_i$ takes as input the previous gradient $\bar{y}_i$.
The input is an existing gradient $\Delta$. At the start this is usually set to 1:</p>

\[\begin{align}
\text{s̄elf}_3, \bar{a}_{3,1}, \bar{y}_2 &amp;\leftarrow \mathcal{B}_3(\Delta) \\
\text{s̄elf}_2, \bar{a}_{2,1}, \bar{y}_1 &amp;\leftarrow \mathcal{B}_2(\bar{y}_2) \\
\text{s̄elf}_1, \bar{b}_{1,1}, \bar{b}_{1,2} &amp;\leftarrow \mathcal{B}_1(\bar{y}_1)
\end{align}
\tag{2.6}
\label{eq:reverse}\]

<p>The final step is to accumulate the gradients for variables which are used multiple times:</p>

\[\begin{align}
\bar{a} &amp;\leftarrow \bar{a}_{3,1} + \bar{a}_{2,1} \\
\bar{b} &amp;\leftarrow \bar{b}_{1,1} + \bar{b}_{1,2} \\
\end{align}
\tag{2.7}
\label{eq:accumulate}\]

<p>This end result is equivalent to rolling everything up into one function using the multivariable chain rule:</p>

\[\begin{align}
\bar{a} &amp;= \frac{\partial l}{\partial a} = \mathcal{B}_{3,a}(\Delta) + \mathcal{B}_{2,a}(\bar{y}_2) \\
        &amp;= \frac{\partial l}{\partial y_3} \frac{\partial y_3}{\partial a} + \frac{\partial l}{\partial y_2} \frac{\partial y_2}{\partial a} \\
        &amp;= \Delta \cdot \frac{\partial }{\partial a} \left( \frac{a}{y_2}\right) + 
        \left(\frac{\partial l}{\partial y_3}\frac{\partial y_3}{\partial y_2} \right)\frac{\partial}{\partial a}(a + y_1) \\
        &amp;= \Delta  \frac{1}{y_2} + \left(\Delta \frac{-a}{y_2^2} \right) (1+0) \\
        &amp;= \Delta \frac{b^2}{(a+b^2)^2} \\
\bar{b} &amp;= \frac{\partial l}{\partial b} = 2 \mathcal{B}_{1,b}(\bar{y}_1) \\
        &amp;= 2\frac{\partial l}{\partial y_1} \frac{\partial y_1}{\partial b} \\
        &amp;= 2 \left(\frac{\partial l}{\partial y_3}\frac{\partial y_3}{\partial y_2}\frac{\partial y_2}{\partial y_1} \right) \frac{\partial y_1}{\partial b} \\
        &amp;= 2 \left(\Delta \cdot \frac{\partial}{\partial y_2}\left(\frac{a}{y_2}\right) \cdot \frac{\partial}{\partial y_1}(a + y_1) \right)\frac{\partial}{\partial b'}(b'\times b) \\
        &amp;= 2\left(\Delta \left(-\frac{a}{y_2^2}\right)(0+1)\right)b \\
        &amp;= -\frac{2ab\Delta}{(a+b^2)^2}
\end{align}
\tag{2.8}
\label{eq:rollup}\]

<h2 id="pullback">3 Pullback</h2>
<h3 id="pullback-definition">3.1 Definition</h3>

<p>The goal is to generate code which automatically implements the equations of section 2.</p>

<div class="message-container info-message">
  <div class="message-icon fa fa-fw fa-2x fa-exclamation-circle"></div>
    <div class="content-container">
      <div class="message-body">
        The <code>pullback</code> function that is implemented here is equivalent to the internal <code>Zygote._pullback</code> function, which returns all partial gradients including for $\frac{\partial l}{\partial \text{self}}$. <code>Zygote.pullback</code> is a thin wrapper around <code>Zygote._pullback</code> which discards that first gradient.
      </div>
    </div>
</div>

<p>To start, define a <code class="language-plaintext highlighter-rouge">pullback</code> function (<a href="https://github.com/FluxML/ZygoteRules.jl/blob/f9bf0e367fa259c5aa68f0e14ccbf2125d734bd6/src/adjoint.jl#L33">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> pullback</span> <span class="k">end</span></code></pre></figure>

<p>This will be turned into a <a href="https://docs.julialang.org/en/v1/manual/metaprogramming/#Generated-functions">generated function</a>.</p>

<p>Julia changed the behaviour of generated functions in <a href="https://github.com/JuliaLang/julia/issues/49715">version 1.10</a>.
Before 1.10, they always had access to the <a href="https://docs.julialang.org/en/v1/manual/methods/">world age counter</a>.
This is a single number that is incremented every time a method is defined, and helps optimise compilations.
However from version 1.10 generated functions <code class="language-plaintext highlighter-rouge">Base.get_world_counter()</code> will only return <code class="language-plaintext highlighter-rouge">typemax(UInt)</code>.
This is to prevent reflection - code inspection - in generated functions.<sup id="fnref:generated_reflection" role="doc-noteref"><a href="#fn:generated_reflection" class="footnote" rel="footnote">1</a></sup>
However the code here relies on reflection.
Thankfully, there is a hack that <a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/compiler/interface2.jl#L69C17-L69C31">Zygote.jl</a> uses to access the world age in <code class="language-plaintext highlighter-rouge">pullback</code>.
Because of this, the definition of <code class="language-plaintext highlighter-rouge">pullback</code> is different based on the version, but both will forward to a common internal <code class="language-plaintext highlighter-rouge">_generate_pullback</code> function.</p>

<div class="message-container info-message">
    <div class="message-icon fa fa-fw fa-2x fa-exclamation-circle">
    </div>
    <div class="content-container">
        <div class="message-body">
    Generated functions should only be defined after all other functions. That is, at the bottom of the file or after all functions have been defined in the REPL. Otherwise they will not be able to access those functions or only old versions of those functions. These functions are defined here at the top only for explanatory purposes.
        </div>
    </div>
</div>

<div class="accordion" id="accordianJuliaVersions">
  <div class="card">
    <div class="card-header" id="generatedJuliaPre10">
      <div class="mb-0">
        <button class="btn btn-link btn-block text-left" type="button" data-toggle="collapse" data-target="#collapsePre10" aria-expanded="true" aria-controls="collapsePre10">
          Julia Version before 1.10
        </button>
      </div>
    </div>
    <div id="collapsePre10" class="collapse show" aria-labelledby="generatedJuliaPre10" data-parent="#accordianJuliaVersions">
      <div class="card-body">

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="nd">@generated</span> <span class="k">function</span><span class="nf"> pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
        <span class="n">_generate_pullback</span><span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

      </div>
    </div>
  </div>
  <div class="card">
    <div class="card-header" id="generatedJuliaPost10">
      <div class="mb-0">
        <button class="btn btn-link btn-block text-left collapsed" type="button" data-toggle="collapse" data-target="#collapsePost10" aria-expanded="false" aria-controls="collapsePost10">
          Julia Version after 1.10
        </button>
      </div>
    </div>
    <div id="collapsePost10" class="collapse" aria-labelledby="generatedJuliaPost10" data-parent="#accordianJuliaVersions">
      <div class="card-body">

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _pullback_generator</span><span class="x">(</span><span class="n">world</span><span class="o">::</span><span class="kt">UInt</span><span class="x">,</span> <span class="n">source</span><span class="x">,</span> <span class="n">self</span><span class="x">,</span> <span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="x">)</span>
        <span class="n">ret</span> <span class="o">=</span> <span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
        <span class="n">ret</span> <span class="k">isa</span> <span class="n">Core</span><span class="o">.</span><span class="n">CodeInfo</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="n">ret</span>
        <span class="n">stub</span> <span class="o">=</span> <span class="n">Core</span><span class="o">.</span><span class="n">GeneratedFunctionStub</span><span class="x">(</span><span class="n">identity</span><span class="x">,</span> <span class="n">Core</span><span class="o">.</span><span class="n">svec</span><span class="x">(</span><span class="o">:</span><span class="n">methodinstance</span><span class="x">,</span> <span class="o">:</span><span class="n">f</span><span class="x">,</span> <span class="o">:</span><span class="n">args</span><span class="x">),</span> <span class="n">Core</span><span class="o">.</span><span class="n">svec</span><span class="x">())</span>
        <span class="n">stub</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">source</span><span class="x">,</span> <span class="n">ret</span><span class="x">)</span>
<span class="k">end</span>

<span class="nd">@eval</span> <span class="k">function</span><span class="nf"> pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
        <span class="o">$</span><span class="x">(</span><span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">meta</span><span class="x">,</span> <span class="o">:</span><span class="n">generated</span><span class="x">,</span> <span class="n">_pullback_generator</span><span class="x">))</span>
        <span class="o">$</span><span class="x">(</span><span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">meta</span><span class="x">,</span> <span class="o">:</span><span class="n">generated_only</span><span class="x">))</span>
<span class="k">end</span></code></pre></figure>

      </div>
    </div>
  </div>
</div>

<h3 id="chainrules">3.2 ChainRules</h3>

<p>The first goal of <code class="language-plaintext highlighter-rouge">_generate_pullback</code> will be to forward the function and its arguments to a matching <code class="language-plaintext highlighter-rouge">rrule</code> if it exists.
For now it will throw an error if it cannot find one.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
    <span class="n">T</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">}</span>
    <span class="k">if</span> <span class="x">(</span><span class="n">has_chain_rrule</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">))</span>
        <span class="k">return</span> <span class="o">:</span><span class="x">(</span><span class="n">rrule</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">))</span>
    <span class="k">end</span>
    <span class="o">:</span><span class="x">(</span><span class="n">error</span><span class="x">(</span><span class="s">"No rrule found for "</span><span class="x">,</span> <span class="n">repr</span><span class="x">(</span><span class="o">$</span><span class="n">T</span><span class="x">)))</span>
<span class="k">end</span></code></pre></figure>

<p>In <a href="http://localhost:4000/machine-learning/2024/07/27/micrograd-1-chainrules#chainrules-definition">part 1</a> the most generic method of <code class="language-plaintext highlighter-rouge">rrule</code> was defined for an <code class="language-plaintext highlighter-rouge">Any</code> first argument, so if the compiler dispatches to this method it means no specific <code class="language-plaintext highlighter-rouge">rrule</code> was found.<sup id="fnref:has_chain_rule" role="doc-noteref"><a href="#fn:has_chain_rule" class="footnote" rel="footnote">2</a></sup></p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">IRTools</span><span class="o">:</span> <span class="n">meta</span>
<span class="k">function</span><span class="nf"> has_chain_rrule</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">)</span>
    <span class="n">Tr</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">rrule</span><span class="x">),</span> <span class="n">T</span><span class="o">.</span><span class="n">parameters</span><span class="o">...</span><span class="x">}</span>
    <span class="n">meta_T</span> <span class="o">=</span> <span class="n">meta</span><span class="x">(</span><span class="n">Tr</span><span class="x">;</span> <span class="n">world</span><span class="o">=</span><span class="n">world</span><span class="x">)</span>
    <span class="k">if</span> <span class="n">isnothing</span><span class="x">(</span><span class="n">meta_T</span><span class="x">)</span>
        <span class="k">return</span> <span class="nb">false</span>
    <span class="k">end</span>
    <span class="n">method_</span> <span class="o">=</span> <span class="n">meta_T</span><span class="o">.</span><span class="n">method</span>
    <span class="n">sig</span> <span class="o">=</span> <span class="n">method_</span><span class="o">.</span><span class="n">sig</span>
    <span class="o">!</span><span class="x">(</span><span class="n">sig</span> <span class="k">isa</span> <span class="kt">DataType</span><span class="x">)</span> <span class="o">||</span> <span class="x">(</span><span class="n">sig</span><span class="o">.</span><span class="n">parameters</span><span class="x">[</span><span class="mi">2</span><span class="x">]</span> <span class="o">!==</span> <span class="kt">Any</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>Let’s test all this code from bottom to top for a function with an <code class="language-plaintext highlighter-rouge">rrule</code> and one without: <code class="language-plaintext highlighter-rouge">+</code> and <code class="language-plaintext highlighter-rouge">f(a,b)=a/(a+b*b)</code>.
As a reminder, generated functions only have access to a variables types, so to test the <code class="language-plaintext highlighter-rouge">_generate_pullback</code> and all functions under it, we can only work with the types.</p>

<p>Firstly, for <code class="language-plaintext highlighter-rouge">+</code> acting on floats (redefine <code class="language-plaintext highlighter-rouge">@generated pullback</code> if necessary):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">world</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">get_world_counter</span><span class="x">()</span>
<span class="n">T</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="o">+</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">}</span>
<span class="n">has_chain_rrule</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">)</span> <span class="c"># true</span>
<span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="o">+</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">)</span> <span class="c"># :(rrule(f, args...))</span>
<span class="n">pullback</span><span class="x">(</span><span class="o">+</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">)</span> <span class="c"># (3.0, var"#add_back#5"())</span></code></pre></figure>

<p>Now for <code class="language-plaintext highlighter-rouge">f</code>, also acting on floats:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">world</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">get_world_counter</span><span class="x">()</span>
<span class="n">T</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">}</span>
<span class="n">has_chain_rrule</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">)</span> <span class="c"># false</span>
<span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">)</span> <span class="c"># :(error(...))</span>
<span class="n">pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">)</span> <span class="c"># ERROR: No rrule found for ...</span></code></pre></figure>

<p>The more interesting task is to inspect <code class="language-plaintext highlighter-rouge">f</code> and apply the equations of section 2 to fully differentiate with respect to all input parameters.</p>

<h3 id="ir">3.3 IR</h3>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/micrograd/compiler_diagram.png" alt="Julia compiler steps" />
<figcaption>Source: <a href="https://docs.julialang.org/en/v1/devdocs/eval/">Julia Docs eval</a></figcaption>
</figure>

<p>The first step is to create a Wengert list for <code class="language-plaintext highlighter-rouge">f</code> in Intermediate Representation (IR) form.
Julia already does this as part of the compilation process.
IRTools.jl mimics this internal IR form with its own custom IR struct.
It can be generated as follows:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">IRTools</span><span class="o">:</span> <span class="n">IR</span><span class="x">,</span> <span class="n">meta</span>
<span class="n">T</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">}</span>
<span class="n">m</span> <span class="o">=</span> <span class="n">meta</span><span class="x">(</span><span class="n">T</span><span class="x">;</span> <span class="n">world</span><span class="o">=</span><span class="n">Base</span><span class="o">.</span><span class="n">get_world_counter</span><span class="x">())</span>
<span class="n">ir</span> <span class="o">=</span> <span class="n">IR</span><span class="x">(</span><span class="n">m</span><span class="x">)</span>
<span class="cm">#=
1: (%1, %2, %3)
  %4 = %3 * %3
  %5 = %2 + %4
  %6 = %2 / %5
  return %6
=#</span></code></pre></figure>

<p>The returned object corresponds exactly to $\ref{eq:f_wengert}$.</p>

<p>Using this knowledge, we can now create a new function <code class="language-plaintext highlighter-rouge">_generate_pullback_via_decomposition</code> which will be called if no <code class="language-plaintext highlighter-rouge">rrule</code> exists.
It uses the IR to create the primal (equation $\ref{eq:primal}$) (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/compiler/emit.jl#L98">source</a>).</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">IRTools</span><span class="o">:</span> <span class="n">meta</span><span class="x">,</span> <span class="n">IR</span><span class="x">,</span> <span class="n">blocks</span>
<span class="k">function</span><span class="nf"> _generate_pullback_via_decomposition</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">)</span>
    <span class="n">m</span> <span class="o">=</span> <span class="n">meta</span><span class="x">(</span><span class="n">T</span><span class="x">;</span> <span class="n">world</span><span class="o">=</span><span class="n">world</span><span class="x">)</span>
    <span class="n">isnothing</span><span class="x">(</span><span class="n">m</span><span class="x">)</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="nb">nothing</span>
    <span class="n">ir</span> <span class="o">=</span> <span class="n">IR</span><span class="x">(</span><span class="n">m</span><span class="x">)</span>
    <span class="n">length</span><span class="x">(</span><span class="n">blocks</span><span class="x">(</span><span class="n">ir</span><span class="x">))</span> <span class="o">==</span> <span class="mi">1</span> <span class="o">||</span> <span class="n">error</span><span class="x">(</span><span class="s">"control flow is not supported"</span><span class="x">)</span>
    <span class="n">pr</span><span class="x">,</span> <span class="n">calls</span> <span class="o">=</span> <span class="n">primal</span><span class="x">(</span><span class="n">ir</span><span class="x">,</span> <span class="n">T</span><span class="x">)</span>
    <span class="n">m</span><span class="x">,</span> <span class="n">pr</span><span class="x">,</span> <span class="n">calls</span>
<span class="k">end</span></code></pre></figure>

<h3 id="primal">3.4 Primal</h3>

<p>The goal here is to create an IR for equation $\ref{eq:primal}$.
This is what it will look like:</p>

<figure class="highlight"><pre><code class="language-plaintext" data-lang="plaintext">1: (%1, %2, %3)
  %4 = Main.pullback(Main.:*, %3, %3)
  %5 = Base.getindex(%4, 1)
  %6 = Base.getindex(%4, 2)
  %7 = Main.pullback(Main.:+, %2, %5)
  %8 = Base.getindex(%7, 1)
  %9 = Base.getindex(%7, 2)
  %10 = Main.pullback(Main.:/, %2, %8)
  %11 = Base.getindex(%10, 1)
  %12 = Base.getindex(%10, 2)
  %13 = Base.tuple(%6, %9, %12)
  %14 = (Pullback{Tuple{typeof(f), Float64, Float64}})(%13)
  %15 = Base.tuple(%11, %14)
  return %15</code></pre></figure>

<p>Although harder to read, this code represents the same code as the expressions in <a href="/machine-learning/2024/08/03/micrograd-2-expr#primal">part 2</a>.</p>

<p>The primal function first wraps the existing IR with <code class="language-plaintext highlighter-rouge">Pipe</code> to make inserts more efficient.
It defines two arrays to store information (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/compiler/reverse.jl#L201">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">IRTools</span><span class="o">:</span> <span class="n">block</span><span class="x">,</span> <span class="n">isexpr</span><span class="x">,</span> <span class="n">finish</span><span class="x">,</span> <span class="kt">Pipe</span><span class="x">,</span> <span class="n">Variable</span><span class="x">,</span> <span class="k">return</span><span class="o">!</span><span class="x">,</span> <span class="n">returnvalue</span><span class="x">,</span> <span class="n">stmt</span><span class="x">,</span> <span class="n">xcall</span>
<span class="k">function</span><span class="nf"> primal</span><span class="x">(</span><span class="n">ir</span><span class="o">::</span><span class="n">IR</span><span class="x">,</span> <span class="n">T</span><span class="o">=</span><span class="kt">Any</span><span class="x">)</span>
    <span class="n">pr</span> <span class="o">=</span> <span class="n">IRTools</span><span class="o">.</span><span class="kt">Pipe</span><span class="x">(</span><span class="n">ir</span><span class="x">)</span>
    <span class="n">calls</span> <span class="o">=</span> <span class="x">[]</span>
    <span class="n">pullbacks</span> <span class="o">=</span> <span class="x">[]</span></code></pre></figure>

<p>The <code class="language-plaintext highlighter-rouge">calls</code> array stores the subset of variables that require a pullback.
Because the IR is a dictionary - <code class="language-plaintext highlighter-rouge">ir[Variable(i)]</code> returns statement <code class="language-plaintext highlighter-rouge">i</code> - this creates a direct link to the statement called.
These will be used to generate the reverse code (equation $\ref{eq:reverse}$) in the next section.</p>

<p>Next, iterate over each statement in the IR.
For each statement if it is an expression <code class="language-plaintext highlighter-rouge">:call</code> and not part of a special ignored list, replace it with three calls: the first is to <code class="language-plaintext highlighter-rouge">pullback</code> and then two calls to <code class="language-plaintext highlighter-rouge">getindex</code> to get the output variable <code class="language-plaintext highlighter-rouge">v</code> and back function <code class="language-plaintext highlighter-rouge">J</code> from the tuple <code class="language-plaintext highlighter-rouge">t</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="k">for</span> <span class="x">(</span><span class="n">v</span><span class="x">,</span> <span class="n">st</span><span class="x">)</span> <span class="k">in</span> <span class="n">pr</span>
        <span class="n">ex</span> <span class="o">=</span> <span class="n">st</span><span class="o">.</span><span class="n">expr</span>
        <span class="k">if</span> <span class="n">isexpr</span><span class="x">(</span><span class="n">ex</span><span class="x">,</span> <span class="o">:</span><span class="n">call</span><span class="x">)</span> <span class="o">&amp;&amp;</span> <span class="o">!</span><span class="n">ignored</span><span class="x">(</span><span class="n">ex</span><span class="x">)</span>
            <span class="n">t</span> <span class="o">=</span> <span class="n">insert!</span><span class="x">(</span><span class="n">pr</span><span class="x">,</span> <span class="n">v</span><span class="x">,</span> <span class="n">stmt</span><span class="x">(</span><span class="n">xcall</span><span class="x">(</span><span class="n">Main</span><span class="x">,</span> <span class="o">:</span><span class="n">pullback</span><span class="x">,</span> <span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="o">...</span><span class="x">),</span> <span class="n">line</span> <span class="o">=</span> <span class="n">st</span><span class="o">.</span><span class="n">line</span><span class="x">))</span>
            <span class="n">pr</span><span class="x">[</span><span class="n">v</span><span class="x">]</span> <span class="o">=</span> <span class="n">xcall</span><span class="x">(</span><span class="n">Base</span><span class="x">,</span> <span class="o">:</span><span class="n">getindex</span><span class="x">,</span> <span class="n">t</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span>
            <span class="n">J</span> <span class="o">=</span> <span class="n">push!</span><span class="x">(</span><span class="n">pr</span><span class="x">,</span> <span class="n">xcall</span><span class="x">(</span><span class="o">:</span><span class="n">getindex</span><span class="x">,</span> <span class="n">t</span><span class="x">,</span> <span class="mi">2</span><span class="x">))</span>
            <span class="n">push!</span><span class="x">(</span><span class="n">calls</span><span class="x">,</span> <span class="n">v</span><span class="x">)</span>
            <span class="n">push!</span><span class="x">(</span><span class="n">pullbacks</span><span class="x">,</span> <span class="n">J</span><span class="x">)</span>
        <span class="k">end</span>
    <span class="k">end</span></code></pre></figure>

<p>After working through all the statements, a final statement is added which returns a tuple with the output of the function and a <code class="language-plaintext highlighter-rouge">Pullback</code> struct which stores all the pullbacks.
In the last step the pipe is converted back into an IR.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="n">pb</span> <span class="o">=</span> <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">call</span><span class="x">,</span> <span class="n">Pullback</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">xcall</span><span class="x">(</span><span class="o">:</span><span class="n">tuple</span><span class="x">,</span> <span class="n">pullbacks</span><span class="o">...</span><span class="x">))</span>
    <span class="k">return</span><span class="o">!</span><span class="x">(</span><span class="n">pr</span><span class="x">,</span> <span class="n">xcall</span><span class="x">(</span><span class="o">:</span><span class="n">tuple</span><span class="x">,</span> <span class="n">returnvalue</span><span class="x">(</span><span class="n">block</span><span class="x">(</span><span class="n">ir</span><span class="x">,</span> <span class="mi">1</span><span class="x">)),</span> <span class="n">pb</span><span class="x">))</span>
    <span class="n">finish</span><span class="x">(</span><span class="n">pr</span><span class="x">),</span> <span class="n">calls</span>
<span class="k">end</span></code></pre></figure>

<p>This code requires a definition for the <code class="language-plaintext highlighter-rouge">Pullback</code> struct as well as the <code class="language-plaintext highlighter-rouge">ignored</code> function.</p>

<p>There are no closures in lowered Julia code, so instead <a href="https://fluxml.ai/Zygote.jl/stable/internals/#Closure-Conversion-1">Zygote.jl</a> stores the pullbacks in a generic struct:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> Pullback</span><span class="x">{</span><span class="n">S</span><span class="x">,</span><span class="n">T</span><span class="x">}</span>
    <span class="n">data</span><span class="o">::</span><span class="n">T</span>
<span class="k">end</span>
<span class="n">Pullback</span><span class="x">{</span><span class="n">S</span><span class="x">}(</span><span class="n">data</span><span class="x">)</span> <span class="k">where</span> <span class="n">S</span> <span class="o">=</span> <span class="n">Pullback</span><span class="x">{</span><span class="n">S</span><span class="x">,</span><span class="n">typeof</span><span class="x">(</span><span class="n">data</span><span class="x">)}(</span><span class="n">data</span><span class="x">)</span></code></pre></figure>

<p>In the next section this struct will be turned into a callable struct.
That is, for <code class="language-plaintext highlighter-rouge">back=Pullback{S}(data)</code>, we will create a generated function that dispatches on itself: <code class="language-plaintext highlighter-rouge">(j::Pullback)(Δ)</code> so that we can call <code class="language-plaintext highlighter-rouge">back(Δ)</code>. This <code class="language-plaintext highlighter-rouge">back</code> has all the information to generate the reverse pass independently of the forward pass: the method can be retrieved using <code class="language-plaintext highlighter-rouge">meta(S)</code> and the relevant data and input parameters from  <code class="language-plaintext highlighter-rouge">back.data</code>.</p>

<p>Here is the ignored functions list (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/compiler/reverse.jl#L171">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> ignored</span><span class="x">(</span><span class="n">ex</span><span class="o">::</span><span class="kt">Expr</span><span class="x">)</span>
    <span class="n">f</span> <span class="o">=</span> <span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span>
    <span class="n">ignored_f</span><span class="x">(</span><span class="n">f</span><span class="x">)</span>
<span class="k">end</span>

<span class="n">ignored_f</span><span class="x">(</span><span class="n">f</span><span class="x">)</span> <span class="o">=</span> <span class="n">f</span> <span class="k">in</span> <span class="x">(</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Base</span><span class="x">,</span> <span class="o">:</span><span class="n">not_int</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="o">.</span><span class="n">Intrinsics</span><span class="x">,</span> <span class="o">:</span><span class="n">not_int</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="x">,</span> <span class="o">:</span><span class="x">(</span><span class="o">===</span><span class="x">)),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="x">,</span> <span class="o">:</span><span class="n">apply_type</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="x">,</span> <span class="o">:</span><span class="n">typeof</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="x">,</span> <span class="o">:</span><span class="n">throw</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Base</span><span class="x">,</span> <span class="o">:</span><span class="n">kwerr</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="x">,</span> <span class="o">:</span><span class="n">kwfunc</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="x">,</span> <span class="o">:</span><span class="n">isdefined</span><span class="x">)</span>
<span class="x">)</span></code></pre></figure>

<p>Running this code:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">world</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">get_world_counter</span><span class="x">()</span>
<span class="n">T</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">}</span>
<span class="n">pr</span><span class="x">,</span> <span class="n">calls</span> <span class="o">=</span><span class="n">_generate_pullback_via_decomposition</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">)</span></code></pre></figure>

<p>gives the IR at the start.</p>

<h3 id="convert">3.5 Convert </h3>

<p>To evaluate the IR it needs to be converted into a <code class="language-plaintext highlighter-rouge">CodeInfo</code> struct.
Zygote.jl uses <code class="language-plaintext highlighter-rouge">IRTools.Inner.update!</code> to modify the existing struct in <code class="language-plaintext highlighter-rouge">meta_T.code</code>.
To me, it makes more sense to construct a new code info block directly from the IR using a slightly modified version of <code class="language-plaintext highlighter-rouge">IRTools.Inner.build_codeinfo</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">IRTools</span><span class="o">:</span> <span class="n">arguments</span>
<span class="k">using</span> <span class="n">IRTools</span><span class="o">.</span><span class="n">Inner</span><span class="o">:</span> <span class="n">dummy_m</span><span class="x">,</span> <span class="n">update!</span>
<span class="k">function</span><span class="nf"> build_codeinfo_</span><span class="x">(</span><span class="n">ir</span><span class="o">::</span><span class="n">IR</span><span class="x">)</span>
    <span class="n">ir</span> <span class="o">=</span> <span class="n">copy</span><span class="x">(</span><span class="n">ir</span><span class="x">)</span>
    <span class="n">ci</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">uncompressed_ir</span><span class="x">(</span><span class="n">dummy_m</span><span class="x">)</span>
    <span class="n">ci</span><span class="o">.</span><span class="n">inlineable</span> <span class="o">=</span> <span class="nb">true</span>
    <span class="k">for</span> <span class="n">arg</span> <span class="k">in</span> <span class="n">arguments</span><span class="x">(</span><span class="n">ir</span><span class="x">)</span>
    <span class="nd">@static</span> <span class="k">if</span> <span class="nb">VERSION</span> <span class="o">&gt;=</span> <span class="n">v</span><span class="s">"1.10.0-DEV.870"</span>
        <span class="n">isnothing</span><span class="x">(</span><span class="n">ci</span><span class="o">.</span><span class="n">slottypes</span><span class="x">)</span> <span class="o">&amp;&amp;</span> <span class="x">(</span><span class="n">ci</span><span class="o">.</span><span class="n">slottypes</span> <span class="o">=</span> <span class="kt">Any</span><span class="x">[])</span>
        <span class="n">push!</span><span class="x">(</span><span class="n">ci</span><span class="o">.</span><span class="n">slottypes</span><span class="x">,</span> <span class="kt">Type</span><span class="x">)</span>
    <span class="k">end</span>
    <span class="n">push!</span><span class="x">(</span><span class="n">ci</span><span class="o">.</span><span class="n">slotnames</span><span class="x">,</span> <span class="kt">Symbol</span><span class="x">(</span><span class="s">""</span><span class="x">))</span>
    <span class="n">push!</span><span class="x">(</span><span class="n">ci</span><span class="o">.</span><span class="n">slotflags</span><span class="x">,</span> <span class="mi">0</span><span class="x">)</span>
    <span class="k">end</span>
    <span class="c">#argument!(ir, at = 1) # argument for #self# might already exist</span>
    <span class="n">update!</span><span class="x">(</span><span class="n">ci</span><span class="x">,</span> <span class="n">ir</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>This can now be used in <code class="language-plaintext highlighter-rouge">_generate_pullback</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">IRTools</span><span class="o">:</span> <span class="n">argument!</span><span class="x">,</span> <span class="n">varargs!</span><span class="x">,</span> <span class="n">pis!</span><span class="x">,</span> <span class="n">slots!</span>
<span class="k">function</span><span class="nf"> _generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
    <span class="n">T</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">}</span>
    <span class="k">if</span> <span class="x">(</span><span class="n">has_chain_rrule</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">))</span>
        <span class="k">return</span> <span class="o">:</span><span class="x">(</span><span class="n">rrule</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">))</span>
    <span class="k">end</span>    
    <span class="n">g</span> <span class="o">=</span> <span class="n">_generate_pullback_via_decomposition</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">)</span>
    <span class="k">if</span> <span class="n">isnothing</span><span class="x">(</span><span class="n">g</span><span class="x">)</span>
        <span class="k">return</span> <span class="o">:</span><span class="x">(</span><span class="n">error</span><span class="x">(</span><span class="s">"No method found for "</span><span class="x">,</span> <span class="n">repr</span><span class="x">(</span><span class="o">$</span><span class="n">T</span><span class="x">),</span> <span class="s">" in world "</span><span class="x">,</span> <span class="o">$</span><span class="n">world</span><span class="x">))</span>
    <span class="k">end</span>
    <span class="n">m</span><span class="x">,</span> <span class="n">pr</span><span class="x">,</span> <span class="n">backs</span> <span class="o">=</span> <span class="n">g</span>
    <span class="n">pr</span> <span class="o">=</span> <span class="n">varargs!</span><span class="x">(</span><span class="n">m</span><span class="x">,</span> <span class="n">pr</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span> <span class="c"># add getfield for each index in args, offset by 1 for f</span>
    <span class="n">pr</span> <span class="o">=</span> <span class="n">slots!</span><span class="x">(</span><span class="n">pis!</span><span class="x">(</span><span class="n">pr</span><span class="x">))</span>
    <span class="n">argument!</span><span class="x">(</span><span class="n">pr</span><span class="x">,</span> <span class="n">at</span> <span class="o">=</span> <span class="mi">1</span><span class="x">)</span> <span class="c"># add #self#</span>
    <span class="n">ci</span> <span class="o">=</span> <span class="n">build_codeinfo_</span><span class="x">(</span><span class="n">pr</span><span class="x">)</span>
    <span class="n">ci</span><span class="o">.</span><span class="n">slotnames</span> <span class="o">=</span> <span class="x">[</span><span class="kt">Symbol</span><span class="x">(</span><span class="s">"#self#"</span><span class="x">),</span> <span class="o">:</span><span class="n">f</span><span class="x">,</span> <span class="o">:</span><span class="n">args</span><span class="x">]</span>
    <span class="n">ci</span>
<span class="k">end</span></code></pre></figure>

<p>Testing (you should redefine the <code class="language-plaintext highlighter-rouge">@generated pullback</code> function first):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">world</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">get_world_counter</span><span class="x">()</span>
<span class="n">pr</span> <span class="o">=</span> <span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">)</span> <span class="c"># CodeInfo(...)</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">)</span> <span class="c"># (0.2,Pullback{...})</span></code></pre></figure>

<h3 id="reverse">3.6 Reverse</h3>

<p>The goal is to now turn <code class="language-plaintext highlighter-rouge">Pullback</code> into a callable struct so that we can call <code class="language-plaintext highlighter-rouge">back(1.0)</code> to evaluate equations $\ref{eq:reverse}$ and $\ref{eq:accumulate}$.
With <code class="language-plaintext highlighter-rouge">typeof(back)</code> and <code class="language-plaintext highlighter-rouge">back.data</code> we have all the information to do this independent from the forward pass.
The result will be:</p>

<div class="message-container info-message">
  <div class="message-icon fa fa-fw fa-2x fa-exclamation-circle"></div>
  <div class="content-container">
    <div class="message-body">
      There are unused variables here which can be removed e.g. <code>%8</code> (s̄elf). The code here does not do such optimisations to keep things simple.
    </div>
  </div>
</div>

<figure class="highlight"><pre><code class="language-plaintext" data-lang="plaintext">(%1, %2)
  %3 = Base.getfield(%1, :data)
  %4 = Base.getindex(%3, 1)
  %5 = Base.getindex(%3, 2)
  %6 = Base.getindex(%3, 3)
  %7 = (%6)(%2)
  %8 = Base.getindex(%7, 1)
  %9 = Base.getindex(%7, 2)
  %10 = Base.getindex(%7, 3)
  %11 = (%5)(%10)
  %12 = Base.getindex(%11, 1)
  %13 = Base.getindex(%11, 2)
  %14 = Base.getindex(%11, 3)
  %15 = (%4)(%14)
  %16 = Base.getindex(%15, 1)
  %17 = Base.getindex(%15, 2)
  %18 = Base.getindex(%15, 3)
  %19 = Main.accum(%9, %13)
  %20 = Main.accum(%17, %18)
  %21 = Base.tuple(nothing, %19, %20)
  return %21</code></pre></figure>

<p>Although harder to read, this code represents the same code as the expressions in <a href="/machine-learning/2024/08/03/micrograd-2-expr#reverse">part 2</a>.</p>

<p>As with the forward pass, an internal function <code class="language-plaintext highlighter-rouge">_generate_callable_pullback</code> will do most of the work:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">IRTools</span><span class="o">:</span> <span class="n">blocks</span><span class="x">,</span> <span class="n">meta</span><span class="x">,</span> <span class="n">slots!</span><span class="x">,</span> <span class="n">inlineable!</span>
<span class="k">function</span><span class="nf"> _generate_callable_pullback</span><span class="x">(</span><span class="n">j</span><span class="o">::</span><span class="kt">Type</span><span class="x">{</span><span class="o">&lt;:</span><span class="n">Pullback</span><span class="x">{</span><span class="n">S</span><span class="x">,</span> <span class="n">T</span><span class="x">}},</span> <span class="n">world</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">S</span><span class="x">,</span> <span class="n">T</span><span class="x">}</span>
    <span class="n">m</span> <span class="o">=</span> <span class="n">meta</span><span class="x">(</span><span class="n">S</span><span class="x">;</span> <span class="n">world</span><span class="o">=</span><span class="n">world</span><span class="x">)</span>
    <span class="n">ir</span> <span class="o">=</span> <span class="n">IR</span><span class="x">(</span><span class="n">m</span><span class="x">)</span>
    <span class="n">isnothing</span><span class="x">(</span><span class="n">ir</span><span class="x">)</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="o">:</span><span class="x">(</span><span class="n">error</span><span class="x">(</span><span class="s">"Non-differentiable function "</span><span class="x">,</span> <span class="n">repr</span><span class="x">(</span><span class="n">args</span><span class="x">[</span><span class="mi">1</span><span class="x">])))</span>
    <span class="n">length</span><span class="x">(</span><span class="n">blocks</span><span class="x">(</span><span class="n">ir</span><span class="x">))</span> <span class="o">==</span> <span class="mi">1</span> <span class="o">||</span> <span class="n">error</span><span class="x">(</span><span class="s">"control flow is not supported"</span><span class="x">)</span>
    <span class="n">back</span> <span class="o">=</span> <span class="n">reverse_differentiate</span><span class="x">(</span><span class="n">ir</span><span class="x">)</span>
    <span class="n">back</span> <span class="o">=</span> <span class="n">slots!</span><span class="x">(</span><span class="n">inlineable!</span><span class="x">(</span><span class="n">back</span><span class="x">))</span>
    <span class="n">ci</span> <span class="o">=</span> <span class="n">build_codeinfo_</span><span class="x">(</span><span class="n">back</span><span class="x">)</span>
    <span class="n">ci</span><span class="o">.</span><span class="n">slotnames</span> <span class="o">=</span> <span class="x">[</span><span class="kt">Symbol</span><span class="x">(</span><span class="s">"#self#"</span><span class="x">),</span> <span class="o">:</span><span class="n">Δ</span><span class="x">]</span>
    <span class="n">ci</span>
<span class="k">end</span></code></pre></figure>

<p>The <code class="language-plaintext highlighter-rouge">reverse_differentiate</code> function is a simplified version of <a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/compiler/reverse.jl#L293">Zygote.adjoint</a> and <a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/compiler/emit.jl#L65">Zygote.reverse_stacks!</a>.</p>

<p>To start, a dictionary is created to store the gradients.
It maps variable names (symbols) to an array of gradients.
It is not accessed directly (e.g. <code class="language-plaintext highlighter-rouge">grads[x]</code>) but rather through the closure functions <code class="language-plaintext highlighter-rouge">grad</code> and <code class="language-plaintext highlighter-rouge">grad!</code> which automatically handle the arrays.
The first gradient stored is <code class="language-plaintext highlighter-rouge">%2=Δ</code> associated with the final return value of the forward pass.
(<code class="language-plaintext highlighter-rouge">xaccum</code>  will be defined shortly.)</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">IRTools</span><span class="o">:</span> <span class="n">argument!</span><span class="x">,</span> <span class="n">arguments</span><span class="x">,</span> <span class="n">isexpr</span><span class="x">,</span> <span class="n">returnvalue</span><span class="x">,</span> <span class="n">xcall</span><span class="x">,</span> <span class="k">return</span><span class="o">!</span>
<span class="k">function</span><span class="nf"> reverse_differentiate</span><span class="x">(</span><span class="n">forw</span><span class="o">::</span><span class="n">IR</span><span class="x">)</span>
    <span class="n">grads</span> <span class="o">=</span> <span class="kt">Dict</span><span class="x">()</span>
    <span class="n">grad!</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">x̄</span><span class="x">)</span> <span class="o">=</span> <span class="n">push!</span><span class="x">(</span><span class="n">get!</span><span class="x">(</span><span class="n">grads</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="x">[]),</span> <span class="n">x̄</span><span class="x">)</span>
    <span class="n">grad</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">xaccum</span><span class="x">(</span><span class="n">get</span><span class="x">(</span><span class="n">grads</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="x">[])</span><span class="o">...</span><span class="x">)</span>
    <span class="n">ir</span> <span class="o">=</span> <span class="n">empty</span><span class="x">(</span><span class="n">forw</span><span class="x">)</span>
    <span class="n">self</span> <span class="o">=</span> <span class="n">argument!</span><span class="x">(</span><span class="n">ir</span><span class="x">,</span> <span class="n">at</span> <span class="o">=</span> <span class="mi">1</span><span class="x">,</span> <span class="n">insert</span><span class="o">=</span><span class="nb">false</span><span class="x">)</span>
    <span class="n">grad!</span><span class="x">(</span><span class="n">returnvalue</span><span class="x">(</span><span class="n">block</span><span class="x">(</span><span class="n">forw</span><span class="x">,</span> <span class="mi">1</span><span class="x">)),</span> <span class="n">IRTools</span><span class="o">.</span><span class="n">argument!</span><span class="x">(</span><span class="n">ir</span><span class="x">))</span></code></pre></figure>

<p>The first statement retrieves the <code class="language-plaintext highlighter-rouge">data</code> field in the struct.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="n">data</span> <span class="o">=</span> <span class="n">push!</span><span class="x">(</span><span class="n">ir</span><span class="x">,</span> <span class="n">xcall</span><span class="x">(</span><span class="o">:</span><span class="n">getfield</span><span class="x">,</span> <span class="n">self</span><span class="x">,</span> <span class="kt">QuoteNode</span><span class="x">(</span><span class="o">:</span><span class="n">data</span><span class="x">)))</span></code></pre></figure>

<p>Next the code retrieves all the calls with pullbacks from the primal and loops over them, calling the pullbacks one by one.
For each call it also loops over the input arguments and unpacks them one by one.
Each variable’s gradient is added to <code class="language-plaintext highlighter-rouge">grads</code> and may be used later in the loop.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="n">pr</span><span class="x">,</span> <span class="n">calls</span> <span class="o">=</span> <span class="n">primal</span><span class="x">(</span><span class="n">forw</span><span class="x">)</span>
    <span class="n">pullbacks</span> <span class="o">=</span> <span class="kt">Dict</span><span class="x">(</span><span class="n">calls</span><span class="x">[</span><span class="n">i</span><span class="x">]</span> <span class="o">=&gt;</span> <span class="n">push!</span><span class="x">(</span><span class="n">ir</span><span class="x">,</span> <span class="n">xcall</span><span class="x">(</span><span class="o">:</span><span class="n">getindex</span><span class="x">,</span> <span class="n">data</span><span class="x">,</span> <span class="n">i</span><span class="x">))</span> <span class="k">for</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">1</span><span class="o">:</span><span class="n">length</span><span class="x">(</span><span class="n">calls</span><span class="x">))</span>
    <span class="k">for</span> <span class="n">v</span> <span class="k">in</span> <span class="n">reverse</span><span class="x">(</span><span class="n">keys</span><span class="x">(</span><span class="n">forw</span><span class="x">))</span>
        <span class="n">ex</span> <span class="o">=</span> <span class="n">forw</span><span class="x">[</span><span class="n">v</span><span class="x">]</span><span class="o">.</span><span class="n">expr</span>
        <span class="k">if</span> <span class="n">isexpr</span><span class="x">(</span><span class="n">ex</span><span class="x">,</span> <span class="o">:</span><span class="n">call</span><span class="x">)</span> <span class="o">&amp;&amp;</span> <span class="o">!</span><span class="n">ignored</span><span class="x">(</span><span class="n">ex</span><span class="x">)</span>
            <span class="n">Δs</span> <span class="o">=</span> <span class="n">push!</span><span class="x">(</span><span class="n">ir</span><span class="x">,</span> <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">call</span><span class="x">,</span> <span class="n">pullbacks</span><span class="x">[</span><span class="n">v</span><span class="x">],</span> <span class="n">grad</span><span class="x">(</span><span class="n">v</span><span class="x">)))</span>
            <span class="k">for</span> <span class="x">(</span><span class="n">i</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span> <span class="k">in</span> <span class="n">enumerate</span><span class="x">(</span><span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="x">)</span>
                <span class="n">grad!</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">push!</span><span class="x">(</span><span class="n">ir</span><span class="x">,</span> <span class="n">xcall</span><span class="x">(</span><span class="o">:</span><span class="n">getindex</span><span class="x">,</span> <span class="n">Δs</span><span class="x">,</span> <span class="n">i</span><span class="x">)))</span>
            <span class="k">end</span>
        <span class="k">end</span>
    <span class="k">end</span></code></pre></figure>

<p>Finally, the last call retrieves all the necessary gradients for the input arguments and returns the IR:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="k">return</span><span class="o">!</span><span class="x">(</span><span class="n">ir</span><span class="x">,</span> <span class="n">xcall</span><span class="x">(</span><span class="o">:</span><span class="n">tuple</span><span class="x">,</span> <span class="x">[</span><span class="n">grad</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="k">for</span> <span class="n">x</span> <span class="k">in</span> <span class="n">arguments</span><span class="x">(</span><span class="n">forw</span><span class="x">)]</span><span class="o">...</span><span class="x">))</span>
<span class="k">end</span></code></pre></figure>

<p>This code calls a <code class="language-plaintext highlighter-rouge">xaccum</code> function. It is as follows:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">xaccum</span><span class="x">()</span> <span class="o">=</span> <span class="nb">nothing</span>
<span class="n">xaccum</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span>
<span class="n">xaccum</span><span class="x">(</span><span class="n">xs</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="n">xcall</span><span class="x">(</span><span class="n">Main</span><span class="x">,</span> <span class="o">:</span><span class="n">accum</span><span class="x">,</span> <span class="n">xs</span><span class="o">...</span><span class="x">)</span></code></pre></figure>

<p>The <code class="language-plaintext highlighter-rouge">xaccum</code> function calls an internal accumulate function if it acts on multiple inputs. 
At its simplest, <code class="language-plaintext highlighter-rouge">accum</code> is the same as <code class="language-plaintext highlighter-rouge">sum</code>. 
However it also handles <code class="language-plaintext highlighter-rouge">nothing</code> inputs, <code class="language-plaintext highlighter-rouge">Tuples</code>s and <code class="language-plaintext highlighter-rouge">NameTuple</code>s (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/lib/lib.jl#L14">source</a>).</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">accum</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">y</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span> <span class="o">===</span> <span class="nb">nothing</span> <span class="o">?</span> <span class="n">y</span> <span class="o">:</span> <span class="n">y</span> <span class="o">===</span> <span class="nb">nothing</span> <span class="o">?</span> <span class="n">x</span> <span class="o">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">accum</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Tuple</span><span class="x">,</span> <span class="n">ys</span><span class="o">::</span><span class="kt">Tuple</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">accum</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="n">ys</span><span class="o">...</span><span class="x">)</span>
<span class="n">accum</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">y</span><span class="x">,</span> <span class="n">zs</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="n">accum</span><span class="x">(</span><span class="n">accum</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">y</span><span class="x">),</span> <span class="n">zs</span><span class="o">...</span><span class="x">)</span>
<span class="nd">@generated</span> <span class="k">function</span><span class="nf"> accum</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">NamedTuple</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">NamedTuple</span><span class="x">)</span>
    <span class="c"># assumes that y has no keys apart from those also in x</span>
    <span class="n">fieldnames</span><span class="x">(</span><span class="n">y</span><span class="x">)</span> <span class="n">⊆</span> <span class="n">fieldnames</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">||</span> <span class="n">throw</span><span class="x">(</span><span class="kt">ArgumentError</span><span class="x">(</span><span class="s">"</span><span class="si">$</span><span class="s">y keys must be a subset of </span><span class="si">$</span><span class="s">x keys"</span><span class="x">))</span>
    <span class="n">grad</span><span class="x">(</span><span class="n">field</span><span class="x">)</span> <span class="o">=</span> <span class="n">field</span> <span class="k">in</span> <span class="n">fieldnames</span><span class="x">(</span><span class="n">y</span><span class="x">)</span> <span class="o">?</span> <span class="o">:</span><span class="x">(</span><span class="n">y</span><span class="o">.$</span><span class="n">field</span><span class="x">)</span> <span class="o">:</span> <span class="o">:</span><span class="nb">nothing</span>
    <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">tuple</span><span class="x">,</span> <span class="x">[</span><span class="o">:</span><span class="x">(</span><span class="o">$</span><span class="n">f</span><span class="o">=</span><span class="n">accum</span><span class="x">(</span><span class="n">x</span><span class="o">.$</span><span class="n">f</span><span class="x">,</span> <span class="o">$</span><span class="x">(</span><span class="n">grad</span><span class="x">(</span><span class="n">f</span><span class="x">))))</span> <span class="k">for</span> <span class="n">f</span> <span class="k">in</span> <span class="n">fieldnames</span><span class="x">(</span><span class="n">x</span><span class="x">)]</span><span class="o">...</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>Examples:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">accum</span><span class="x">(</span><span class="mi">1</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="nb">nothing</span><span class="x">,</span> <span class="mi">3</span><span class="x">)</span> <span class="c"># 6</span>
<span class="n">accum</span><span class="x">((</span><span class="mi">1</span><span class="x">,</span> <span class="mi">2</span><span class="x">),</span> <span class="x">(</span><span class="mi">3</span><span class="x">,</span> <span class="mi">4</span><span class="x">))</span> <span class="c"># (3, 6)</span>
<span class="n">accum</span><span class="x">((;</span><span class="n">a</span><span class="o">=</span><span class="mi">3</span><span class="x">,</span> <span class="n">b</span><span class="o">=</span><span class="mi">2</span><span class="x">),</span> <span class="x">(;</span><span class="n">a</span><span class="o">=</span><span class="mi">1</span><span class="x">))</span> <span class="c"># (a = 4, b = 2)</span></code></pre></figure>

<p>Finally, dispatch on the <code class="language-plaintext highlighter-rouge">Pullback</code> struct to turn it into a callable struct:</p>

<div class="accordion" id="accordianJuliaVersions-callable">
  <div class="card">
    <div class="card-header" id="generatedJuliaPre10s-callable">
      <div class="mb-0">
        <button class="btn btn-link btn-block text-left" type="button" data-toggle="collapse" data-target="#collapsePre10-callable" aria-expanded="true" aria-controls="collapsePre10-callable">
          Julia Version before 1.10
        </button>
      </div>
    </div>
    <div id="collapsePre10-callable" class="collapse show" aria-labelledby="generatedJuliaPre10s-callable" data-parent="#accordianJuliaVersions-callable">
      <div class="card-body">

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="nd">@generated</span> <span class="k">function</span><span class="nf"> </span><span class="o">(</span><span class="n">methodinstance</span><span class="o">::</span><span class="n">Pullback</span><span class="x">)(</span><span class="n">Δ</span><span class="x">)</span>
    <span class="n">_generate_callable_pullback</span><span class="x">(</span><span class="n">methodinstance</span><span class="x">,</span> <span class="nb">nothing</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

      </div>
    </div>
  </div>
  <div class="card">
    <div class="card-header" id="generatedJuliaPost10-callable">
      <div class="mb-0">
        <button class="btn btn-link btn-block text-left collapsed" type="button" data-toggle="collapse" data-target="#collapsePost10-callable" aria-expanded="false" aria-controls="collapsePost10-callable">
          Julia Version after 1.10
        </button>
      </div>
    </div>
    <div id="collapsePost10-callable" class="collapse" aria-labelledby="generatedJuliaPost10-callable" data-parent="#accordianJuliaVersions-callable">
      <div class="card-body">

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _callable_pullback_generator</span><span class="x">(</span><span class="n">world</span><span class="o">::</span><span class="kt">UInt</span><span class="x">,</span> <span class="n">source</span><span class="x">,</span> <span class="n">self</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span>
    <span class="n">ret</span> <span class="o">=</span> <span class="n">_generate_callable_pullback</span><span class="x">(</span><span class="n">self</span><span class="x">,</span> <span class="n">world</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span>
    <span class="n">ret</span> <span class="k">isa</span> <span class="n">Core</span><span class="o">.</span><span class="n">CodeInfo</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="n">ret</span>
    <span class="n">stub</span> <span class="o">=</span> <span class="n">Core</span><span class="o">.</span><span class="n">GeneratedFunctionStub</span><span class="x">(</span><span class="n">identity</span><span class="x">,</span> <span class="n">Core</span><span class="o">.</span><span class="n">svec</span><span class="x">(</span><span class="o">:</span><span class="n">methodinstance</span><span class="x">,</span> <span class="o">:</span><span class="n">Δ</span><span class="x">),</span> <span class="n">Core</span><span class="o">.</span><span class="n">svec</span><span class="x">())</span> <span class="c"># names must match symbols in _generate_callable_pullback</span>
    <span class="n">stub</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">source</span><span class="x">,</span> <span class="n">ret</span><span class="x">)</span>
<span class="k">end</span>

<span class="nd">@eval</span> <span class="k">function</span><span class="nf"> </span><span class="o">(</span><span class="n">j</span><span class="o">::</span><span class="n">Pullback</span><span class="x">)(</span><span class="n">Δ</span><span class="x">)</span>
    <span class="o">$</span><span class="x">(</span><span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">meta</span><span class="x">,</span> <span class="o">:</span><span class="n">generated</span><span class="x">,</span> <span class="n">_callable_pullback_generator</span><span class="x">))</span>
    <span class="o">$</span><span class="x">(</span><span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">meta</span><span class="x">,</span> <span class="o">:</span><span class="n">generated_only</span><span class="x">))</span>
<span class="k">end</span></code></pre></figure>

      </div>
    </div>
  </div>
</div>

<p>Testing:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">f</span><span class="x">(</span><span class="n">a</span><span class="x">,</span><span class="n">b</span><span class="x">)</span><span class="o">=</span><span class="n">a</span><span class="o">/</span><span class="x">(</span><span class="n">a</span><span class="o">+</span><span class="n">b</span><span class="o">*</span><span class="n">b</span><span class="x">)</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">,</span> <span class="mf">3.0</span><span class="x">)</span> <span class="c"># (0.1818, Pullback{...})</span>
<span class="n">_generate_callable_pullback</span><span class="x">(</span><span class="n">typeof</span><span class="x">(</span><span class="n">back</span><span class="x">),</span> <span class="nb">nothing</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">)</span> <span class="c"># CodeInfo for IR at start</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># (nothing, 0.0744, -0.0991)</span></code></pre></figure>

<p>The results should match equation $\ref{eq:rollup}$:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">a</span><span class="x">,</span> <span class="n">b</span> <span class="o">=</span> <span class="mf">2.0</span><span class="x">,</span> <span class="mf">3.0</span>
<span class="n">ā</span> <span class="o">=</span> <span class="n">abs2</span><span class="x">(</span><span class="n">b</span><span class="x">)</span><span class="o">/</span><span class="n">abs2</span><span class="x">(</span><span class="n">a</span><span class="o">+</span><span class="n">abs2</span><span class="x">(</span><span class="n">b</span><span class="x">))</span> <span class="c"># 0.0744</span>
<span class="n">b̄</span> <span class="o">=</span> <span class="o">-</span><span class="mi">2</span><span class="o">*</span><span class="n">a</span><span class="o">*</span><span class="n">b</span><span class="o">/</span><span class="n">abs2</span><span class="x">(</span><span class="n">a</span><span class="o">+</span><span class="n">abs2</span><span class="x">(</span><span class="n">b</span><span class="x">))</span>  <span class="c"># -0.0991</span></code></pre></figure>

<h2 id="conclusion">4 Conclusion</h2>

<p>This code works well enough for this simple case. 
It also works for the trigonometry example from <a href="/machine-learning/2024/07/27/micrograd-1-chainrules.html#chainrules-trigonometry">part 1</a>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">f</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">sin</span><span class="x">(</span><span class="n">cos</span><span class="x">(</span><span class="n">x</span><span class="x">))</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="mf">0.9</span><span class="x">)</span> <span class="c"># (0.5823, Pullback{...})</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># (nothing, -0.6368) </span></code></pre></figure>

<p>However it will fail for the polynomial model:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> Polynomial</span><span class="x">{</span><span class="n">V</span><span class="o">&lt;:</span><span class="kt">AbstractVector</span><span class="x">}</span>
    <span class="n">weights</span><span class="o">::</span><span class="n">V</span>
<span class="k">end</span>
<span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Polynomial</span><span class="x">)(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">evalpoly</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">m</span><span class="o">.</span><span class="n">weights</span><span class="x">)</span>
<span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Polynomial</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">)</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">m</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Polynomial</span><span class="x">([</span><span class="mf">3.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">,</span> <span class="o">-</span><span class="mf">3.0</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">])</span>
<span class="n">x</span> <span class="o">=</span> <span class="x">[</span><span class="mf">1.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">,</span> <span class="mf">3.0</span><span class="x">,</span> <span class="mf">4.0</span><span class="x">]</span>
<span class="n">pullback</span><span class="x">(</span><span class="n">model</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span> <span class="c"># ERROR: No method found for Tuple{typeof(fieldtype) ....}</span></code></pre></figure>

<p>The error is raised five levels down:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">pr1</span> <span class="o">=</span> <span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">Polynomial</span><span class="x">,</span> <span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span>
<span class="n">pr2</span> <span class="o">=</span> <span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="n">map</span><span class="x">),</span> <span class="n">Polynomial</span><span class="x">,</span> <span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span>
<span class="n">pr3</span> <span class="o">=</span> <span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="n">Base</span><span class="o">.</span><span class="n">Generator</span><span class="x">),</span> <span class="n">Polynomial</span><span class="x">,</span> <span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span>
<span class="n">TT</span> <span class="o">=</span> <span class="kt">Type</span><span class="x">{</span><span class="n">Base</span><span class="o">.</span><span class="n">Generator</span><span class="x">{</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">},</span> <span class="n">Polynomial</span><span class="x">{</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">}}}}</span> <span class="c"># %9</span>
<span class="n">pr4</span> <span class="o">=</span> <span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">TT</span><span class="x">,</span> <span class="n">Polynomial</span><span class="x">,</span> <span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span>
<span class="n">pr5</span> <span class="o">=</span> <span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="n">Core</span><span class="o">.</span><span class="n">fieldtype</span><span class="x">),</span> <span class="n">TT</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span> <span class="c"># error</span></code></pre></figure>

<p>This can be fixed by explicitly defining a pullback for <code class="language-plaintext highlighter-rouge">map</code>.
These and other extensions will be the goal of <a href="/machine-learning/2024/08/17/micrograd-4-ext">part 4</a>.</p>

<hr />

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:generated_reflection" role="doc-endnote">
      <p>Presumably the reason the Julia team tried to prevent reflection in generated functions is that it interferes with the compliers ability to properly predict, trigger and/or optimise compilations. <a href="#fnref:generated_reflection" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:has_chain_rule" role="doc-endnote">
      <p>Zygote.jl has more <a href="https://github.com/FluxML/Zygote.jl/blob/master/src/compiler/chainrules.jl">complex rules</a> which also consider other fallbacks, key word arguments and a possible opt out through a <code class="language-plaintext highlighter-rouge">no_rrule</code>. <a href="#fnref:has_chain_rule" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Lior Sinai</name></author><category term="machine-learning" /><category term="mathematics" /><category term="transformers" /><category term="&apos;machine" /><category term="learning&apos;" /><category term="&apos;deep" /><category term="learning&apos;" /><summary type="html"><![CDATA[A series on automatic differentiation in Julia. Part 3 uses metaprogramming based on IRTools.jl to generate a modified (primal) forward pass and to reverse differentiate it into a backward pass. This is a more robust approach than the expression based approach in Part 2.]]></summary></entry><entry><title type="html">MicroGrad.jl: Part 2 Automation with expressions</title><link href="https://liorsinai.github.io/machine-learning/2024/08/03/micrograd-2-expr.html" rel="alternate" type="text/html" title="MicroGrad.jl: Part 2 Automation with expressions" /><published>2024-08-03T00:00:00+00:00</published><updated>2024-08-10T00:00:00+00:00</updated><id>https://liorsinai.github.io/machine-learning/2024/08/03/micrograd-2-expr</id><content type="html" xml:base="https://liorsinai.github.io/machine-learning/2024/08/03/micrograd-2-expr.html"><![CDATA[<p><em>A series on automatic differentiation in Julia. Part 2 uses metaprogramming to generate a modified (primal) forward pass and to reverse differentiate it into a backward pass. This post uses an expression based approach which can be brittle. Part 3 develops a more robust approach for the same code using IRTools.jl.</em></p>

<p>This is part of a series. The other articles are:</p>
<ul>
  <li><a href="/machine-learning/2024/07/27/micrograd-1-chainrules">Part 1: ChainRules</a>.</li>
  <li><a href="/machine-learning/2024/08/10/micrograd-3-ir">Part 3: Automation with IR</a>.</li>
  <li><a href="/machine-learning/2024/08/17/micrograd-4-ext">Part 4: Extensions</a>.</li>
  <li><a href="/machine-learning/2024/08/19/micrograd-5-mlp">Part 5: MLP</a>.</li>
</ul>

<p>All source code can be found at <a href="https://github.com/LiorSinai/MicroGrad.jl">MicroGrad.jl</a>.
The code here is inspired by the example at <a href="https://github.com/FluxML/IRTools.jl/blob/master/examples/reverse.jl">IRTools.jl</a>.</p>

<h3 id="table-of-contents">Table of Contents</h3>

<nav id="toc"></nav>
<script src="/assets/makeTableOfContents.js"></script>

<h2 id="introduction">1 Introduction</h2>

<p><a href="/machine-learning/2024/07/27/micrograd-1-chainrules">Part 1</a> introduced the <code class="language-plaintext highlighter-rouge">rrule</code> for implementing chain rules.
The challenge now is to automate it.
This will be done through metaprogramming and generated functions.</p>

<div class="message-container warning-message">
    <div class="message-icon fa fa-fw fa-2x fa-exclamation-triangle">
    </div>
    <div class="content-container">
        <div class="message-body">
        Metaprogramming is a powerful tool, but it introduces complexity that can make code more difficult to understand. It can easily introduces critical bugs that can crash a program.
        Care should be taken when using it.
        </div>
    </div>
</div>

<p>For example, from part 1 there are <code class="language-plaintext highlighter-rouge">rrule</code>s for <code class="language-plaintext highlighter-rouge">+</code>, <code class="language-plaintext highlighter-rouge">*</code> and <code class="language-plaintext highlighter-rouge">/</code>.
The goal is then to automatically differentiate the following:</p>

\[f(a, b) = \frac{a}{a + b^2}\]

<p>like so:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">f</span><span class="x">(</span><span class="n">a</span><span class="x">,</span> <span class="n">b</span><span class="x">)</span> <span class="o">=</span> <span class="n">a</span> <span class="o">/</span> <span class="x">(</span><span class="n">a</span> <span class="o">+</span> <span class="n">b</span><span class="o">*</span><span class="n">b</span><span class="x">)</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">,</span> <span class="mf">3.0</span><span class="x">)</span> <span class="c"># (0.1818, ∂(f))</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># (nothing, 0.0744, -0.099)</span></code></pre></figure>

<p>where <code class="language-plaintext highlighter-rouge">pullback</code> is a <code class="language-plaintext highlighter-rouge">@generated</code> function that inspects the lowered code for <code class="language-plaintext highlighter-rouge">f</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">ci</span> <span class="o">=</span> <span class="nd">@code_lowered</span> <span class="n">f</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">3</span><span class="x">)</span>
<span class="cm">#= CodeInfo(
1 ─ %1 = b * b
│   %2 = a + %1
│   %3 = a / %2
└──      return %3
)
=#</span></code></pre></figure>

<p>This is an advanced use of the Julia programming language.
You should be comfortable with the language before reading this post.
At the very least, the Julia documentation page on <a href="https://docs.julialang.org/en/v1/manual/metaprogramming/">metaprogramming</a> is required for this post and will be considered assumed knowledge, especially the sections on “Expressions and evaluation”, “Code Generation” and “Generated Functions”.</p>

<h2 id="wengert-lists">2 Differentiating Wengert Lists</h2>

<p>The <a href="https://fluxml.ai/Zygote.jl/stable/">Zygote.jl</a> automatic differentiation (AD) package is a realisation of the paper <a href="https://arxiv.org/abs/1810.07951">Don’t Unroll Adjoint: Differentiating SSA-Form Programs (2019)</a> by Michael J Innes.<br />
The paper works with Wengert lists, also known as tapes, and a generalisation of it called Static Single Assignment (SSA) form.
The aim here is to develop a minimal AD package, so this series only focuses on the sections on Wengert lists.
A consequence is that the code will not be to handle any non-linear logic in Julia, for example any control flow like <code class="language-plaintext highlighter-rouge">if</code>, <code class="language-plaintext highlighter-rouge">while</code> or <code class="language-plaintext highlighter-rouge">for</code> blocks.</p>

<p>The paper uses the same example as the introduction:</p>

\[f(a, b) = \frac{a}{a + b^2}
\tag{2.1}
\label{eq:f}\]

<p>This can be broken down into smaller steps where each intermediate variable is saved.
This is known as a Wengert list, or tape, or (backpropagation) graph:</p>

\[\begin{align}
y_1 &amp;= b \times b \\
y_2 &amp;= a + y_1 \\
y_3 &amp;= a / y_2
\end{align}
\tag{2.2}
\label{eq:f_wengert}\]

<p>To differentiate this, all function calls are wrapped with a differentiation function $\mathcal{J}$ which returns both the output $y$ and a pullback function $\mathcal{B}$.
This is called the <em>primal</em> form:</p>

\[\begin{align}
y_1, \mathcal{B}_1 &amp;\leftarrow \mathcal{J}(\times, b, b) \\
y_2, \mathcal{B}_2 &amp;\leftarrow \mathcal{J}(+, a, y_1) \\
y_3, \mathcal{B}_3 &amp;\leftarrow \mathcal{J}(/, a, y_2)
\end{align}
\tag{2.3}
\label{eq:primal}\]

<p>The pullback function $\mathcal{B}$ takes as input the gradient of a scalar $l$ (typically a loss function) to a function $y(x)$ and returns the gradient with regards to the variable $x$.
This partial gradient $\frac{\partial l}{\partial x}$ is written as $\bar{x}$.</p>

\[\begin{align}
\bar{x} &amp;= \frac{\partial l}{\partial x} = \frac{\partial l}{\partial y} \frac{\partial y}{\partial x}
\end{align}
\tag{2.4}
\label{eq:bar_x}\]

<p>so we can write in this mathematical notation as:</p>

\[\begin{align}
\bar{x} &amp;\leftarrow \mathcal{B}(\bar{y}) = \bar{y} \frac{\partial y}{\partial x}\\
\text{or} \quad \bar{x} &amp;\leftarrow  \mathcal{B}(\bar{y}) = J^{\dagger}\bar{y}
\end{align}
\tag{2.5}
\label{eq:pullback}\]

<p>where $\bar{y}=\frac{\partial l}{\partial y}$ and $J=\frac{\partial y}{\partial x}$ is the Jacobian (gradient) for arrays.</p>

<p>The various partial gradients are calculated by reversing the list.
Each pullback function $\mathcal{B}_i$ takes as input the previous gradient $\bar{y}_i$.
The input is an existing gradient $\Delta$. At the start this is usually set to 1:</p>

\[\begin{align}
\text{s̄elf}_3, \bar{a}_{3,1}, \bar{y}_2 &amp;\leftarrow \mathcal{B}_3(\Delta) \\
\text{s̄elf}_2, \bar{a}_{2,1}, \bar{y}_1 &amp;\leftarrow \mathcal{B}_2(\bar{y}_2) \\
\text{s̄elf}_1, \bar{b}_{1,1}, \bar{b}_{1,2} &amp;\leftarrow \mathcal{B}_1(\bar{y}_1)
\end{align}
\tag{2.6}
\label{eq:reverse}\]

<p>The final step is to accumulate the gradients for variables which are used multiple times:</p>

\[\begin{align}
\bar{a} &amp;\leftarrow \bar{a}_{3,1} + \bar{a}_{2,1} \\
\bar{b} &amp;\leftarrow \bar{b}_{1,1} + \bar{b}_{1,2} \\
\end{align}
\tag{2.7}
\label{eq:accumulate}\]

<p>This end result is equivalent to rolling everything up into one function using the multivariable chain rule:</p>

\[\begin{align}
\bar{a} &amp;= \frac{\partial l}{\partial a} = \mathcal{B}_{3,a}(\Delta) + \mathcal{B}_{2,a}(\bar{y}_2) \\
        &amp;= \frac{\partial l}{\partial y_3} \frac{\partial y_3}{\partial a} + \frac{\partial l}{\partial y_2} \frac{\partial y_2}{\partial a} \\
        &amp;= \Delta \cdot \frac{\partial }{\partial a} \left( \frac{a}{y_2}\right) + 
        \left(\frac{\partial l}{\partial y_3}\frac{\partial y_3}{\partial y_2} \right)\frac{\partial}{\partial a}(a + y_1) \\
        &amp;= \Delta  \frac{1}{y_2} + \left(\Delta \frac{-a}{y_2^2} \right) (1+0) \\
        &amp;= \Delta \frac{b^2}{(a+b^2)^2} \\
\bar{b} &amp;= \frac{\partial l}{\partial b} = 2 \mathcal{B}_{1,b}(\bar{y}_1) \\
        &amp;= 2\frac{\partial l}{\partial y_1} \frac{\partial y_1}{\partial b} \\
        &amp;= 2 \left(\frac{\partial l}{\partial y_3}\frac{\partial y_3}{\partial y_2}\frac{\partial y_2}{\partial y_1} \right) \frac{\partial y_1}{\partial b} \\
        &amp;= 2 \left(\Delta \cdot \frac{\partial}{\partial y_2}\left(\frac{a}{y_2}\right) \cdot \frac{\partial}{\partial y_1}(a + y_1) \right)\frac{\partial}{\partial b'}(b'\times b) \\
        &amp;= 2\left(\Delta \left(-\frac{a}{y_2^2}\right)(0+1)\right)b \\
        &amp;= -\frac{2ab\Delta}{(a+b^2)^2}
\end{align}
\tag{2.8}
\label{eq:rollup}\]

<h2 id="pullback">3 Pullback</h2>
<h3 id="pullback-definition">3.1 Definition</h3>

<p>The goal is to generate code which automatically implements the equations of section 2.</p>

<div class="message-container info-message">
  <div class="message-icon fa fa-fw fa-2x fa-exclamation-circle"></div>
    <div class="content-container">
      <div class="message-body">
        The <code>pullback</code> function that is implemented here is equivalent to the internal <code>Zygote._pullback</code> function, which returns all partial gradients including for $\frac{\partial l}{\partial \text{self}}$. <code>Zygote.pullback</code> is a thin wrapper around <code>Zygote._pullback</code> which discards that first gradient.
      </div>
    </div>
</div>

<p>To start, define a <code class="language-plaintext highlighter-rouge">pullback</code> function (<a href="https://github.com/FluxML/ZygoteRules.jl/blob/f9bf0e367fa259c5aa68f0e14ccbf2125d734bd6/src/adjoint.jl#L33">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> pullback</span> <span class="k">end</span></code></pre></figure>

<p>This will be turned into a <a href="https://docs.julialang.org/en/v1/manual/metaprogramming/#Generated-functions">generated function</a>.</p>

<p>Julia changed the behaviour of generated functions in <a href="https://github.com/JuliaLang/julia/issues/49715">version 1.10</a>.
Before 1.10, they always had access to the <a href="https://docs.julialang.org/en/v1/manual/methods/">world age counter</a>.
This is a single number that is incremented every time a method is defined, and helps optimise compilations.
However from version 1.10 generated functions <code class="language-plaintext highlighter-rouge">Base.get_world_counter()</code> will only return <code class="language-plaintext highlighter-rouge">typemax(UInt)</code>.
This is to prevent reflection - code inspection - in generated functions.<sup id="fnref:generated_reflection" role="doc-noteref"><a href="#fn:generated_reflection" class="footnote" rel="footnote">1</a></sup>
However the code here relies on reflection.
Thankfully, there is a hack that <a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/compiler/interface2.jl#L69C17-L69C31">Zygote.jl</a> uses to access the world age in <code class="language-plaintext highlighter-rouge">pullback</code>.
Because of this, the definition of <code class="language-plaintext highlighter-rouge">pullback</code> is different based on the version, but both will forward to a common internal <code class="language-plaintext highlighter-rouge">_generate_pullback</code> function.</p>

<div class="message-container info-message">
    <div class="message-icon fa fa-fw fa-2x fa-exclamation-circle">
    </div>
    <div class="content-container">
        <div class="message-body">
    Generated functions should only be defined after all other functions. That is, at the bottom of the file or after all functions have been defined in the REPL. Otherwise they will not be able to access those functions or only old versions of those functions. These functions are defined here at the top only for explanatory purposes.
        </div>
    </div>
</div>

<div class="accordion" id="accordianJuliaVersions">
  <div class="card">
    <div class="card-header" id="generatedJuliaPre10">
      <div class="mb-0">
        <button class="btn btn-link btn-block text-left" type="button" data-toggle="collapse" data-target="#collapsePre10" aria-expanded="true" aria-controls="collapsePre10">
          Julia Version before 1.10
        </button>
      </div>
    </div>
    <div id="collapsePre10" class="collapse show" aria-labelledby="generatedJuliaPre10" data-parent="#accordianJuliaVersions">
      <div class="card-body">

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="nd">@generated</span> <span class="k">function</span><span class="nf"> pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
        <span class="n">_generate_pullback</span><span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

      </div>
    </div>
  </div>
  <div class="card">
    <div class="card-header" id="generatedJuliaPost10">
      <div class="mb-0">
        <button class="btn btn-link btn-block text-left collapsed" type="button" data-toggle="collapse" data-target="#collapsePost10" aria-expanded="false" aria-controls="collapsePost10">
          Julia Version after 1.10
        </button>
      </div>
    </div>
    <div id="collapsePost10" class="collapse" aria-labelledby="generatedJuliaPost10" data-parent="#accordianJuliaVersions">
      <div class="card-body">

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _pullback_generator</span><span class="x">(</span><span class="n">world</span><span class="o">::</span><span class="kt">UInt</span><span class="x">,</span> <span class="n">source</span><span class="x">,</span> <span class="n">self</span><span class="x">,</span> <span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="x">)</span>
        <span class="n">ret</span> <span class="o">=</span> <span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
        <span class="n">ret</span> <span class="k">isa</span> <span class="n">Core</span><span class="o">.</span><span class="n">CodeInfo</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="n">ret</span>
        <span class="n">stub</span> <span class="o">=</span> <span class="n">Core</span><span class="o">.</span><span class="n">GeneratedFunctionStub</span><span class="x">(</span><span class="n">identity</span><span class="x">,</span> <span class="n">Core</span><span class="o">.</span><span class="n">svec</span><span class="x">(</span><span class="o">:</span><span class="n">methodinstance</span><span class="x">,</span> <span class="o">:</span><span class="n">f</span><span class="x">,</span> <span class="o">:</span><span class="n">args</span><span class="x">),</span> <span class="n">Core</span><span class="o">.</span><span class="n">svec</span><span class="x">())</span>
        <span class="n">stub</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">source</span><span class="x">,</span> <span class="n">ret</span><span class="x">)</span>
<span class="k">end</span>

<span class="nd">@eval</span> <span class="k">function</span><span class="nf"> pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
        <span class="o">$</span><span class="x">(</span><span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">meta</span><span class="x">,</span> <span class="o">:</span><span class="n">generated</span><span class="x">,</span> <span class="n">_pullback_generator</span><span class="x">))</span>
        <span class="o">$</span><span class="x">(</span><span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">meta</span><span class="x">,</span> <span class="o">:</span><span class="n">generated_only</span><span class="x">))</span>
<span class="k">end</span></code></pre></figure>

      </div>
    </div>
  </div>
</div>

<h3 id="chainrules">3.2 ChainRules</h3>

<p>The first goal of <code class="language-plaintext highlighter-rouge">_generate_pullback</code> will be to forward the function and its arguments to a matching <code class="language-plaintext highlighter-rouge">rrule</code> if it exists.
For now it will throw an error if it cannot find one.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
    <span class="n">T</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">}</span>
    <span class="k">if</span> <span class="x">(</span><span class="n">has_chain_rrule</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">))</span>
        <span class="k">return</span> <span class="o">:</span><span class="x">(</span><span class="n">rrule</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">))</span>
    <span class="k">end</span>
    <span class="o">:</span><span class="x">(</span><span class="n">error</span><span class="x">(</span><span class="s">"No rrule found for "</span><span class="x">,</span> <span class="n">repr</span><span class="x">(</span><span class="o">$</span><span class="n">T</span><span class="x">)))</span>
<span class="k">end</span></code></pre></figure>

<p>In <a href="http://localhost:4000/machine-learning/2024/07/27/micrograd-1-chainrules#chainrules-definition">part 1</a> the most generic method of <code class="language-plaintext highlighter-rouge">rrule</code> was defined for an <code class="language-plaintext highlighter-rouge">Any</code> first argument, so if the compiler dispatches to this method it means no specific <code class="language-plaintext highlighter-rouge">rrule</code> was found.<sup id="fnref:has_chain_rule" role="doc-noteref"><a href="#fn:has_chain_rule" class="footnote" rel="footnote">2</a></sup></p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> has_chain_rrule</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">)</span>
    <span class="n">Tr</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">rrule</span><span class="x">),</span> <span class="n">T</span><span class="o">.</span><span class="n">parameters</span><span class="o">...</span><span class="x">}</span>
    <span class="n">meta_T</span> <span class="o">=</span> <span class="n">meta</span><span class="x">(</span><span class="n">Tr</span><span class="x">;</span> <span class="n">world</span><span class="o">=</span><span class="n">world</span><span class="x">)</span>
    <span class="k">if</span> <span class="n">isnothing</span><span class="x">(</span><span class="n">meta_T</span><span class="x">)</span>
        <span class="k">return</span> <span class="nb">false</span>
    <span class="k">end</span>
    <span class="n">type_signature</span><span class="x">,</span> <span class="n">sps</span><span class="x">,</span> <span class="n">method_</span> <span class="o">=</span> <span class="n">meta_T</span>
    <span class="n">method_</span><span class="o">.</span><span class="n">sig</span><span class="o">.</span><span class="n">parameters</span><span class="x">[</span><span class="mi">2</span><span class="x">]</span> <span class="o">!==</span> <span class="kt">Any</span>
<span class="k">end</span></code></pre></figure>

<p>The <code class="language-plaintext highlighter-rouge">meta</code> function uses the internal reflection function <code class="language-plaintext highlighter-rouge">Base._methods_by_ftype</code> to get all the methods for a specific type. (This same function is used by <code class="language-plaintext highlighter-rouge">methods</code>.)
The most specific method is assumed to be the last one (<a href="https://github.com/FluxML/IRTools.jl/blob/master/src/reflection/reflection.jl#L71">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> meta</span><span class="x">(</span><span class="n">T</span><span class="x">;</span> <span class="n">world</span><span class="o">=</span><span class="n">Base</span><span class="o">.</span><span class="n">get_world_counter</span><span class="x">())</span>
    <span class="k">if</span> <span class="n">isnothing</span><span class="x">(</span><span class="n">world</span><span class="x">)</span>
        <span class="n">world</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">get_world_counter</span><span class="x">()</span> <span class="c"># in generated function post v1.10 this will return typemax(UInt)</span>
    <span class="k">end</span>
    <span class="n">min_world</span> <span class="o">=</span> <span class="kt">Ref</span><span class="x">{</span><span class="kt">UInt</span><span class="x">}(</span><span class="n">typemin</span><span class="x">(</span><span class="kt">UInt</span><span class="x">))</span>
    <span class="n">max_world</span> <span class="o">=</span> <span class="kt">Ref</span><span class="x">{</span><span class="kt">UInt</span><span class="x">}(</span><span class="n">typemax</span><span class="x">(</span><span class="kt">UInt</span><span class="x">))</span>
    <span class="n">has_ambig</span> <span class="o">=</span> <span class="kt">Ptr</span><span class="x">{</span><span class="kt">Int32</span><span class="x">}(</span><span class="nb">C_NULL</span><span class="x">)</span>  <span class="c"># don't care about ambiguous results</span>
    <span class="n">_methods</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">_methods_by_ftype</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="cm">#=mt=#</span> <span class="nb">nothing</span><span class="x">,</span> <span class="cm">#=lim=#</span> <span class="o">-</span><span class="mi">1</span><span class="x">,</span>
        <span class="n">world</span><span class="x">,</span> <span class="cm">#=ambig=#</span> <span class="nb">false</span><span class="x">,</span>
        <span class="n">min_world</span><span class="x">,</span> <span class="n">max_world</span><span class="x">,</span> <span class="n">has_ambig</span><span class="x">)</span>
    <span class="n">_methods</span> <span class="o">===</span> <span class="nb">nothing</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="nb">nothing</span>
    <span class="n">_methods</span> <span class="k">isa</span> <span class="kt">Bool</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="nb">nothing</span>
    <span class="n">length</span><span class="x">(</span><span class="n">_methods</span><span class="x">)</span> <span class="o">==</span> <span class="mi">0</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="nb">nothing</span>
    <span class="n">last</span><span class="x">(</span><span class="n">_methods</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>Let’s test all this code from bottom to top for a function with an <code class="language-plaintext highlighter-rouge">rrule</code> and one without: <code class="language-plaintext highlighter-rouge">+</code> and <code class="language-plaintext highlighter-rouge">f(a,b)=a/(a+b*b)</code>.
As a reminder, generated functions only have access to a variables types, so to test the <code class="language-plaintext highlighter-rouge">_generate_pullback</code> and all functions under it, we can only work with the types.</p>

<p>Firstly, for <code class="language-plaintext highlighter-rouge">+</code> acting on floats:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">world</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">get_world_counter</span><span class="x">()</span>
<span class="n">T</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="o">+</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">}</span>
<span class="n">Tr</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">rrule</span><span class="x">),</span> <span class="n">T</span><span class="o">.</span><span class="n">parameters</span><span class="o">...</span><span class="x">}</span>
<span class="n">meta</span><span class="x">(</span><span class="n">Tr</span><span class="x">;</span> <span class="n">world</span><span class="o">=</span><span class="n">world</span><span class="x">)</span> <span class="c"># Core.MethodMatch(...), svec(), rrule(::typeof(+), x::Number, y::Number)</span>
<span class="n">has_chain_rrule</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">)</span> <span class="c"># true</span>
<span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="o">+</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">)</span> <span class="c"># :(rrule(f, args...))</span>
<span class="n">pullback</span><span class="x">(</span><span class="o">+</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">)</span> <span class="c"># (3.0, var"#add_back#5"())</span></code></pre></figure>

<p>Now for <code class="language-plaintext highlighter-rouge">f</code>, also acting on floats:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">world</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">get_world_counter</span><span class="x">()</span>
<span class="n">T</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">}</span>
<span class="n">Tr</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">rrule</span><span class="x">),</span> <span class="n">T</span><span class="o">.</span><span class="n">parameters</span><span class="o">...</span><span class="x">}</span>
<span class="n">meta</span><span class="x">(</span><span class="n">Tr</span><span class="x">;</span> <span class="n">world</span><span class="o">=</span><span class="n">world</span><span class="x">)</span> <span class="c"># Core.MethodMatch(...), svec(), rrule(::Any, ...)</span>
<span class="n">has_chain_rrule</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">)</span> <span class="c"># false</span>
<span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">)</span> <span class="c"># :(error(...))</span>
<span class="n">pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">)</span> <span class="c"># ERROR: No rrule found ...</span></code></pre></figure>

<p>The more interesting task is to inspect <code class="language-plaintext highlighter-rouge">f</code> and apply the equations of section 2 to fully differentiate with respect to all input parameters.</p>

<h3 id="ast">3.3 AST</h3>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/micrograd/compiler_diagram.png" alt="Julia compiler steps" />
<figcaption>Source: <a href="https://docs.julialang.org/en/v1/devdocs/eval/">Julia Docs eval</a></figcaption>
</figure>

<p>The first step is to create a Wengert list for <code class="language-plaintext highlighter-rouge">f</code>.
This is trivial because Julia already does this as part of the compilation process.
As the first step of lowering code, the compiler will create an Abstract Syntax Tree (AST) which in the absence of control flow is the same as a Wengert list.</p>

<div class="message-container info-message">
	<div class="message-icon fa fa-fw fa-2x fa-exclamation-circle">
	</div>
	<div class="content-container">
		<div class="message-body">
    Julia exposes the <code>@code_lowered</code> macro to easily access the Intermediate Representation (IR) which is in Single Static Assignment (SSA) form. This is one step lower than the AST. However in many cases it is the same. Part 3 works with this form instead of the AST.
		</div>
	</div>
</div>

<p>This AST can be retrieved by calling <code class="language-plaintext highlighter-rouge">Base.uncompressed_ast</code> on the method we have found above:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">T</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">}</span>
<span class="n">type_signature</span><span class="x">,</span> <span class="n">sps</span><span class="x">,</span> <span class="n">method_</span> <span class="o">=</span> <span class="n">meta</span><span class="x">(</span><span class="n">T</span><span class="x">)</span>
<span class="n">ci</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">uncompressed_ast</span><span class="x">(</span><span class="n">method_</span><span class="x">)</span>
<span class="cm">#=
CodeInfo(
    @ REPL[1]:1 within `f`
1 ─ %1 = b * b
│   %2 = a + %1
│   %3 = a / %2
└──      return %3
)
=#</span></code></pre></figure>

<p>The returned object is a <a href="https://docs.julialang.org/en/v1/devdocs/ast/#CodeInfo">CodeInfo</a> struct and it corresponds exactly to $\ref{eq:f_wengert}$.</p>

<p>Using this knowledge, we can now create a new function <code class="language-plaintext highlighter-rouge">_generate_pullback_via_decomposition</code> which will be called if no <code class="language-plaintext highlighter-rouge">rrule</code> exists.
It uses the <code class="language-plaintext highlighter-rouge">CodeInfo</code> block to create the primal (equation $\ref{eq:primal}$) (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/compiler/emit.jl#L98">source</a>).</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _generate_pullback_via_decomposition</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">)</span>
    <span class="n">m</span> <span class="o">=</span> <span class="n">meta</span><span class="x">(</span><span class="n">T</span><span class="x">;</span> <span class="n">world</span><span class="o">=</span><span class="n">world</span><span class="x">)</span>
    <span class="n">isnothing</span><span class="x">(</span><span class="n">m</span><span class="x">)</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="o">:</span><span class="x">(</span><span class="n">error</span><span class="x">(</span><span class="s">"No method found for "</span><span class="x">,</span> <span class="n">repr</span><span class="x">(</span><span class="o">$</span><span class="n">T</span><span class="x">),</span> <span class="s">" in world "</span><span class="x">,</span> <span class="o">$</span><span class="n">world</span><span class="x">))</span>
    <span class="n">type_signature</span><span class="x">,</span> <span class="n">sps</span><span class="x">,</span> <span class="n">method_</span> <span class="o">=</span> <span class="n">m</span>
    <span class="n">ci</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">uncompressed_ast</span><span class="x">(</span><span class="n">method_</span><span class="x">)</span>
    <span class="n">pr</span><span class="x">,</span> <span class="n">calls</span> <span class="o">=</span> <span class="n">primal</span><span class="x">(</span><span class="n">ci</span><span class="x">,</span> <span class="n">T</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<h3 id="primal">3.4 Primal</h3>

<p>The goal here is to create an expression for equation $\ref{eq:primal}$.
This is what it will look like:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">quote</span>
    <span class="x">(</span><span class="n">y1</span><span class="x">,</span> <span class="n">back1</span><span class="x">)</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">Main</span><span class="o">.:*</span><span class="x">,</span> <span class="n">_3</span><span class="x">,</span> <span class="n">_3</span><span class="x">)</span>
    <span class="x">(</span><span class="n">y2</span><span class="x">,</span> <span class="n">back2</span><span class="x">)</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">Main</span><span class="o">.:+</span><span class="x">,</span> <span class="n">_2</span><span class="x">,</span> <span class="o">%</span><span class="mi">1</span><span class="x">)</span>
    <span class="x">(</span><span class="n">y3</span><span class="x">,</span> <span class="n">back3</span><span class="x">)</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">Main</span><span class="o">.:/</span><span class="x">,</span> <span class="n">_2</span><span class="x">,</span> <span class="o">%</span><span class="mi">2</span><span class="x">)</span>
    <span class="n">Base</span><span class="o">.</span><span class="n">tuple</span><span class="x">(</span><span class="o">%</span><span class="mi">3</span><span class="x">,</span> <span class="x">(</span><span class="n">Pullback</span><span class="x">{</span><span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">}})(</span><span class="n">Base</span><span class="o">.</span><span class="n">tuple</span><span class="x">(</span><span class="n">back1</span><span class="x">,</span> <span class="n">back2</span><span class="x">,</span> <span class="n">back3</span><span class="x">)))</span>
<span class="k">end</span></code></pre></figure>

<p>Note that this expression cannot be executed because it still has slot numbers which correspond to input arguments (<code class="language-plaintext highlighter-rouge">_X</code>), and SSA values which correspond to intermediate values (e.g. <code class="language-plaintext highlighter-rouge">%X</code>).
This will be fixed in the <a href="#sanitise">Sanitise</a> section.</p>

<p>The first step for the primal function is to define three arrays to store information (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/compiler/reverse.jl#L201">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> primal</span><span class="x">(</span><span class="n">ci</span><span class="o">::</span><span class="n">Core</span><span class="o">.</span><span class="n">CodeInfo</span><span class="x">,</span> <span class="n">T</span><span class="o">=</span><span class="kt">Any</span><span class="x">)</span>
    <span class="n">tape</span> <span class="o">=</span> <span class="x">[]</span>
    <span class="n">calls</span> <span class="o">=</span> <span class="x">[]</span>
    <span class="n">pullbacks</span> <span class="o">=</span> <span class="x">[]</span></code></pre></figure>

<p>The <code class="language-plaintext highlighter-rouge">tape</code> array stores the new expressions which will be part of the final expression.
The <code class="language-plaintext highlighter-rouge">calls</code> array stores the subset of expressions that require a pullback.
This will be used to generate the reverse code (equation $\ref{eq:reverse}$) in the next section.
Lastly, <code class="language-plaintext highlighter-rouge">pullbacks</code> stores all the pullbacks.</p>

<p>Next, iterate over each line in the <code class="language-plaintext highlighter-rouge">CodeInfo</code> instance.
Each output variable will be called <code class="language-plaintext highlighter-rouge">y$i</code>.
Then the line’s expression type is inspected.
This minimal code cannot handle control flow or the creation of new objects, so errors will be explicitly thrown if those cases are encountered. 
(Please refer to the <a href="https://docs.julialang.org/en/v1/devdocs/ast/#Lowered-form">Lowered form</a> section in the Julia documentation.)</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="k">for</span> <span class="x">(</span><span class="n">i</span><span class="x">,</span> <span class="n">ex</span><span class="x">)</span> <span class="k">in</span> <span class="n">enumerate</span><span class="x">(</span><span class="n">ci</span><span class="o">.</span><span class="n">code</span><span class="x">)</span>
      <span class="n">vy</span> <span class="o">=</span> <span class="kt">Symbol</span><span class="x">(</span><span class="s">"y</span><span class="si">$</span><span class="s">i"</span><span class="x">)</span>
      <span class="k">if</span> <span class="n">ex</span> <span class="k">isa</span> <span class="n">Core</span><span class="o">.</span><span class="n">ReturnNode</span>
          <span class="n">break</span>
      <span class="k">elseif</span> <span class="x">(</span><span class="n">typeof</span><span class="x">(</span><span class="n">ex</span><span class="x">)</span> <span class="k">in</span> <span class="x">[</span><span class="n">Core</span><span class="o">.</span><span class="n">GotoNode</span><span class="x">,</span> <span class="n">Core</span><span class="o">.</span><span class="n">GotoIfNot</span><span class="x">,</span> <span class="n">Core</span><span class="o">.</span><span class="n">SlotNumber</span><span class="x">])</span>
          <span class="n">error</span><span class="x">(</span><span class="s">"</span><span class="si">$</span><span class="s">(typeof(ex)) is not supported"</span><span class="x">)</span></code></pre></figure>

<p>If the expression is of type <code class="language-plaintext highlighter-rouge">Expr</code> and it makes a call, and it is not in a specialised ignore list (to be defined shortly), then the new expression can be created and the three arrays updated.
Otherwise, leave as is.</p>

<div class="message-container warning-message">
	<div class="message-icon fa fa-fw fa-2x fa-exclamation-triangle">
	</div>
	<div class="content-container">
		<div class="message-body">
		There are possible silent errors, including logic errors, with the <code>else</code> statement here.
    For example, it will not properly handle any <code>:new</code> expression statements.
    This is one of the inherent complexities with this metaprogramming/multiple dispatch approach.
		</div>
	</div>
</div>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">      <span class="k">elseif</span> <span class="x">(</span><span class="n">ex</span> <span class="k">isa</span> <span class="kt">Expr</span><span class="x">)</span> <span class="o">&amp;&amp;</span> <span class="x">(</span><span class="n">ex</span><span class="o">.</span><span class="n">head</span> <span class="o">==</span> <span class="o">:</span><span class="n">call</span><span class="x">)</span>  <span class="o">&amp;&amp;</span> <span class="o">!</span><span class="n">ignored</span><span class="x">(</span><span class="n">ex</span><span class="x">)</span>
              <span class="n">vb</span> <span class="o">=</span> <span class="kt">Symbol</span><span class="x">(</span><span class="s">"back</span><span class="si">$</span><span class="s">i"</span><span class="x">)</span>
              <span class="n">new_ex</span> <span class="o">=</span> <span class="o">:</span><span class="x">((</span><span class="o">$</span><span class="n">vy</span><span class="x">,</span> <span class="o">$</span><span class="n">vb</span><span class="x">)</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="o">$</span><span class="x">(</span><span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="o">...</span><span class="x">)))</span>
              <span class="n">push!</span><span class="x">(</span><span class="n">tape</span><span class="x">,</span> <span class="n">new_ex</span><span class="x">)</span>
              <span class="n">push!</span><span class="x">(</span><span class="n">calls</span><span class="x">,</span> <span class="x">(;</span><span class="n">SSA_value</span><span class="o">=</span><span class="n">vy</span><span class="x">,</span> <span class="n">expr</span><span class="o">=</span><span class="n">ex</span><span class="x">))</span>
              <span class="n">push!</span><span class="x">(</span><span class="n">pullbacks</span><span class="x">,</span> <span class="n">vb</span><span class="x">)</span>
      <span class="k">else</span> <span class="c"># keep as is</span>
              <span class="n">push!</span><span class="x">(</span><span class="n">tape</span><span class="x">,</span> <span class="o">:</span><span class="x">(</span><span class="o">$</span><span class="n">vy</span> <span class="o">=</span> <span class="o">$</span><span class="n">ex</span><span class="x">))</span>
      <span class="k">end</span>
    <span class="k">end</span></code></pre></figure>

<p>After working through all the lines, a final expression is added which returns a tuple with the final output of the function and a <code class="language-plaintext highlighter-rouge">Pullback</code> struct which stores all the pullbacks.
Everything is then grouped into a single <code class="language-plaintext highlighter-rouge">:block</code> expression:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="n">pb</span> <span class="o">=</span> <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">call</span><span class="x">,</span> <span class="n">Pullback</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">xcall</span><span class="x">(</span><span class="o">:</span><span class="n">tuple</span><span class="x">,</span> <span class="n">pullbacks</span><span class="o">...</span><span class="x">))</span>
    <span class="n">push!</span><span class="x">(</span><span class="n">tape</span><span class="x">,</span> <span class="n">xcall</span><span class="x">(</span><span class="o">:</span><span class="n">tuple</span><span class="x">,</span> <span class="n">returnvalue</span><span class="x">(</span><span class="n">ci</span><span class="x">),</span> <span class="n">pb</span><span class="x">))</span>
    <span class="n">pr</span> <span class="o">=</span> <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">block</span><span class="x">,</span> <span class="n">tape</span><span class="o">...</span><span class="x">)</span>
    <span class="n">pr</span><span class="x">,</span> <span class="n">calls</span>
<span class="k">end</span></code></pre></figure>

<p>This code requires definitions for the <code class="language-plaintext highlighter-rouge">Pullback</code> struct as well as the following functions: <code class="language-plaintext highlighter-rouge">ignored</code>, <code class="language-plaintext highlighter-rouge">xcall</code> and <code class="language-plaintext highlighter-rouge">returnvalue</code>.</p>

<p>There are no closures in lowered Julia code, so instead <a href="https://fluxml.ai/Zygote.jl/stable/internals/#Closure-Conversion-1">Zygote.jl</a> stores the pullbacks in a generic struct:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> Pullback</span><span class="x">{</span><span class="n">S</span><span class="x">,</span><span class="n">T</span><span class="x">}</span>
    <span class="n">data</span><span class="o">::</span><span class="n">T</span>
<span class="k">end</span>
<span class="n">Pullback</span><span class="x">{</span><span class="n">S</span><span class="x">}(</span><span class="n">data</span><span class="x">)</span> <span class="k">where</span> <span class="n">S</span> <span class="o">=</span> <span class="n">Pullback</span><span class="x">{</span><span class="n">S</span><span class="x">,</span><span class="n">typeof</span><span class="x">(</span><span class="n">data</span><span class="x">)}(</span><span class="n">data</span><span class="x">)</span></code></pre></figure>

<p>In the next section this struct will be turned into a callable struct.
That is, for <code class="language-plaintext highlighter-rouge">back=Pullback{S}(data)</code>, we will create a generated function that dispatches on itself: <code class="language-plaintext highlighter-rouge">(j::Pullback)(Δ)</code> so that we can call <code class="language-plaintext highlighter-rouge">back(Δ)</code>. This <code class="language-plaintext highlighter-rouge">back</code> has all the information to generate the reverse pass independently of the forward pass: the method can be retrieved using <code class="language-plaintext highlighter-rouge">meta(S)</code> and the relevant data and input parameters from  <code class="language-plaintext highlighter-rouge">back.data</code>.</p>

<p>Here is the ignored functions list (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/compiler/reverse.jl#L171">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> ignored</span><span class="x">(</span><span class="n">ex</span><span class="o">::</span><span class="kt">Expr</span><span class="x">)</span>
    <span class="n">f</span> <span class="o">=</span> <span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span>
    <span class="n">ignored_f</span><span class="x">(</span><span class="n">f</span><span class="x">)</span>
<span class="k">end</span>

<span class="n">ignored_f</span><span class="x">(</span><span class="n">f</span><span class="x">)</span> <span class="o">=</span> <span class="n">f</span> <span class="k">in</span> <span class="x">(</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Base</span><span class="x">,</span> <span class="o">:</span><span class="n">not_int</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="o">.</span><span class="n">Intrinsics</span><span class="x">,</span> <span class="o">:</span><span class="n">not_int</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="x">,</span> <span class="o">:</span><span class="x">(</span><span class="o">===</span><span class="x">)),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="x">,</span> <span class="o">:</span><span class="n">apply_type</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="x">,</span> <span class="o">:</span><span class="n">typeof</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="x">,</span> <span class="o">:</span><span class="n">throw</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Base</span><span class="x">,</span> <span class="o">:</span><span class="n">kwerr</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="x">,</span> <span class="o">:</span><span class="n">kwfunc</span><span class="x">),</span>
    <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">Core</span><span class="x">,</span> <span class="o">:</span><span class="n">isdefined</span><span class="x">)</span>
<span class="x">)</span></code></pre></figure>

<p><code class="language-plaintext highlighter-rouge">xcall</code> and <code class="language-plaintext highlighter-rouge">returnvalue</code> are convenience functions from <a href="https://github.com/FluxML/IRTools.jl/blob/dd1f2c212258001ea565df696841929ad0fcb614/src/ir/utils.jl#L12">IRTools</a>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">xcall</span><span class="x">(</span><span class="n">mod</span><span class="o">::</span><span class="kt">Module</span><span class="x">,</span> <span class="n">f</span><span class="o">::</span><span class="kt">Symbol</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">call</span><span class="x">,</span> <span class="kt">GlobalRef</span><span class="x">(</span><span class="n">mod</span><span class="x">,</span> <span class="n">f</span><span class="x">),</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
<span class="n">xcall</span><span class="x">(</span><span class="n">f</span><span class="o">::</span><span class="kt">Symbol</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="n">xcall</span><span class="x">(</span><span class="n">Base</span><span class="x">,</span> <span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
<span class="n">xcall</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">call</span><span class="x">,</span> <span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>

<span class="k">function</span><span class="nf"> returnvalue</span><span class="x">(</span><span class="n">ci</span><span class="o">::</span><span class="n">Core</span><span class="o">.</span><span class="n">CodeInfo</span><span class="x">)</span>
    <span class="k">for</span> <span class="n">expr</span> <span class="k">in</span> <span class="n">ci</span><span class="o">.</span><span class="n">code</span>
        <span class="k">if</span> <span class="n">expr</span> <span class="k">isa</span> <span class="n">Core</span><span class="o">.</span><span class="n">ReturnNode</span>
            <span class="k">return</span> <span class="n">expr</span><span class="o">.</span><span class="n">val</span>
        <span class="k">end</span>
    <span class="k">end</span>
<span class="k">end</span></code></pre></figure>

<p>Running this code:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">world</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">get_world_counter</span><span class="x">()</span>
<span class="n">T</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">}</span>
<span class="n">pr</span><span class="x">,</span> <span class="n">calls</span> <span class="o">=</span><span class="n">_generate_pullback_via_decomposition</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">)</span></code></pre></figure>

<p>gives the expression at the start.</p>

<h3 id="sanitise"> 3.5 Sanitise </h3>

<p>To evaluate the expression we need to remove all slot values and SSA values.</p>

<p>For the slot values (<code class="language-plaintext highlighter-rouge">_X</code>), the first parameter in <code class="language-plaintext highlighter-rouge">T</code> will always be the function <code class="language-plaintext highlighter-rouge">f</code>, and the remainder are from <code class="language-plaintext highlighter-rouge">args</code>.
Therefore the first slot needs to be replaced with the symbol <code class="language-plaintext highlighter-rouge">:f</code>, and the remainder with <code class="language-plaintext highlighter-rouge">Base.getindex(args, idx)</code> where <code class="language-plaintext highlighter-rouge">idx</code> is offset by 1.
Here are two recursive functions to accomplish this:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> replace_slot!</span><span class="x">(</span><span class="n">ex</span><span class="o">::</span><span class="kt">Expr</span><span class="x">,</span> <span class="n">idx</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">f</span><span class="o">::</span><span class="kt">Symbol</span><span class="x">)</span>
    <span class="k">for</span> <span class="x">(</span><span class="n">i</span><span class="x">,</span> <span class="n">v</span><span class="x">)</span> <span class="k">in</span> <span class="n">enumerate</span><span class="x">(</span><span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="x">)</span>
        <span class="k">if</span> <span class="n">v</span> <span class="k">isa</span> <span class="kt">Expr</span>
            <span class="n">replace_slot!</span><span class="x">(</span><span class="n">v</span><span class="x">,</span> <span class="n">idx</span><span class="x">,</span> <span class="n">f</span><span class="x">)</span>
        <span class="k">elseif</span> <span class="n">v</span> <span class="k">isa</span> <span class="n">Core</span><span class="o">.</span><span class="n">SlotNumber</span> <span class="o">&amp;&amp;</span> <span class="n">v</span><span class="o">.</span><span class="n">id</span> <span class="o">==</span> <span class="n">idx</span>
            <span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="x">[</span><span class="n">i</span><span class="x">]</span> <span class="o">=</span> <span class="o">:</span><span class="x">(</span><span class="o">$</span><span class="n">f</span><span class="x">)</span> 
        <span class="k">end</span>
    <span class="k">end</span>
    <span class="n">ex</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> varargs!</span><span class="x">(</span><span class="n">ex</span><span class="o">::</span><span class="kt">Expr</span><span class="x">,</span> <span class="n">offset</span><span class="o">::</span><span class="kt">Int</span><span class="o">=</span><span class="mi">1</span><span class="x">)</span>
    <span class="k">for</span> <span class="x">(</span><span class="n">i</span><span class="x">,</span> <span class="n">v</span><span class="x">)</span> <span class="k">in</span> <span class="n">enumerate</span><span class="x">(</span><span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="x">)</span>
        <span class="k">if</span> <span class="n">v</span> <span class="k">isa</span> <span class="kt">Expr</span>
            <span class="n">varargs!</span><span class="x">(</span><span class="n">v</span><span class="x">)</span>
        <span class="k">elseif</span> <span class="n">v</span> <span class="k">isa</span> <span class="n">Core</span><span class="o">.</span><span class="n">SlotNumber</span>
            <span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="x">[</span><span class="n">i</span><span class="x">]</span> <span class="o">=</span> <span class="o">:</span><span class="x">(</span><span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">args</span><span class="x">,</span> <span class="o">$</span><span class="x">(</span><span class="n">v</span><span class="o">.</span><span class="n">id</span> <span class="o">-</span> <span class="n">offset</span><span class="x">)))</span> 
        <span class="k">end</span>
    <span class="k">end</span>
    <span class="n">ex</span>
<span class="k">end</span></code></pre></figure>

<p>The SSA values (<code class="language-plaintext highlighter-rouge">%id</code>) need to be replaced by the <code class="language-plaintext highlighter-rouge">y$id</code> symbol:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> replace_SSA!</span><span class="x">(</span><span class="n">ex</span><span class="o">::</span><span class="kt">Expr</span><span class="x">)</span>
    <span class="k">for</span> <span class="x">(</span><span class="n">i</span><span class="x">,</span> <span class="n">v</span><span class="x">)</span> <span class="k">in</span> <span class="n">enumerate</span><span class="x">(</span><span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="x">)</span>
        <span class="k">if</span> <span class="n">v</span> <span class="k">isa</span> <span class="kt">Expr</span>
            <span class="n">replace_SSA!</span><span class="x">(</span><span class="n">v</span><span class="x">)</span>
        <span class="k">elseif</span> <span class="n">v</span> <span class="k">isa</span> <span class="n">Core</span><span class="o">.</span><span class="n">SSAValue</span>
            <span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="x">[</span><span class="n">i</span><span class="x">]</span> <span class="o">=</span> <span class="kt">Symbol</span><span class="x">(</span><span class="s">"y</span><span class="si">$</span><span class="s">(v.id)"</span><span class="x">)</span> 
        <span class="k">end</span>
    <span class="k">end</span>
    <span class="n">ex</span>
<span class="k">end</span></code></pre></figure>

<p>Running this code on <code class="language-plaintext highlighter-rouge">pr</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">replace_slot!</span><span class="x">(</span><span class="n">pr</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="o">:</span><span class="n">f</span><span class="x">)</span>
<span class="n">varargs!</span><span class="x">(</span><span class="n">pr</span><span class="x">)</span>
<span class="n">replace_SSA!</span><span class="x">(</span><span class="n">pr</span><span class="x">)</span></code></pre></figure>

<p>Results in:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">quote</span>
    <span class="x">(</span><span class="n">y1</span><span class="x">,</span> <span class="n">back1</span><span class="x">)</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">Main</span><span class="o">.:*</span><span class="x">,</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">args</span><span class="x">,</span> <span class="mi">2</span><span class="x">),</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">args</span><span class="x">,</span> <span class="mi">2</span><span class="x">))</span>
    <span class="x">(</span><span class="n">y2</span><span class="x">,</span> <span class="n">back2</span><span class="x">)</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">Main</span><span class="o">.:+</span><span class="x">,</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">args</span><span class="x">,</span> <span class="mi">1</span><span class="x">),</span> <span class="n">y1</span><span class="x">)</span>
    <span class="x">(</span><span class="n">y3</span><span class="x">,</span> <span class="n">back3</span><span class="x">)</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">Main</span><span class="o">.:/</span><span class="x">,</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">args</span><span class="x">,</span> <span class="mi">1</span><span class="x">),</span> <span class="n">y2</span><span class="x">)</span>
    <span class="n">Base</span><span class="o">.</span><span class="n">tuple</span><span class="x">(</span><span class="n">y3</span><span class="x">,</span> <span class="x">(</span><span class="n">Pullback</span><span class="x">{</span><span class="kt">Tuple</span><span class="x">{</span><span class="n">typeof</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">}})(</span><span class="n">Base</span><span class="o">.</span><span class="n">tuple</span><span class="x">(</span><span class="n">back1</span><span class="x">,</span> <span class="n">back2</span><span class="x">,</span> <span class="n">back3</span><span class="x">)))</span>
<span class="k">end</span></code></pre></figure>

<p>We can now complete <code class="language-plaintext highlighter-rouge">_generate_pullback</code> to also call the decomposition code:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">)</span>
    <span class="n">T</span> <span class="o">=</span> <span class="kt">Tuple</span><span class="x">{</span><span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">}</span>
    <span class="k">if</span> <span class="x">(</span><span class="n">has_chain_rrule</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">))</span>
        <span class="k">return</span> <span class="o">:</span><span class="x">(</span><span class="n">rrule</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="n">args</span><span class="o">...</span><span class="x">))</span>
    <span class="k">end</span>    
    <span class="n">pr</span><span class="x">,</span> <span class="n">backs</span> <span class="o">=</span> <span class="n">_generate_pullback_via_decomposition</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">world</span><span class="x">)</span>
    <span class="n">replace_slot!</span><span class="x">(</span><span class="n">pr</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="o">:</span><span class="n">f</span><span class="x">)</span>
    <span class="n">varargs!</span><span class="x">(</span><span class="n">pr</span><span class="x">)</span>
    <span class="n">replace_SSA!</span><span class="x">(</span><span class="n">pr</span><span class="x">)</span>
    <span class="n">pr</span>
<span class="k">end</span></code></pre></figure>

<p>Testing (you should redefine the <code class="language-plaintext highlighter-rouge">@generated pullback</code> function first):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">world</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">get_world_counter</span><span class="x">()</span>
<span class="n">pr</span> <span class="o">=</span> <span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="n">f</span><span class="x">),</span> <span class="kt">Float64</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">)</span> <span class="c"># same as above</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">)</span> <span class="c"># (0.2,Pullback{...})</span></code></pre></figure>

<h3 id="reverse">3.6 Reverse</h3>

<p>The goal is to now turn <code class="language-plaintext highlighter-rouge">Pullback</code> into a callable struct so that we can call <code class="language-plaintext highlighter-rouge">back(1.0)</code> to evaluate equations $\ref{eq:reverse}$ and $\ref{eq:accumulate}$.
With <code class="language-plaintext highlighter-rouge">typeof(back)</code> and <code class="language-plaintext highlighter-rouge">back.data</code> we have all the information to do this independent from the forward pass.
The result will be:</p>

<div class="message-container info-message">
	<div class="message-icon fa fa-fw fa-2x fa-exclamation-circle"></div>
  <div class="content-container">
    <div class="message-body">
      There are unused variables here which can be removed e.g. <code>x̄3_1</code> (s̄elf). The code here does not do such optimisations to keep things simple.
    </div>
  </div>
</div>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">quote</span>
    <span class="n">data</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getfield</span><span class="x">(</span><span class="n">methodinstance</span><span class="x">,</span> <span class="o">:</span><span class="n">data</span><span class="x">)</span>
    <span class="n">back3</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">data</span><span class="x">,</span> <span class="mi">3</span><span class="x">)</span>
    <span class="n">Δs</span> <span class="o">=</span> <span class="n">back3</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span>
    <span class="n">x̄3_1</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">Δs</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span>
    <span class="n">x̄3_2</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">Δs</span><span class="x">,</span> <span class="mi">2</span><span class="x">)</span>
    <span class="n">x̄3_3</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">Δs</span><span class="x">,</span> <span class="mi">3</span><span class="x">)</span>
    <span class="n">back2</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">data</span><span class="x">,</span> <span class="mi">2</span><span class="x">)</span>
    <span class="n">Δs</span> <span class="o">=</span> <span class="n">back2</span><span class="x">(</span><span class="n">x̄3_3</span><span class="x">)</span>
    <span class="n">x̄2_1</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">Δs</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span>
    <span class="n">x̄2_2</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">Δs</span><span class="x">,</span> <span class="mi">2</span><span class="x">)</span>
    <span class="n">x̄2_3</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">Δs</span><span class="x">,</span> <span class="mi">3</span><span class="x">)</span>
    <span class="n">back1</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">data</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span>
    <span class="n">Δs</span> <span class="o">=</span> <span class="n">back1</span><span class="x">(</span><span class="n">x̄2_3</span><span class="x">)</span>
    <span class="n">x̄1_1</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">Δs</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span>
    <span class="n">x̄1_2</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">Δs</span><span class="x">,</span> <span class="mi">2</span><span class="x">)</span>
    <span class="n">x̄1_3</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">Δs</span><span class="x">,</span> <span class="mi">3</span><span class="x">)</span>
    <span class="n">Base</span><span class="o">.</span><span class="n">tuple</span><span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="n">Main</span><span class="o">.</span><span class="n">accum</span><span class="x">(</span><span class="n">x̄3_2</span><span class="x">,</span> <span class="n">x̄2_2</span><span class="x">),</span> <span class="n">Main</span><span class="o">.</span><span class="n">accum</span><span class="x">(</span><span class="n">x̄1_2</span><span class="x">,</span> <span class="n">x̄1_3</span><span class="x">))</span>
<span class="k">end</span></code></pre></figure>

<p>As with the forward pass, an internal function <code class="language-plaintext highlighter-rouge">_generate_callable_pullback</code> will do most of the work.
It uses the <code class="language-plaintext highlighter-rouge">meta</code> function defined above to get the <code class="language-plaintext highlighter-rouge">CodeInfo</code> struct based on the input types:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _generate_callable_pullback</span><span class="x">(</span><span class="n">j</span><span class="o">::</span><span class="kt">Type</span><span class="x">{</span><span class="o">&lt;:</span><span class="n">Pullback</span><span class="x">{</span><span class="n">S</span><span class="x">,</span> <span class="n">T</span><span class="x">}},</span> <span class="n">world</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">S</span><span class="x">,</span> <span class="n">T</span><span class="x">}</span>
    <span class="n">m</span> <span class="o">=</span> <span class="n">meta</span><span class="x">(</span><span class="n">S</span><span class="x">;</span> <span class="n">world</span><span class="o">=</span><span class="n">world</span><span class="x">)</span>
    <span class="n">isnothing</span><span class="x">(</span><span class="n">m</span><span class="x">)</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="o">:</span><span class="x">(</span><span class="n">error</span><span class="x">(</span><span class="s">"No method found for "</span><span class="x">,</span> <span class="n">repr</span><span class="x">(</span><span class="o">$</span><span class="n">S</span><span class="x">),</span> <span class="s">" in world "</span><span class="x">,</span> <span class="o">$</span><span class="n">world</span><span class="x">))</span>
    <span class="n">type_signature</span><span class="x">,</span> <span class="n">sps</span><span class="x">,</span> <span class="n">method_</span> <span class="o">=</span> <span class="n">m</span>
    <span class="n">ci</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">uncompressed_ast</span><span class="x">(</span><span class="n">method_</span><span class="x">)</span>
    <span class="n">back</span> <span class="o">=</span> <span class="n">reverse_differentiate</span><span class="x">(</span><span class="n">ci</span><span class="x">,</span> <span class="o">:</span><span class="n">methodinstance</span><span class="x">,</span> <span class="o">:</span><span class="n">Δ</span><span class="x">)</span>
    <span class="n">back</span>
<span class="k">end</span></code></pre></figure>

<p>The <code class="language-plaintext highlighter-rouge">reverse_differentiate</code> function is a simplified version of <a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/compiler/reverse.jl#L293">Zygote.adjoint</a> and <a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/compiler/emit.jl#L65">Zygote.reverse_stacks!</a>.</p>

<p>To start, a dictionary is created to store the gradients.
It maps variable names (symbols) to an array of gradients.
It is not accessed directly (e.g. <code class="language-plaintext highlighter-rouge">grads[x]</code>) but rather through the closure functions <code class="language-plaintext highlighter-rouge">grad</code> and <code class="language-plaintext highlighter-rouge">grad!</code> which automatically handle the arrays.
The first gradient stored is <code class="language-plaintext highlighter-rouge">Δ</code> associated with the final return value of the forward pass.
(<code class="language-plaintext highlighter-rouge">_var_name</code> and <code class="language-plaintext highlighter-rouge">xaccum</code>  will be defined shortly.)</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> reverse_differentiate</span><span class="x">(</span><span class="n">forw</span><span class="o">::</span><span class="n">Core</span><span class="o">.</span><span class="n">CodeInfo</span><span class="x">,</span> <span class="n">self</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span>
    <span class="n">grads</span> <span class="o">=</span> <span class="kt">Dict</span><span class="x">()</span>
    <span class="n">grad!</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">x̄</span><span class="x">)</span> <span class="o">=</span> <span class="n">push!</span><span class="x">(</span><span class="n">get!</span><span class="x">(</span><span class="n">grads</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="x">[]),</span> <span class="n">x̄</span><span class="x">)</span>
    <span class="n">grad</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">xaccum</span><span class="x">(</span><span class="n">get</span><span class="x">(</span><span class="n">grads</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="x">[])</span><span class="o">...</span><span class="x">)</span>
    <span class="n">grad!</span><span class="x">(</span><span class="n">_var_name</span><span class="x">(</span><span class="n">returnvalue</span><span class="x">(</span><span class="n">forw</span><span class="x">)),</span> <span class="n">Δ</span><span class="x">)</span> <span class="c"># _var_name maps to variable names in calls</span>
    <span class="n">tape</span> <span class="o">=</span> <span class="kt">Expr</span><span class="x">[]</span>
    <span class="n">push!</span><span class="x">(</span><span class="n">tape</span><span class="x">,</span> <span class="o">:</span><span class="x">(</span><span class="n">data</span><span class="o">=$</span><span class="x">(</span><span class="n">xcall</span><span class="x">(</span><span class="o">:</span><span class="n">getfield</span><span class="x">,</span> <span class="n">self</span><span class="x">,</span> <span class="kt">QuoteNode</span><span class="x">(</span><span class="o">:</span><span class="n">data</span><span class="x">)))))</span></code></pre></figure>

<p>The <code class="language-plaintext highlighter-rouge">tape</code> for the expression block is started by retrieving the <code class="language-plaintext highlighter-rouge">data</code> field in the struct.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="n">tape</span> <span class="o">=</span> <span class="kt">Expr</span><span class="x">[]</span>
    <span class="n">push!</span><span class="x">(</span><span class="n">tape</span><span class="x">,</span> <span class="o">:</span><span class="x">(</span><span class="n">data</span><span class="o">=$</span><span class="x">(</span><span class="n">xcall</span><span class="x">(</span><span class="o">:</span><span class="n">getfield</span><span class="x">,</span> <span class="n">self</span><span class="x">,</span> <span class="kt">QuoteNode</span><span class="x">(</span><span class="o">:</span><span class="n">data</span><span class="x">)))))</span></code></pre></figure>

<p>Next the code retrieves all the calls with pullbacks from the primal and loops over them, calling the pullbacks one by one.
For each call it also loops over the input arguments and unpacks them one by one.
Each variable’s gradient is added to <code class="language-plaintext highlighter-rouge">grads</code> and may be used later in the loop.
The <code class="language-plaintext highlighter-rouge">_var_name</code> function ensures that the keys of <code class="language-plaintext highlighter-rouge">grads</code> can be connected back to the original functions.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="n">pr</span><span class="x">,</span> <span class="n">calls</span> <span class="o">=</span> <span class="n">primal</span><span class="x">(</span><span class="n">forw</span><span class="x">)</span>
    <span class="n">i</span> <span class="o">=</span> <span class="n">length</span><span class="x">(</span><span class="n">calls</span><span class="x">)</span>
    <span class="k">for</span> <span class="x">(</span><span class="n">v</span><span class="x">,</span> <span class="n">ex</span><span class="x">)</span> <span class="k">in</span> <span class="n">reverse</span><span class="x">(</span><span class="n">calls</span><span class="x">)</span>
        <span class="n">vb</span> <span class="o">=</span> <span class="kt">Symbol</span><span class="x">(</span><span class="s">"back</span><span class="si">$</span><span class="s">i"</span><span class="x">)</span>
        <span class="n">push!</span><span class="x">(</span><span class="n">tape</span><span class="x">,</span> <span class="o">:</span><span class="x">(</span><span class="o">$</span><span class="n">vb</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">getindex</span><span class="x">(</span><span class="n">data</span><span class="x">,</span> <span class="o">$</span><span class="n">i</span><span class="x">)))</span>
        <span class="n">g</span> <span class="o">=</span> <span class="n">grad</span><span class="x">(</span><span class="n">v</span><span class="x">)</span>
        <span class="n">push!</span><span class="x">(</span><span class="n">tape</span><span class="x">,</span> <span class="o">:</span><span class="x">(</span><span class="n">Δs</span> <span class="o">=</span> <span class="o">$</span><span class="n">vb</span><span class="x">(</span><span class="o">$</span><span class="n">g</span><span class="x">)))</span>
        <span class="k">for</span> <span class="x">(</span><span class="n">j</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span> <span class="k">in</span> <span class="n">enumerate</span><span class="x">(</span><span class="n">ex</span><span class="o">.</span><span class="n">args</span><span class="x">)</span>
            <span class="n">xbar</span> <span class="o">=</span> <span class="kt">Symbol</span><span class="x">(</span><span class="s">"x̄</span><span class="si">$(i)</span><span class="s">_</span><span class="si">$(j)</span><span class="s">"</span><span class="x">)</span>
            <span class="n">get_xbar</span> <span class="o">=</span> <span class="o">:</span><span class="x">(</span><span class="o">$</span><span class="n">xbar</span><span class="o">=$</span><span class="x">(</span><span class="n">xcall</span><span class="x">(</span><span class="o">:</span><span class="n">getindex</span><span class="x">,</span> <span class="o">:</span><span class="n">Δs</span><span class="x">,</span> <span class="n">j</span><span class="x">)))</span>
            <span class="n">push!</span><span class="x">(</span><span class="n">tape</span><span class="x">,</span> <span class="n">get_xbar</span><span class="x">)</span>
            <span class="n">grad!</span><span class="x">(</span><span class="n">_var_name</span><span class="x">(</span><span class="n">x</span><span class="x">),</span> <span class="n">xbar</span><span class="x">)</span>
        <span class="k">end</span>
        <span class="n">i</span> <span class="o">-=</span> <span class="mi">1</span>
    <span class="k">end</span></code></pre></figure>

<p>Finally, the last call retrieves all the necessary gradients for the input arguments and returns a single <code class="language-plaintext highlighter-rouge">quote</code> block.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="n">push!</span><span class="x">(</span><span class="n">tape</span><span class="x">,</span> <span class="n">xcall</span><span class="x">(</span><span class="o">:</span><span class="n">tuple</span><span class="x">,</span> <span class="x">[</span><span class="n">grad</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="k">for</span> <span class="n">x</span> <span class="k">in</span> <span class="n">arguments</span><span class="x">(</span><span class="n">forw</span><span class="x">)]</span><span class="o">...</span><span class="x">))</span>
    <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">block</span><span class="x">,</span> <span class="n">tape</span><span class="o">...</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>This code required the following functions: <code class="language-plaintext highlighter-rouge">xaccum</code>, <code class="language-plaintext highlighter-rouge">_var_name</code> and <code class="language-plaintext highlighter-rouge">arguments</code>. They are as follows:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">xaccum</span><span class="x">()</span> <span class="o">=</span> <span class="nb">nothing</span>
<span class="n">xaccum</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span>
<span class="n">xaccum</span><span class="x">(</span><span class="n">xs</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="n">xcall</span><span class="x">(</span><span class="n">Main</span><span class="x">,</span> <span class="o">:</span><span class="n">accum</span><span class="x">,</span> <span class="n">xs</span><span class="o">...</span><span class="x">)</span>
<span class="n">_var_name</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="n">Core</span><span class="o">.</span><span class="n">SlotNumber</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">id</span> <span class="o">==</span> <span class="mi">1</span> <span class="o">?</span> <span class="kt">Symbol</span><span class="x">(</span><span class="s">"#self"</span><span class="x">)</span> <span class="o">:</span> <span class="kt">Symbol</span><span class="x">(</span><span class="s">"args</span><span class="si">$</span><span class="s">(x.id)"</span><span class="x">)</span>
<span class="n">_var_name</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="n">Core</span><span class="o">.</span><span class="n">SSAValue</span><span class="x">)</span>  <span class="o">=</span> <span class="kt">Symbol</span><span class="x">(</span><span class="s">"y</span><span class="si">$</span><span class="s">(x.id)"</span><span class="x">)</span>
<span class="n">_var_name</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span>
<span class="n">arguments</span><span class="x">(</span><span class="n">forw</span><span class="o">::</span><span class="n">Core</span><span class="o">.</span><span class="n">CodeInfo</span><span class="x">)</span> <span class="o">=</span> <span class="x">[</span><span class="kt">Symbol</span><span class="x">(</span><span class="s">"#self"</span><span class="x">),</span> <span class="x">[</span><span class="kt">Symbol</span><span class="x">(</span><span class="s">"args</span><span class="si">$</span><span class="s">i"</span><span class="x">)</span> <span class="k">for</span> <span class="n">i</span> <span class="k">in</span> <span class="mi">2</span><span class="o">:</span><span class="n">length</span><span class="x">(</span><span class="n">forw</span><span class="o">.</span><span class="n">slotnames</span><span class="x">)]</span><span class="o">...</span><span class="x">]</span></code></pre></figure>

<p>The <code class="language-plaintext highlighter-rouge">xaccum</code> function calls an internal accumulate function if it acts on multiple inputs. 
At its simplest, <code class="language-plaintext highlighter-rouge">accum</code> is the same as <code class="language-plaintext highlighter-rouge">sum</code>. 
However it also handles <code class="language-plaintext highlighter-rouge">nothing</code> inputs, <code class="language-plaintext highlighter-rouge">Tuples</code>s and <code class="language-plaintext highlighter-rouge">NameTuple</code>s (<a href="https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/lib/lib.jl#L14">source</a>).</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">accum</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">y</span><span class="x">)</span> <span class="o">=</span> <span class="n">x</span> <span class="o">===</span> <span class="nb">nothing</span> <span class="o">?</span> <span class="n">y</span> <span class="o">:</span> <span class="n">y</span> <span class="o">===</span> <span class="nb">nothing</span> <span class="o">?</span> <span class="n">x</span> <span class="o">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>
<span class="n">accum</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">Tuple</span><span class="x">,</span> <span class="n">ys</span><span class="o">::</span><span class="kt">Tuple</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">accum</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="n">ys</span><span class="o">...</span><span class="x">)</span>
<span class="n">accum</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">y</span><span class="x">,</span> <span class="n">zs</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="n">accum</span><span class="x">(</span><span class="n">accum</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">y</span><span class="x">),</span> <span class="n">zs</span><span class="o">...</span><span class="x">)</span>
<span class="nd">@generated</span> <span class="k">function</span><span class="nf"> accum</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">NamedTuple</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">NamedTuple</span><span class="x">)</span>
    <span class="c"># assumes that y has no keys apart from those also in x</span>
    <span class="n">fieldnames</span><span class="x">(</span><span class="n">y</span><span class="x">)</span> <span class="n">⊆</span> <span class="n">fieldnames</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">||</span> <span class="n">throw</span><span class="x">(</span><span class="kt">ArgumentError</span><span class="x">(</span><span class="s">"</span><span class="si">$</span><span class="s">y keys must be a subset of </span><span class="si">$</span><span class="s">x keys"</span><span class="x">))</span>
    <span class="n">grad</span><span class="x">(</span><span class="n">field</span><span class="x">)</span> <span class="o">=</span> <span class="n">field</span> <span class="k">in</span> <span class="n">fieldnames</span><span class="x">(</span><span class="n">y</span><span class="x">)</span> <span class="o">?</span> <span class="o">:</span><span class="x">(</span><span class="n">y</span><span class="o">.$</span><span class="n">field</span><span class="x">)</span> <span class="o">:</span> <span class="o">:</span><span class="nb">nothing</span>
    <span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">tuple</span><span class="x">,</span> <span class="x">[</span><span class="o">:</span><span class="x">(</span><span class="o">$</span><span class="n">f</span><span class="o">=</span><span class="n">accum</span><span class="x">(</span><span class="n">x</span><span class="o">.$</span><span class="n">f</span><span class="x">,</span> <span class="o">$</span><span class="x">(</span><span class="n">grad</span><span class="x">(</span><span class="n">f</span><span class="x">))))</span> <span class="k">for</span> <span class="n">f</span> <span class="k">in</span> <span class="n">fieldnames</span><span class="x">(</span><span class="n">x</span><span class="x">)]</span><span class="o">...</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>Examples:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">accum</span><span class="x">(</span><span class="mi">1</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="nb">nothing</span><span class="x">,</span> <span class="mi">3</span><span class="x">)</span> <span class="c"># 6</span>
<span class="n">accum</span><span class="x">((</span><span class="mi">1</span><span class="x">,</span> <span class="mi">2</span><span class="x">),</span> <span class="x">(</span><span class="mi">3</span><span class="x">,</span> <span class="mi">4</span><span class="x">))</span> <span class="c"># (3, 6)</span>
<span class="n">accum</span><span class="x">((;</span><span class="n">a</span><span class="o">=</span><span class="mi">3</span><span class="x">,</span> <span class="n">b</span><span class="o">=</span><span class="mi">2</span><span class="x">),</span> <span class="x">(;</span><span class="n">a</span><span class="o">=</span><span class="mi">1</span><span class="x">))</span> <span class="c"># (a = 4, b = 2)</span></code></pre></figure>

<p>Finally, dispatch on the <code class="language-plaintext highlighter-rouge">Pullback</code> struct to turn it into a callable struct:</p>

<div class="message-container info-message">
	<div class="message-icon fa fa-fw fa-2x fa-exclamation-circle">
	</div>
	<div class="content-container">
		<div class="message-body">
    The argument names <code>methodinstance</code> and <code>Δ</code> must match the symbols in the call to <code>reverse_differentiate</code> in <code>_generate_callable_pullback</code>. Otherwise the expression will be unable to find those variables.
		</div>
	</div>
</div>

<div class="accordion" id="accordianJuliaVersions-callable">
  <div class="card">
    <div class="card-header" id="generatedJuliaPre10s-callable">
      <div class="mb-0">
        <button class="btn btn-link btn-block text-left" type="button" data-toggle="collapse" data-target="#collapsePre10-callable" aria-expanded="true" aria-controls="collapsePre10-callable">
          Julia Version before 1.10
        </button>
      </div>
    </div>
    <div id="collapsePre10-callable" class="collapse show" aria-labelledby="generatedJuliaPre10s-callable" data-parent="#accordianJuliaVersions-callable">
      <div class="card-body">

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="nd">@generated</span> <span class="k">function</span><span class="nf"> </span><span class="o">(</span><span class="n">methodinstance</span><span class="o">::</span><span class="n">Pullback</span><span class="x">)(</span><span class="n">Δ</span><span class="x">)</span>
    <span class="n">_generate_callable_pullback</span><span class="x">(</span><span class="n">methodinstance</span><span class="x">,</span> <span class="nb">nothing</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

      </div>
    </div>
  </div>
  <div class="card">
    <div class="card-header" id="generatedJuliaPost10-callable">
      <div class="mb-0">
        <button class="btn btn-link btn-block text-left collapsed" type="button" data-toggle="collapse" data-target="#collapsePost10-callable" aria-expanded="false" aria-controls="collapsePost10-callable">
          Julia Version after 1.10
        </button>
      </div>
    </div>
    <div id="collapsePost10-callable" class="collapse" aria-labelledby="generatedJuliaPost10-callable" data-parent="#accordianJuliaVersions-callable">
      <div class="card-body">

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> _callable_pullback_generator</span><span class="x">(</span><span class="n">world</span><span class="o">::</span><span class="kt">UInt</span><span class="x">,</span> <span class="n">source</span><span class="x">,</span> <span class="n">self</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span>
    <span class="n">ret</span> <span class="o">=</span> <span class="n">_generate_callable_pullback</span><span class="x">(</span><span class="n">self</span><span class="x">,</span> <span class="n">world</span><span class="x">,</span> <span class="n">Δ</span><span class="x">)</span>
    <span class="n">ret</span> <span class="k">isa</span> <span class="n">Core</span><span class="o">.</span><span class="n">CodeInfo</span> <span class="o">&amp;&amp;</span> <span class="k">return</span> <span class="n">ret</span>
    <span class="n">stub</span> <span class="o">=</span> <span class="n">Core</span><span class="o">.</span><span class="n">GeneratedFunctionStub</span><span class="x">(</span><span class="n">identity</span><span class="x">,</span> <span class="n">Core</span><span class="o">.</span><span class="n">svec</span><span class="x">(</span><span class="o">:</span><span class="n">methodinstance</span><span class="x">,</span> <span class="o">:</span><span class="n">Δ</span><span class="x">),</span> <span class="n">Core</span><span class="o">.</span><span class="n">svec</span><span class="x">())</span> <span class="c"># names must match symbols in _generate_callable_pullback</span>
    <span class="n">stub</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">source</span><span class="x">,</span> <span class="n">ret</span><span class="x">)</span>
<span class="k">end</span>

<span class="nd">@eval</span> <span class="k">function</span><span class="nf"> </span><span class="o">(</span><span class="n">j</span><span class="o">::</span><span class="n">Pullback</span><span class="x">)(</span><span class="n">Δ</span><span class="x">)</span>
    <span class="o">$</span><span class="x">(</span><span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">meta</span><span class="x">,</span> <span class="o">:</span><span class="n">generated</span><span class="x">,</span> <span class="n">_callable_pullback_generator</span><span class="x">))</span>
    <span class="o">$</span><span class="x">(</span><span class="kt">Expr</span><span class="x">(</span><span class="o">:</span><span class="n">meta</span><span class="x">,</span> <span class="o">:</span><span class="n">generated_only</span><span class="x">))</span>
<span class="k">end</span></code></pre></figure>

      </div>
    </div>
  </div>
</div>

<p>Testing:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">f</span><span class="x">(</span><span class="n">a</span><span class="x">,</span><span class="n">b</span><span class="x">)</span><span class="o">=</span><span class="n">a</span><span class="o">/</span><span class="x">(</span><span class="n">a</span><span class="o">+</span><span class="n">b</span><span class="o">*</span><span class="n">b</span><span class="x">)</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">,</span> <span class="mf">3.0</span><span class="x">)</span> <span class="c"># (0.1818, Pullback{...})</span>
<span class="n">_generate_callable_pullback</span><span class="x">(</span><span class="n">typeof</span><span class="x">(</span><span class="n">back</span><span class="x">),</span> <span class="nb">nothing</span><span class="x">,</span> <span class="kt">Float64</span><span class="x">)</span> <span class="c"># expression at start</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># (nothing, 0.0744, -0.0991)</span></code></pre></figure>

<p>The results should match equation $\ref{eq:rollup}$:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">a</span><span class="x">,</span> <span class="n">b</span> <span class="o">=</span> <span class="mf">2.0</span><span class="x">,</span> <span class="mf">3.0</span>
<span class="n">ā</span> <span class="o">=</span> <span class="n">abs2</span><span class="x">(</span><span class="n">b</span><span class="x">)</span><span class="o">/</span><span class="n">abs2</span><span class="x">(</span><span class="n">a</span><span class="o">+</span><span class="n">abs2</span><span class="x">(</span><span class="n">b</span><span class="x">))</span> <span class="c"># 0.0744</span>
<span class="n">b̄</span> <span class="o">=</span> <span class="o">-</span><span class="mi">2</span><span class="o">*</span><span class="n">a</span><span class="o">*</span><span class="n">b</span><span class="o">/</span><span class="n">abs2</span><span class="x">(</span><span class="n">a</span><span class="o">+</span><span class="n">abs2</span><span class="x">(</span><span class="n">b</span><span class="x">))</span>  <span class="c"># -0.0991</span></code></pre></figure>

<h2 id="conclusion">4 Conclusion</h2>

<p>This code works well enough for this simple case. 
It also works for the trigonometry example from <a href="/machine-learning/2024/07/27/micrograd-1-chainrules.html#chainrules-trigonometry">part 1</a>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">f</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">sin</span><span class="x">(</span><span class="n">cos</span><span class="x">(</span><span class="n">x</span><span class="x">))</span>
<span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="mf">0.9</span><span class="x">)</span> <span class="c"># (0.5823, Pullback{...})</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># (nothing, -0.6368) </span></code></pre></figure>

<p>However it will fail for the polynomial model:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> Polynomial</span><span class="x">{</span><span class="n">V</span><span class="o">&lt;:</span><span class="kt">AbstractVector</span><span class="x">}</span>
    <span class="n">weights</span><span class="o">::</span><span class="n">V</span>
<span class="k">end</span>
<span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Polynomial</span><span class="x">)(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">evalpoly</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">m</span><span class="o">.</span><span class="n">weights</span><span class="x">)</span>
<span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Polynomial</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">)</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">m</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Polynomial</span><span class="x">([</span><span class="mf">3.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">,</span> <span class="o">-</span><span class="mf">3.0</span><span class="x">,</span> <span class="mf">1.0</span><span class="x">])</span>
<span class="n">x</span> <span class="o">=</span> <span class="x">[</span><span class="mf">1.0</span><span class="x">,</span> <span class="mf">2.0</span><span class="x">,</span> <span class="mf">3.0</span><span class="x">,</span> <span class="mf">4.0</span><span class="x">]</span>
<span class="n">pullback</span><span class="x">(</span><span class="n">model</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span> <span class="c"># ERROR: syntax: invalid syntax (static_parameter 1)</span></code></pre></figure>

<p>The error is raised three levels down:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">pr1</span> <span class="o">=</span> <span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">Polynomial</span><span class="x">,</span> <span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span>
<span class="n">pr2</span> <span class="o">=</span> <span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="n">map</span><span class="x">),</span> <span class="n">Polynomial</span><span class="x">,</span> <span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span>
<span class="n">pr3</span> <span class="o">=</span> <span class="n">_generate_pullback</span><span class="x">(</span><span class="n">world</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="n">Base</span><span class="o">.</span><span class="n">Generator</span><span class="x">),</span> <span class="n">Polynomial</span><span class="x">,</span> <span class="kt">Vector</span><span class="x">{</span><span class="kt">Float64</span><span class="x">})</span></code></pre></figure>

<p>This can be fixed by explicitly writing a pullback for <code class="language-plaintext highlighter-rouge">map</code>.</p>

<p>However rather than fixing it here, I first want to rewrite the code using IRTools.
The code written here is brittle and difficult to debug.
Instead of writing expressions, it would be better to directly create a <code class="language-plaintext highlighter-rouge">CodeInfo</code> struct which always contains valid code.
Julia does not allow us to do that, but working with an <code class="language-plaintext highlighter-rouge">IR</code> object which can be readily converted is the next best thing.
This is will be the goal of <a href="/machine-learning/2024/08/10/micrograd-3-ir">part 3</a>.</p>

<hr />

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:generated_reflection" role="doc-endnote">
      <p>Presumably the reason the Julia team tried to prevent reflection in generated functions is that it interferes with the compliers ability to properly predict, trigger and/or optimise compilations. <a href="#fnref:generated_reflection" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:has_chain_rule" role="doc-endnote">
      <p>Zygote.jl has more <a href="https://github.com/FluxML/Zygote.jl/blob/master/src/compiler/chainrules.jl">complex rules</a> which also consider other fallbacks, key word arguments and a possible opt out through a <code class="language-plaintext highlighter-rouge">no_rrule</code>. <a href="#fnref:has_chain_rule" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Lior Sinai</name></author><category term="machine-learning" /><category term="mathematics" /><category term="transformers" /><category term="&apos;machine" /><category term="learning&apos;" /><category term="&apos;deep" /><category term="learning&apos;" /><summary type="html"><![CDATA[A series on automatic differentiation in Julia. Part 2 uses metaprogramming to generate a modified (primal) forward pass and to reverse differentiate it into a backward pass. This post uses an expression based approach which can be brittle. Part 3 develops a more robust approach for the same code using IRTools.jl.]]></summary></entry><entry><title type="html">MicroGrad.jl: Part 1 ChainRules</title><link href="https://liorsinai.github.io/machine-learning/2024/07/27/micrograd-1-chainrules.html" rel="alternate" type="text/html" title="MicroGrad.jl: Part 1 ChainRules" /><published>2024-07-27T00:00:00+00:00</published><updated>2024-08-03T00:00:00+00:00</updated><id>https://liorsinai.github.io/machine-learning/2024/07/27/micrograd-1-chainrules</id><content type="html" xml:base="https://liorsinai.github.io/machine-learning/2024/07/27/micrograd-1-chainrules.html"><![CDATA[<p><em>A series on automatic differentiation in Julia. Part 1 provides an overview and defines explicit chain rules.</em></p>

<p>This is part of a series. The other articles are:</p>
<ul>
  <li><a href="/machine-learning/2024/08/03/micrograd-2-expr">Part 2: Automation with expressions</a>.</li>
  <li><a href="/machine-learning/2024/08/10/micrograd-3-ir">Part 3: Automation with IR</a>.</li>
  <li><a href="/machine-learning/2024/08/17/micrograd-4-ext">Part 4: Extensions</a>.</li>
  <li><a href="/machine-learning/2024/08/19/micrograd-5-mlp">Part 5: MLP</a>.</li>
</ul>

<p>All source code can be found at <a href="https://github.com/LiorSinai/MicroGrad.jl">MicroGrad.jl</a>.</p>

<h3 id="table-of-contents">Table of Contents</h3>

<nav id="toc"></nav>
<script src="/assets/makeTableOfContents.js"></script>

<h2 id="introduction">1 Introduction</h2>

<p>A major convenience of modern machine learning frameworks is automatic differentiation (AD).
Training a machine learning model typically consist of two steps, a forward pass and a backwards pass.
The forward pass takes an input sample and calculates the result. 
Examples include a label in a classifier model or a word or image in a generative model.
In the backward pass, the result is compared to a ground truth sample and the error is backpropagated throughout the model, from the final layers through to the start.
Backpropagation is driven by gradients which are calculated with the differentiation rules of Calculus.</p>

<p>With modern machine learning frameworks, such as <a href="https://pytorch.org/">PyTorch</a> or <a href="https://fluxml.ai/Flux.jl/stable/">Flux.jl</a>, only the forward pass needs to be defined and they will automatically generate the backward pass. This (1) makes them easier to use and (2) enforces consistency between the forward pass and backward pass.</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/micrograd/moons_decision_boundary.png" alt="Decision boundary" />
<figcaption>The probability boundaries of a multi-layer perceptron trained on the moons dataset with MicroGrad.jl.</figcaption>
</figure>

<p>Andrej Kaparthy made an excellent <a href="https://www.youtube.com/watch?v=VMj-3S1tku0">video</a> where he built a minimal automatic differentiation module called <a href="https://github.com/karpathy/micrograd">Micrograd</a> in Python.
This is the first video in his <a href="https://karpathy.ai/zero-to-hero.html">Zero to Hero</a> series.
He later uses it to train a multi-layer perceptron model.
I highly recommend it for anyone who wants to understand backpropagation.</p>

<p>The aim of this series is to create a minimal automatic differentiation package in Julia.
It is based on <a href="https://fluxml.ai/Zygote.jl/stable/">Zygote.jl</a> and works very differently to the Python AD packages.
The latter are based on objects with their own custom implementations of mathematical operations that calculate both the forward and backward passes.
All operations are only done with these objects.<sup id="fnref:micrograd" role="doc-noteref"><a href="#fn:micrograd" class="footnote" rel="footnote">1</a></sup>
Zygote.jl is instead based on the principle that Julia is a functional programming language. 
It utilises Julia’s multiple dispatch feature and its comprehensive metaprogramming abilities to generate new code for the backward pass.
Barring some limitations, it can be used to differentiate all existing functions as well as any custom code.</p>

<p>Zygote’s approach is complex and pushes the boundaries of Julia’s metaprogramming. It can sometimes be <a href="https://discourse.julialang.org/t/state-of-machine-learning-in-julia/74385/4#post_4">buggy</a>.
However its promise is true automatic differentiation of any forward pass code without further work on the coder’s part.</p>

<p>For the final code, see my <a href="https://github.com/LiorSinai/MicroGrad.jl">MicroGrad.jl</a> repository.
It is very versatile but has several limitations, including less code coverage than Zygote.jl and it is unable to handle control flow or keyword arguments.</p>

<p>There are almost no comprehensive tutorials on AD in Julia and so this series aims to cover that gap.
A good understanding of Julia and of Calculus is required.</p>

<h2 id="julia-ad-ecosystem">2 Julia AD Ecosystem</h2>

<p>The Julia automatic differentiation ecosystem is centered around three packages: Flux.jl, ChainRules.jl and Zygote.jl.</p>
<ul>
  <li><a href="https://fluxml.ai/Flux.jl/stable/">Flux.jl</a> is a machine learning framework. It uses either ChainRules.jl or Zygote.jl to differentiate code.</li>
  <li><a href="https://fluxml.ai/Zygote.jl/stable/">Zygote.jl</a> implements automatic differentiation through metaprogramming.
    <ul>
      <li>The core functionality is defined in the minimal <a href="https://github.com/FluxML/ZygoteRules.jl">ZygoteRules.jl</a> package.</li>
      <li>The main functions it exposes are <code class="language-plaintext highlighter-rouge">gradient</code>, <code class="language-plaintext highlighter-rouge">withgradient</code> and <code class="language-plaintext highlighter-rouge">pullback</code>. The <code class="language-plaintext highlighter-rouge">pullback</code> function is a light wrapper around <code class="language-plaintext highlighter-rouge">_pullback</code> which does most of the heavy lifting.</li>
      <li>The goal of <code class="language-plaintext highlighter-rouge">_pullback</code> is to dispatch a function, its arguments and its keyword arguments to a <code class="language-plaintext highlighter-rouge">ChainRule.rrule</code>. If it cannot, it will inspect the code, decompose it into smaller steps, and follow the rules of differentiation to dispatch  each of those to <code class="language-plaintext highlighter-rouge">_pullback</code> to recursively find an <code class="language-plaintext highlighter-rouge">rrule</code>. If this recursive process does not find a valid rule it will raise an error.</li>
    </ul>
  </li>
  <li><a href="https://juliadiff.org/ChainRulesCore.jl/stable/">ChainRules.jl</a> defines forward rules and reverse rules.
    <ul>
      <li>The core functionality is defined in the minimal <a href="https://github.com/JuliaDiff/ChainRulesCore.jl">ChainRulesCore.jl</a> package.</li>
      <li>The main functions it exposes are <code class="language-plaintext highlighter-rouge">frule</code> and <code class="language-plaintext highlighter-rouge">rrule</code>. This series deals only with backpropagation, so it will only concentrate on <code class="language-plaintext highlighter-rouge">rrule</code>.</li>
    </ul>
  </li>
</ul>

<p>Also important is <a href="https://fluxml.ai/IRTools.jl/latest/">IRTools.jl</a>, an extended metaprogramming package for working with an intermediate representation (IR) between raw Julia code and lowered code.
MicroGrad.jl in particular is based on the example code at <a href="https://github.com/FluxML/IRTools.jl/blob/master/examples/reverse.jl">IRTools.jl</a> with alignment with Zyogte.jl functions and names.</p>

<p>As an example, consider the function $f(x) = \sin(\cos(x))$. Using the chain rule of Calculus, it is differentiated as:</p>

\[\begin{align}
\frac{df}{dx} &amp;= \frac{df}{dh}\frac{dh}{dx} \quad ; h(x)=cos(x)\\
              &amp;= \frac{d}{dh}\sin(h)\frac{d}{dx}\cos(x) \\
              &amp;= \cos(h)(-\sin(x)) \\
              &amp;= -\cos(\cos(x))\sin(x)
\end{align}\]

<p><code class="language-plaintext highlighter-rouge">Zygote.withgradient</code>, exposed as <code class="language-plaintext highlighter-rouge">Flux.withgradient</code>, can be used to calculate this:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">Flux</span>
<span class="n">f</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">sin</span><span class="x">(</span><span class="n">cos</span><span class="x">(</span><span class="n">x</span><span class="x">))</span>
<span class="n">y</span><span class="x">,</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">withgradient</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="mf">0.9</span><span class="x">)</span> <span class="c"># 0.5823, (-0.6368,)</span>
<span class="n">grad</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span> <span class="o">==</span> <span class="o">-</span><span class="n">cos</span><span class="x">(</span><span class="n">cos</span><span class="x">(</span><span class="mf">0.9</span><span class="x">))</span><span class="o">*</span><span class="n">sin</span><span class="x">(</span><span class="mf">0.9</span><span class="x">)</span> <span class="c"># true</span></code></pre></figure>

<p>More commonly we differentiate with respect to the model, not the data:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">y</span><span class="x">,</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">withgradient</span><span class="x">(</span><span class="n">m</span><span class="o">-&gt;</span><span class="n">m</span><span class="x">(</span><span class="mf">0.9</span><span class="x">),</span> <span class="n">f</span><span class="x">)</span> <span class="c"># 0.5823, (nothing,)</span></code></pre></figure>

<p>This is more useful for a model with parameters. For example a dense, fully connected layer:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">model</span> <span class="o">=</span> <span class="n">Dense</span><span class="x">(</span><span class="mi">3</span><span class="o">=&gt;</span><span class="mi">1</span><span class="x">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="kt">Float32</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">10</span><span class="x">)</span>
<span class="n">y</span><span class="x">,</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">withgradient</span><span class="x">(</span><span class="n">m</span><span class="o">-&gt;</span><span class="n">sum</span><span class="x">(</span><span class="n">m</span><span class="x">(</span><span class="n">x</span><span class="x">)),</span> <span class="n">model</span><span class="x">)</span> <span class="c"># 1.5056f0, ((weight=[4.9142 6.235 5.3379],bias=Fill(10.0f0,1),σ=nothing),)</span></code></pre></figure>

<p>The aim of the rest of the series is to recreate this functionality.
This first part will focus solely on ChainRules.jl and recreating the <code class="language-plaintext highlighter-rouge">rrule</code> function.
Part 2 will focus on recreating the <code class="language-plaintext highlighter-rouge">Zygote._pullback</code> function.
Part 3 repeats part 2 in a more robust manner.
Part 4 extends part 3’s solution to handle maps, anonymous functions and structs.
Finally part 5 shows how this AD code can be used by a machine learning framework.</p>

<h2 id="chainrules">3 ChainRules</h2>
<h3 id="chainrules-definition">3.1 Definition</h3>

<p>ChainRules.jl’s <code class="language-plaintext highlighter-rouge">rrule</code> returns the output of the forward pass $y(x)$ and a function $\mathcal{B}$ which calculates the backward pass.
$\mathcal{B}$ takes as input $\Delta = \frac{\partial l}{\partial y}$, the gradient of some scalar $l$ with regards to the output variable $y$, and returns a tuple of $\left(\frac{\partial l}{\partial \text{self}}, \frac{\partial l}{\partial x_1}, …, \frac{\partial l}{\partial x_n}\right)$, the gradient of $l$ with regards to each of the input variables $x_i$.
(The extra gradient $\frac{\partial l}{\partial \text{self}}$ is needed for internal fields and closures.
See the <code class="language-plaintext highlighter-rouge">Dense</code> layer example above.)
According to the chain rule of Calculus, each gradient is calculated as:</p>

\[\mathcal{B_i}\left(\frac{\partial l}{\partial y}\right) = \frac{\partial l}{\partial x_i} = \frac{\partial l}{\partial y} \frac{\partial y}{\partial x_i}\]

<p>As a starting point $\frac{\partial l}{\partial y}=1$ is used to evaluate only $\frac{\partial y}{\partial x}$.</p>

<p>If $x$ and $y$ are vectors, then the gradient $J=\frac{\partial y}{\partial x}$ is a <a href="https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant">Jacobian</a>:</p>

\[J = \begin{bmatrix}
\frac{\partial y_1}{\partial x_1} &amp; \dots &amp; \frac{\partial y_1}{\partial x_n}  \\
\vdots &amp; \ddots &amp; \vdots \\
\frac{\partial y_m}{\partial x_1} &amp; \dots &amp; \frac{\partial y_m}{\partial x_n} 
\end{bmatrix}\]

<p>To maintain the correct order, we need to use the <a href="https://juliadiff.org/ChainRulesCore.jl/stable/maths/propagators.html">conjugate transpose (adjoint) of the Jacobian</a>. So each gradient is calculated as:</p>

\[\mathcal{B_i}(\Delta) = J_i^{\dagger} \Delta\]

<p>Note the Jacobian does not need to be explicitly calculated; only the product needs to be. 
This is can be useful when coding the <code class="language-plaintext highlighter-rouge">rrule</code> for matrix functions.
See the section on the chain rule for <a href="#chainrules-matrix-multiplication">matrix multiplication</a> later.</p>

<p>To start, define a default fallback for <code class="language-plaintext highlighter-rouge">rrule</code> that returns <code class="language-plaintext highlighter-rouge">nothing</code> for any function with any number of arguments (<a href="https://github.com/JuliaDiff/ChainRulesCore.jl/blob/a95c181c662ead23aaf9904b8a560bebeb9022a3/src/rules.jl#L131">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">rrule</span><span class="x">(</span><span class="o">::</span><span class="kt">Any</span><span class="x">,</span> <span class="o">::</span><span class="kt">Vararg</span><span class="x">{</span><span class="kt">Any</span><span class="x">})</span> <span class="o">=</span> <span class="nb">nothing</span></code></pre></figure>

<p>An <code class="language-plaintext highlighter-rouge">rrule</code> can now be defined for any function.
For it to be really useful <code class="language-plaintext highlighter-rouge">rrule</code> must cover a large set of functions.
Thankfully ChainRules.jl provides us with that.
However in this post I’ll only work through a limited set of examples.</p>

<h3 id="chainrules-arithmetic">3.2 Arithmetic</h3>

<p>The derivatives of adding two variables is:</p>

\[\frac{\partial}{\partial x}(x+y) = 1 + 0; \frac{\partial}{\partial y}(x+y) = 0 + 1\]

<p>
  <a class="btn" data-toggle="collapse" href="#proof-derivative-addition" role="button" aria-expanded="false" aria-controls="collapse-derivative-addition">
    Proof &#8681;
  </a>
</p>
<div class="collapse" id="proof-derivative-addition">
  <div class="card card-body ">
		<p>
    $$
    \begin{align}
    \Delta f_x &amp;= (x+\Delta x+ y) - (x+y) \\
    \therefore \lim_{\Delta x \to 0}\frac{\Delta f_x}{\Delta x} &amp;=\frac{\partial f}{\partial x}= 1 \\
    \therefore \lim_{\Delta y \to 0}\frac{\Delta f_y}{\Delta y} &amp;=\frac{\partial f}{\partial y}= 1
    \end{align}
    $$
    </p>
  </div>
</div>

<p>There are no internal fields so $\frac{\partial l}{\partial \text{self}}$ is <code class="language-plaintext highlighter-rouge">nothing</code>.
$\mathcal{B}$ can be returned as an anonymous function, but giving it the name <code class="language-plaintext highlighter-rouge">add_back</code>  helps with debugging (<a href="https://github.com/JuliaDiff/ChainRules.jl/blob/dba6cb57d73ba837c5ab6fd1f968f3a5d301ca9c/src/rulesets/Base/fastmath_able.jl#L167">source</a>).</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="o">+</span><span class="x">),</span> <span class="n">x</span><span class="o">::</span><span class="kt">Real</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Real</span><span class="x">)</span>
    <span class="n">add_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span> <span class="o">=</span> <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="nb">true</span> <span class="o">*</span> <span class="n">Δ</span><span class="x">,</span> <span class="nb">true</span> <span class="o">*</span> <span class="n">Δ</span><span class="x">)</span> <span class="c"># ∂self, ∂x, ∂y</span>
    <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="x">,</span> <span class="n">add_back</span> <span class="c"># also (Δ) -&gt; (nothing, true * Δ, true * Δ)</span>
<span class="k">end</span></code></pre></figure>

<p>Usage:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">rrule</span><span class="x">(</span><span class="o">+</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="mi">2</span><span class="x">)</span> <span class="c"># (3, var"#add_back#"())</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.2</span><span class="x">)</span> <span class="c"># (nothing, 1.2, 1.2)</span></code></pre></figure>

<p>Subtraction is almost identical:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="o">-</span><span class="x">),</span> <span class="n">x</span><span class="o">::</span><span class="kt">Real</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Real</span><span class="x">)</span>
    <span class="n">minus_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span> <span class="o">=</span> <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="nb">true</span> <span class="o">*</span> <span class="n">Δ</span><span class="x">,</span> <span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="n">Δ</span><span class="x">)</span> <span class="c"># ∂f, ∂x, ∂y</span>
    <span class="n">x</span> <span class="o">-</span> <span class="n">y</span><span class="x">,</span> <span class="n">minus_back</span>
<span class="k">end</span></code></pre></figure>

<p>With multiplication, the incoming gradient is multiplied by the other variable:</p>

\[\frac{\partial}{\partial x}(xy) = y; \frac{\partial}{\partial y}(xy) = x\]

<p>
  <a class="btn" data-toggle="collapse" href="#proof-derivative-multiplication" role="button" aria-expanded="false" aria-controls="collapse-derivative-multiplication">
    Proof &#8681;
  </a>
</p>
<div class="collapse" id="proof-derivative-multiplication">
  <div class="card card-body ">
		<p>
    $$
    \begin{align}
    \Delta f_x &amp;= (x+\Delta x)y - xy \\
    \therefore \lim_{\Delta x \to 0}\frac{\Delta f_x}{\Delta x} &amp;=\frac{\partial f}{\partial x}= y \\
    \therefore \lim_{\Delta y \to 0}\frac{\Delta f_y}{\Delta y} &amp;=\frac{\partial f}{\partial y}= x
    \end{align}
    $$
    </p>
  </div>
</div>

<p>In code (<a href="https://github.com/JuliaDiff/ChainRules.jl/blob/dba6cb57d73ba837c5ab6fd1f968f3a5d301ca9c/src/rulesets/Base/fastmath_able.jl#L254">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="o">*</span><span class="x">),</span> <span class="n">x</span><span class="o">::</span><span class="kt">Real</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Real</span><span class="x">)</span>
    <span class="n">times_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span> <span class="o">=</span> <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="n">y</span> <span class="o">*</span> <span class="n">Δ</span><span class="x">,</span> <span class="n">x</span> <span class="o">*</span> <span class="n">Δ</span><span class="x">)</span> <span class="c"># ∂self, ∂x, ∂y</span>
    <span class="n">x</span> <span class="o">*</span> <span class="n">y</span><span class="x">,</span> <span class="n">times_back</span>
<span class="k">end</span></code></pre></figure>

<p>Note that Julia will create a <em>closure</em> around the incoming <code class="language-plaintext highlighter-rouge">x</code> and <code class="language-plaintext highlighter-rouge">y</code> variables for <code class="language-plaintext highlighter-rouge">times_back</code>.
A closure is when the function stores the values of variables from its parents scope (it closes over the variables).
In other words, <code class="language-plaintext highlighter-rouge">x</code> and <code class="language-plaintext highlighter-rouge">y</code> will become constants in the <code class="language-plaintext highlighter-rouge">times_back</code> scope.
In this way, the <code class="language-plaintext highlighter-rouge">times_back</code> function will always “remember” what values it was called with:</p>

<p>Example:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">rrule</span><span class="x">(</span><span class="o">*</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">3</span><span class="x">)</span> <span class="c"># (6, var"#times_back#4"{Int64, Int64}(2, 3))</span>
<span class="n">back</span><span class="o">.</span><span class="n">x</span> <span class="c"># 2</span>
<span class="n">back</span><span class="o">.</span><span class="n">y</span> <span class="c"># 3</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.2</span><span class="x">)</span> <span class="c"># (nothing, 3.6, 2.4)</span></code></pre></figure>

<p>Every call to <code class="language-plaintext highlighter-rouge">rrule</code> with <code class="language-plaintext highlighter-rouge">*</code> will return a different <code class="language-plaintext highlighter-rouge">back</code> instance based on the input arguments.</p>

<p>Division is slightly different in that the derivatives look different for $x$ and $y$:</p>

\[\frac{\partial}{\partial x}\frac{x}{y} = \frac{1}{y}; \frac{\partial}{\partial y}\frac{x}{y}= -\frac{x}{y^2}\]

<p>
  <a class="btn" data-toggle="collapse" href="#proof-derivative-division" role="button" aria-expanded="false" aria-controls="collapse-derivative-division">
    Proof &#8681;
  </a>
</p>
<div class="collapse" id="proof-derivative-division">
  <div class="card card-body ">
		<p>
    $$
    \begin{align}
    \Delta f_x &amp;= \frac{x+\Delta x}{y} - \frac{x}{y} \\
    \therefore \lim_{\Delta x \to 0}\frac{\Delta f_x}{\Delta x} &amp;=\frac{\partial f}{\partial x}= \frac{1}{y} \\
    \Delta f_y &amp;= \frac{x}{y+\Delta y} - \frac{x}{y} \\
             &amp;= \frac{xy}{y(y+\Delta y)} - \frac{x(y+\Delta y)}{y(y+\Delta y)} \\
             &amp;= -\frac{x \Delta y}{y(y+\Delta y)} \\
   \therefore \lim_{\Delta y \to 0}\frac{\Delta f_y}{\Delta y} &amp;=\frac{\partial f}{\partial y} = -\frac{x}{y^2}         
    \end{align}
    $$
    </p>
  </div>
</div>

<p>Here we can calculate an internal variable <code class="language-plaintext highlighter-rouge">Ω</code> to close over, and use it for the $\frac{\partial}{\partial y}$ derivative (<a href="https://github.com/JuliaDiff/ChainRules.jl/blob/dba6cb57d73ba837c5ab6fd1f968f3a5d301ca9c/src/rulesets/Base/fastmath_able.jl#L169">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="o">/</span><span class="x">),</span> <span class="n">x</span><span class="o">::</span><span class="kt">Real</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Real</span><span class="x">)</span>
    <span class="n">Ω</span> <span class="o">=</span> <span class="n">x</span> <span class="o">/</span> <span class="n">y</span>
    <span class="n">divide_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span> <span class="o">=</span> <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">y</span> <span class="o">*</span> <span class="n">Δ</span><span class="x">,</span> <span class="o">-</span><span class="n">Ω</span><span class="o">/</span><span class="n">y</span> <span class="o">*</span> <span class="n">Δ</span><span class="x">)</span> <span class="c"># ∂self, ∂x, ∂y</span>
    <span class="n">Ω</span><span class="x">,</span> <span class="n">divide_back</span>
<span class="k">end</span></code></pre></figure>

<p>Example:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">z</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">rrule</span><span class="x">(</span><span class="o">/</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">3</span><span class="x">)</span> <span class="c"># (0.6667, var"#divide_back#5"{Int64, Float64}(3, 0.6667))</span>
<span class="n">back</span><span class="o">.</span><span class="n">Ω</span> <span class="c"># 0.6667</span>
<span class="n">back</span><span class="o">.</span><span class="n">y</span> <span class="c"># 3</span>
<span class="n">back</span><span class="o">.</span><span class="n">x</span> <span class="c"># ERROR</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.2</span><span class="x">)</span> <span class="c"># (nothing, 0.4, -0.2667)</span></code></pre></figure>

<h3 id="chainrules-trigonometry">3.3 Trigonometry</h3>

<p>The derivatives of $\sin$ and $\cos$ are:</p>

\[\begin{align}
  \frac{\partial}{\partial x} \sin(x) &amp;= \cos(x) \\
  \frac{\partial}{\partial x} \cos(x) &amp;= -\sin(x)
\end{align}\]

<p>Because both use $\sin$ and $\cos$, we can use <code class="language-plaintext highlighter-rouge">sincos</code> to calculate both simultaneously and more efficiently than calculating each on its own. This shows the advantage of calculating the forward pass and backward pass at the same time (<a href="https://github.com/JuliaDiff/ChainRules.jl/blob/dba6cb57d73ba837c5ab6fd1f968f3a5d301ca9c/src/rulesets/Base/fastmath_able.jl#L12">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">sin</span><span class="x">),</span> <span class="n">x</span><span class="o">::</span><span class="kt">Number</span><span class="x">)</span>
    <span class="n">s</span><span class="x">,</span> <span class="n">c</span> <span class="o">=</span> <span class="n">sincos</span><span class="x">(</span><span class="n">x</span><span class="x">)</span>
    <span class="n">sin_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span> <span class="o">=</span> <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="n">Δ</span> <span class="o">*</span> <span class="n">c</span><span class="x">)</span> <span class="c"># ∂self, ∂x</span>
    <span class="n">s</span><span class="x">,</span> <span class="n">sin_back</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">cos</span><span class="x">),</span> <span class="n">x</span><span class="o">::</span><span class="kt">Number</span><span class="x">)</span>
    <span class="n">s</span><span class="x">,</span> <span class="n">c</span> <span class="o">=</span> <span class="n">sincos</span><span class="x">(</span><span class="n">x</span><span class="x">)</span>
    <span class="n">cos_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span> <span class="o">=</span> <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="o">-</span><span class="n">Δ</span> <span class="o">*</span> <span class="n">s</span><span class="x">)</span> <span class="c"># ∂self, ∂x</span>
    <span class="n">c</span><span class="x">,</span> <span class="n">cos_back</span>
<span class="k">end</span></code></pre></figure>

<p>Let’s now revisit the example from earlier, $f(x) = \sin(\cos(x))$.
We have the forward pass:</p>

\[\begin{align}
y_1 &amp;= \cos(x) \\
y_2 &amp;= \sin(y_1)\\
\end{align}\]

<p>And the backwards pass:</p>

\[\begin{align}
\frac{\partial y_2}{\partial y_1} &amp;= (1.0)  \frac{\partial}{\partial y_1} \sin(y_1) \\
            &amp;= \cos(y_1) \\
\frac{\partial y_2}{\partial x} &amp;= \frac{\partial y_2}{\partial y_1} \frac{\partial}{\partial x} \cos(x) \\
         &amp;= -\Delta_2 \sin(x)
\end{align}\]

<p>In code:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">x</span> <span class="o">=</span> <span class="mf">0.9</span>
<span class="n">y1</span><span class="x">,</span> <span class="n">back1</span> <span class="o">=</span> <span class="n">rrule</span><span class="x">(</span><span class="n">cos</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span> <span class="c"># (0.6216, cos_back)</span>
<span class="n">y2</span><span class="x">,</span> <span class="n">back2</span> <span class="o">=</span> <span class="n">rrule</span><span class="x">(</span><span class="n">sin</span><span class="x">,</span> <span class="n">y1</span><span class="x">)</span> <span class="c"># (0.5823, sin_back)</span>
<span class="n">grad_sin</span><span class="x">,</span> <span class="n">grad_y1</span> <span class="o">=</span> <span class="n">back2</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># (nothing, 0 .8129)</span>
<span class="n">grad_cos</span><span class="x">,</span> <span class="n">grad_x</span> <span class="o">=</span> <span class="n">back1</span><span class="x">(</span><span class="n">grad_y1</span><span class="x">)</span> <span class="c"># (nothing, -0.6368)</span>
<span class="n">grad_x</span> <span class="o">==</span> <span class="o">-</span><span class="n">cos</span><span class="x">(</span><span class="n">cos</span><span class="x">(</span><span class="n">x</span><span class="x">))</span><span class="o">*</span><span class="n">sin</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="c"># true</span></code></pre></figure>

<h3 id="chainrules-polynomial">3.4 Polynomials</h3>

<p>The next section will showcase an example of polynomial curve fitting.
This requires an <code class="language-plaintext highlighter-rouge">rrule</code> for the <code class="language-plaintext highlighter-rouge">evalpoly</code> function.</p>

<p>For a general polynomial:</p>

\[y = a_0 + a_1x + a_2x^2 + ... + a_n x^n\]

<p>The derivatives are:</p>

\[\begin{align}
\frac{\partial y}{\partial x} &amp;= 0 + a_1 + 2a_2x^1 + ... + n a_n x^{n-1} \\
\frac{\partial y}{\partial a_i} &amp;= 0 + ... + x^{i} + ... + 0
\end{align}\]

<p>For the most efficient implementation, the powers of $x$ can be calculated for both the forward and backwards pass at the same time.
For simplicity, I’m not going to do that (<a href="https://github.com/JuliaDiff/ChainRules.jl/blob/dba6cb57d73ba837c5ab6fd1f968f3a5d301ca9c/src/rulesets/Base/evalpoly.jl">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">evalpoly</span><span class="x">),</span> <span class="n">x</span><span class="x">,</span> <span class="n">coeffs</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">)</span>
    <span class="n">y</span> <span class="o">=</span> <span class="n">evalpoly</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">coeffs</span><span class="x">)</span>
    <span class="k">function</span><span class="nf"> evalpoly_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span>
        <span class="n">xpow</span> <span class="o">=</span> <span class="n">one</span><span class="x">(</span><span class="n">x</span><span class="x">)</span>
        <span class="n">dp</span> <span class="o">=</span> <span class="n">similar</span><span class="x">(</span><span class="n">coeffs</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="n">xpow</span> <span class="o">*</span> <span class="n">Δ</span><span class="x">))</span>
        <span class="n">dx</span> <span class="o">=</span> <span class="n">zero</span><span class="x">(</span><span class="n">x</span><span class="x">)</span>
        <span class="k">for</span> <span class="n">i</span> <span class="k">in</span> <span class="n">eachindex</span><span class="x">(</span><span class="n">coeffs</span><span class="x">)</span>
            <span class="n">dp</span><span class="x">[</span><span class="n">i</span><span class="x">]</span> <span class="o">=</span> <span class="n">Δ</span> <span class="o">*</span> <span class="n">xpow</span>
            <span class="n">dx</span> <span class="o">+=</span> <span class="x">(</span><span class="n">i</span><span class="o">-</span><span class="mi">1</span><span class="x">)</span> <span class="o">*</span> <span class="n">coeffs</span><span class="x">[</span><span class="n">i</span><span class="x">]</span> <span class="o">*</span> <span class="n">xpow</span> <span class="o">/</span> <span class="n">x</span> <span class="o">*</span> <span class="n">Δ</span>
            <span class="n">xpow</span> <span class="o">*=</span> <span class="n">x</span>
        <span class="k">end</span>
        <span class="k">return</span> <span class="nb">nothing</span><span class="x">,</span> <span class="n">dx</span><span class="x">,</span> <span class="n">dp</span>
    <span class="k">end</span>
    <span class="n">y</span><span class="x">,</span> <span class="n">evalpoly_back</span>
<span class="k">end</span></code></pre></figure>

<p>Usage:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">y</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">rrule</span><span class="x">(</span><span class="n">evalpoly</span><span class="x">,</span> <span class="mf">1.2</span><span class="x">,</span> <span class="x">[</span><span class="mf">2.0</span><span class="x">,</span> <span class="mf">0.0</span><span class="x">,</span> <span class="mf">3.0</span><span class="x">,</span> <span class="mf">4.0</span><span class="x">])</span> <span class="c"># 13.232, evalpoly_back</span>
<span class="n">back</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span> <span class="c"># (nothing, 24.48, [1.0, 1.2, 1.44, 1.728]) </span></code></pre></figure>

<h3 id="chainrules-matrix-multiplication">3.5 Matrix multiplication</h3>

<p>For some scaler loss function $l$, we can calculate a derivative $\Delta=\frac{\partial l}{\partial Y}$
against some matrix $Y$. Then for $Y=AB$, the partial derivatives are:</p>

\[\begin{align}
\frac{\partial l}{\partial A} &amp;= \frac{\partial Y}{\partial A} \frac{\partial L}{\partial Y} \\
                              &amp;= \Delta B^T \\
\frac{\partial l}{\partial B} &amp;= \frac{\partial Y}{\partial B} \frac{\partial L}{\partial Y} \\
                              &amp;= A^T \Delta
\end{align}\]

<p>Note that the Jacobians $\frac{\partial Y}{\partial A}$ and $\frac{\partial Y}{\partial B}$ are not explicitly calculated here; only the product is. (These Jacobians would have many zeros because each output element depends only on a small subset of the input elements.)</p>

<div class="message-container info-message">
	<div class="message-icon fa fa-fw fa-2x fa-exclamation-circle"></div>
  <div class="content-container">
    <div class="message-body">
      The most common use case in machine learning is $Y=WX$, where $W$ is a set of weights and $X$ is the data.
      Machine learning algorithms only alter the weights, not the data. Hence only $\frac{\partial l}{\partial W}$ is required.
      This means computation is wasted on $\frac{\partial l}{\partial X}$.
      For large matrices, this can be significant.
      To avoid this ChainRules.jl uses the <code>ChainRulesCore.@thunk</code> macro to wrap code in a <code>ChainRulesCore.Thunk</code> struct. This struct defers computation until it is used. 
      If it is not used, the computation is not run.
    </div>
  </div>
</div>

<p>In code (<a href="https://github.com/JuliaDiff/ChainRules.jl/blob/dba6cb57d73ba837c5ab6fd1f968f3a5d301ca9c/src/rulesets/Base/arraymath.jl#L27">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="o">*</span><span class="x">),</span> <span class="n">A</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">{</span><span class="o">&lt;:</span><span class="kt">Real</span><span class="x">},</span> <span class="n">B</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">{</span><span class="o">&lt;:</span><span class="kt">Real</span><span class="x">})</span>
    <span class="k">function</span><span class="nf"> times_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span>
        <span class="n">dA</span> <span class="o">=</span> <span class="n">Δ</span> <span class="o">*</span> <span class="n">B</span><span class="err">'</span>
        <span class="n">dB</span> <span class="o">=</span> <span class="n">A</span><span class="err">'</span> <span class="o">*</span> <span class="n">Δ</span>
        <span class="k">return</span> <span class="x">(</span><span class="nb">nothing</span><span class="x">,</span> <span class="n">dA</span><span class="x">,</span> <span class="n">dB</span><span class="x">)</span>
    <span class="k">end</span>
    <span class="n">A</span> <span class="o">*</span> <span class="n">B</span><span class="x">,</span> <span class="n">times_back</span>
<span class="k">end</span></code></pre></figure>

<p>Test:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">A</span><span class="x">,</span> <span class="n">B</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">),</span> <span class="n">rand</span><span class="x">(</span><span class="mi">4</span><span class="x">,</span> <span class="mi">3</span><span class="x">)</span>
<span class="n">C</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">rrule</span><span class="x">(</span><span class="o">*</span><span class="x">,</span> <span class="n">A</span><span class="x">,</span> <span class="n">B</span><span class="x">)</span> <span class="c"># (2×3 Matrix{Float64}, times_back)</span>
<span class="n">back</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">3</span><span class="x">))</span> <span class="c"># (nothing, 2×4 Matrix, 4×3 Matrix)</span></code></pre></figure>

<h3 id="chainrules-mse">3.6 MSE</h3>

<p>The mean square error (MSE) is a common loss function in machine learning.
It will be used shortly for polynomial curve fitting.
It is:</p>

\[MSE(\hat{y}, y) = \frac{1}{n}\sum^n_{i=1} (\hat{y}_i - y_i)^2\]

<p>with derivatives:</p>

\[\begin{align}
  \frac{\partial MSE}{\partial \hat{y}_i} &amp;= \frac{1}{n}(0 + ... + 2(\hat{y}_i - y_i) + ... + 0) \\
        &amp;= \frac{2(\hat{y}_i - y_i)}{n} \\
  \frac{\partial MSE}{\partial y_i} &amp;= \frac{1}{n}(0 + ... - 2(\hat{y}_i - y_i) + ... + 0) \\
       &amp;= -\frac{2(\hat{y}_i - y_i)}{n}
\end{align}\]

<p>In code it is:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">StatsBase</span>
<span class="n">mse</span><span class="x">(</span><span class="n">ŷ</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">)</span> <span class="o">=</span> <span class="n">mean</span><span class="x">(</span><span class="n">abs2</span><span class="o">.</span><span class="x">(</span><span class="n">ŷ</span> <span class="o">-</span> <span class="n">y</span><span class="x">))</span></code></pre></figure>

<p>Flux.jl does not define an <code class="language-plaintext highlighter-rouge">rrule</code> for its <code class="language-plaintext highlighter-rouge">mse</code> because it can be decomposed into functions which already have an <code class="language-plaintext highlighter-rouge">rrule</code> (<code class="language-plaintext highlighter-rouge">-</code>, <code class="language-plaintext highlighter-rouge">broadcast</code>, <code class="language-plaintext highlighter-rouge">abs2</code> and <code class="language-plaintext highlighter-rouge">mean</code>). 
However since we don’t have <code class="language-plaintext highlighter-rouge">rrule</code>s for these parts and have not yet automated decomposition, it is simplest to create an <code class="language-plaintext highlighter-rouge">rrule</code> for the entire function:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">mse</span><span class="x">),</span> <span class="n">ŷ</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">AbstractVecOrMat</span><span class="x">)</span>
    <span class="n">Ω</span> <span class="o">=</span> <span class="n">mse</span><span class="x">(</span><span class="n">ŷ</span><span class="x">,</span> <span class="n">y</span><span class="x">)</span>
    <span class="k">function</span><span class="nf"> mse_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span>
        <span class="n">c</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="x">(</span><span class="n">ŷ</span> <span class="o">-</span> <span class="n">y</span><span class="x">)</span> <span class="o">/</span> <span class="n">length</span><span class="x">(</span><span class="n">y</span><span class="x">)</span> <span class="o">*</span> <span class="n">Δ</span>
        <span class="k">return</span> <span class="nb">nothing</span><span class="x">,</span> <span class="n">c</span><span class="x">,</span> <span class="o">-</span><span class="n">c</span> <span class="c"># ∂self, ∂ŷ, ∂y</span>
    <span class="k">end</span>
    <span class="n">Ω</span><span class="x">,</span> <span class="n">mse_back</span>
<span class="k">end</span></code></pre></figure>

<p>The <code class="language-plaintext highlighter-rouge">mse</code> can also be applied per individual data point and summed up separately.
This form is not common but will be useful for explanatory purposes in the polynomial curve fitting section:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">mse</span><span class="x">(</span><span class="n">ŷ</span><span class="o">::</span><span class="kt">Number</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Number</span><span class="x">,</span> <span class="n">n</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span> <span class="o">=</span> <span class="n">abs2</span><span class="x">(</span><span class="n">ŷ</span> <span class="o">-</span> <span class="n">y</span><span class="x">)</span><span class="o">/</span><span class="n">n</span>
<span class="k">function</span><span class="nf"> rrule</span><span class="x">(</span><span class="o">::</span><span class="n">typeof</span><span class="x">(</span><span class="n">mse</span><span class="x">),</span> <span class="n">ŷ</span><span class="o">::</span><span class="kt">Number</span><span class="x">,</span> <span class="n">y</span><span class="o">::</span><span class="kt">Number</span><span class="x">,</span> <span class="n">n</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span>
    <span class="n">Ω</span> <span class="o">=</span> <span class="n">mse</span><span class="x">(</span><span class="n">ŷ</span><span class="x">,</span> <span class="n">y</span><span class="x">,</span> <span class="n">n</span><span class="x">)</span>
    <span class="k">function</span><span class="nf"> mse_back</span><span class="x">(</span><span class="n">Δ</span><span class="x">)</span>
        <span class="n">c</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="x">(</span><span class="n">ŷ</span> <span class="o">-</span> <span class="n">y</span><span class="x">)</span> <span class="o">/</span> <span class="n">n</span> <span class="o">*</span> <span class="n">Δ</span>
        <span class="k">return</span> <span class="nb">nothing</span><span class="x">,</span> <span class="n">c</span><span class="x">,</span> <span class="o">-</span><span class="n">c</span><span class="x">,</span> <span class="o">-</span><span class="n">Ω</span><span class="o">/</span><span class="n">n</span> <span class="c"># ∂self, ∂ŷ, ∂y, ∂n</span>
    <span class="k">end</span>
    <span class="n">Ω</span><span class="x">,</span> <span class="n">mse_back</span>
<span class="k">end</span></code></pre></figure>

<h2 id="gradient-descent">4 Gradient Descent</h2>
<h3 id="polynomial-curve-fitting">4.1 Polynomial curve fitting</h3>

<p>Gradient descent is a great algorithm to illustrate the usefulness of the code developed so far.
The toy example of fitting a polynomial to data will be used.
This is a useful example because (1) we can start with a target curve and so have ground truth values to compare and (2) this problem can be solved analytically without gradients.</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/micrograd/polyfit_data.png" alt="Polynomial with noise" />
<figcaption></figcaption>
</figure>

<p>Here is code to create the above data:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">StatsBase</span>
<span class="n">target_weights</span> <span class="o">=</span> <span class="x">[</span><span class="mf">15.0</span><span class="x">,</span> <span class="o">-</span><span class="mf">2.1</span><span class="x">,</span> <span class="mf">13.9</span><span class="x">,</span> <span class="mf">1.5</span><span class="x">]</span>
<span class="n">noise_factor</span> <span class="o">=</span> <span class="mf">0.2</span>
<span class="n">xs</span> <span class="o">=</span> <span class="x">(</span><span class="n">rand</span><span class="x">(</span><span class="mi">100</span><span class="x">)</span> <span class="o">.-</span> <span class="mf">0.5</span><span class="x">)</span> <span class="o">.*</span> <span class="mi">10</span>
<span class="n">ys</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">x</span> <span class="o">-&gt;</span> <span class="n">evalpoly</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">target_weights</span><span class="x">),</span> <span class="n">xs</span><span class="x">)</span>
<span class="n">scale_factor</span> <span class="o">=</span> <span class="n">mean</span><span class="x">(</span><span class="n">abs</span><span class="o">.</span><span class="x">(</span><span class="n">ys</span><span class="x">))</span>
<span class="n">ys</span> <span class="o">.+=</span> <span class="n">randn</span><span class="x">(</span><span class="n">length</span><span class="x">(</span><span class="n">ys</span><span class="x">))</span> <span class="o">*</span> <span class="n">scale_factor</span> <span class="o">*</span> <span class="n">noise_factor</span></code></pre></figure>

<p>
  <a class="btn" data-toggle="collapse" href="#poly-fit-analytical" role="button" aria-expanded="false" aria-controls="collapseExample">
    Analytical least squares fitting of polynomials &#8681;
  </a>
</p>
<div class="collapse" id="poly-fit-analytical">
  <div class="card card-body ">
		<p> For a polynomial of order $p$, if there are exactly $n=p+1$ training samples (including for the constant $a_0$) than there exactly $n$ equations for $n$ unknowns ($a_0$,...,$a_p$) and this can be solved as an ordinary linear system:
        $$
            \begin{align}
            &amp;a_0 + a_1 x_1 + a_2x_1^2 + ... + a_p x_1^p = y_1 \\
            &amp;\vdots \\
            &amp;a_0 + a_1 x_n + a_2x_n^2 + ... + a_p x_n^p = y_n \\
            &amp;\Rightarrow \begin{bmatrix}
            1 &amp; x_1 &amp; x_1^2 &amp; \cdots &amp; x_1^p \\
            \vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\
            1 &amp; x_n &amp; x_n^2 &amp; \cdots &amp; x_n^p
            \end{bmatrix}
            \begin{bmatrix}
            a_0 \\
            \vdots \\
            a_n
            \end{bmatrix}
            =
            \begin{bmatrix}
            y_1 \\
            \vdots \\
            y_n
            \end{bmatrix} \\
            &amp;\Rightarrow XA=Y \\
            &amp;\Rightarrow A = X^{-1}Y
            \end{align}
        $$
        Where $X^{-1}$ usually exists because $X$ is a square matrix.
        </p>
        <p>
        However usually $n &gt; p + 1$ and thus $X^{-1}$ will not exist. In that case the pseudoinverse $X^+$, also called the <a href="https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse">Moore-Penrose inverse</a>, can be used instead:
        $$
            \begin{align}
            X^{+} &amp;= (X^T X)^{-1} X^T \\
            \Rightarrow A &amp;= X^{+}Y
            \end{align}
        $$
        It can be proven that this solution for $A$ minimises the least squared error.  
        </p>
        <p>
        Here is this solution in code:

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">LinearAlgebra</span>
<span class="k">function</span><span class="nf"> solve_poly_linear</span><span class="x">(</span><span class="n">order</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">xs</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">,</span> <span class="n">ys</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">)</span>
    <span class="n">n</span> <span class="o">=</span> <span class="n">length</span><span class="x">(</span><span class="n">xs</span><span class="x">)</span>
    <span class="n">X</span> <span class="o">=</span> <span class="n">zeros</span><span class="x">(</span><span class="n">n</span><span class="x">,</span> <span class="n">order</span> <span class="o">+</span> <span class="mi">1</span><span class="x">)</span>
    <span class="k">for</span> <span class="x">(</span><span class="n">i</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span> <span class="k">in</span> <span class="n">enumerate</span><span class="x">(</span><span class="n">xs</span><span class="x">)</span>
        <span class="n">xpow</span> <span class="o">=</span> <span class="mi">1</span>
        <span class="k">for</span> <span class="n">j</span> <span class="k">in</span> <span class="mi">1</span><span class="o">:</span><span class="x">(</span><span class="n">size</span><span class="x">(</span><span class="n">X</span><span class="x">,</span> <span class="mi">2</span><span class="x">))</span>
            <span class="n">X</span><span class="x">[</span><span class="n">i</span><span class="x">,</span> <span class="n">j</span><span class="x">]</span> <span class="o">=</span> <span class="n">xpow</span>
            <span class="n">xpow</span> <span class="o">*=</span> <span class="n">x</span>
        <span class="k">end</span>
    <span class="k">end</span>
    <span class="n">pinv</span><span class="x">(</span><span class="n">X</span><span class="x">)</span> <span class="o">*</span> <span class="n">ys</span>
<span class="k">end</span></code></pre></figure>

        </p>
  </div>
</div>

<p>Here is a simple version of gradient descent:</p>
<blockquote>
<u><b>Gradient descent</b></u> <br />
<b>while</b> (criteria is not met) <b>do</b>:<br />
$\quad$ $\Delta = 0$ <br />
$\quad$ <b>for</b> sample, label in train_set <b>do</b>: <br />
$\quad\quad$ $\Delta \leftarrow \Delta + \frac{\partial}{\partial\theta_j}L$($m_{\theta_j}$(sample), label) <br />
$\quad$ $\theta_{j+1}$ $\leftarrow \theta_j - \alpha \Delta$
</blockquote>

<p>where $m_\theta$ is the model with parameters $\theta$ and $L$ is the loss function.</p>

<p>This is a Julia implementation for specifically applying the algorithm to polynomials.
The stopping condition is a maximum number of iterations, so the <code class="language-plaintext highlighter-rouge">while</code> loop has been replaced with a <code class="language-plaintext highlighter-rouge">for</code> loop.
The code also saves the loss so that the training progress can be analysed.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> gradient_descent_poly!</span><span class="x">(</span>
    <span class="n">coeffs</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">,</span>
    <span class="n">xs</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">,</span>
    <span class="n">ys</span><span class="o">::</span><span class="kt">AbstractVector</span>
    <span class="x">;</span> <span class="n">learning_rate</span><span class="o">::</span><span class="kt">AbstractFloat</span><span class="o">=</span><span class="mf">0.1</span><span class="x">,</span>
    <span class="n">max_iters</span><span class="o">::</span><span class="kt">Integer</span><span class="o">=</span><span class="mi">100</span>
    <span class="x">)</span>
    <span class="n">history</span> <span class="o">=</span> <span class="kt">Float64</span><span class="x">[]</span>
    <span class="n">n</span> <span class="o">=</span> <span class="n">length</span><span class="x">(</span><span class="n">xs</span><span class="x">)</span>
    <span class="n">p</span> <span class="o">=</span> <span class="n">length</span><span class="x">(</span><span class="n">coeffs</span><span class="x">)</span>
    <span class="k">for</span> <span class="n">i</span> <span class="k">in</span> <span class="mi">1</span><span class="o">:</span><span class="n">max_iters</span>
        <span class="n">loss_iter</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="n">Δcoeffs</span> <span class="o">=</span> <span class="n">zeros</span><span class="x">(</span><span class="n">p</span><span class="x">)</span>
        <span class="k">for</span> <span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">y</span><span class="x">)</span> <span class="k">in</span> <span class="n">zip</span><span class="x">(</span><span class="n">xs</span><span class="x">,</span> <span class="n">ys</span><span class="x">)</span>
            <span class="c"># forward</span>
            <span class="n">ŷ</span><span class="x">,</span> <span class="n">back_poly</span> <span class="o">=</span> <span class="n">rrule</span><span class="x">(</span><span class="n">evalpoly</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="n">coeffs</span><span class="x">)</span>
            <span class="n">loss_x</span><span class="x">,</span> <span class="n">back_loss</span> <span class="o">=</span> <span class="n">rrule</span><span class="x">(</span><span class="n">mse</span><span class="x">,</span> <span class="n">ŷ</span><span class="x">,</span> <span class="n">y</span><span class="x">,</span> <span class="n">n</span><span class="x">)</span>
            <span class="c"># reverse</span>
            <span class="n">Δloss</span><span class="x">,</span> <span class="n">Δŷ</span><span class="x">,</span> <span class="n">Δy</span><span class="x">,</span> <span class="n">Δn</span> <span class="o">=</span> <span class="n">back_loss</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span>    
            <span class="n">Δevalpoly</span><span class="x">,</span> <span class="n">Δx</span><span class="x">,</span> <span class="n">Δcoeffs_x</span> <span class="o">=</span> <span class="n">back_poly</span><span class="x">(</span><span class="n">Δŷ</span><span class="x">)</span>
            <span class="c"># accumulate</span>
            <span class="n">loss_iter</span> <span class="o">+=</span> <span class="n">loss_x</span>
            <span class="n">Δcoeffs</span> <span class="o">+=</span> <span class="n">Δcoeffs_x</span>
        <span class="k">end</span>
        <span class="c"># update</span>
        <span class="n">coeffs</span> <span class="o">.-=</span> <span class="n">learning_rate</span> <span class="o">.*</span> <span class="n">Δcoeffs</span>
        <span class="c"># history</span>
        <span class="n">push!</span><span class="x">(</span><span class="n">history</span><span class="x">,</span> <span class="n">loss_iter</span><span class="x">)</span>
    <span class="k">end</span>
    <span class="n">history</span>
<span class="k">end</span></code></pre></figure>

<p>Calling the code:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">coeffs</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">4</span><span class="x">)</span>
<span class="n">history</span> <span class="o">=</span> <span class="n">gradient_descent_poly!</span><span class="x">(</span><span class="n">coeffs</span><span class="x">,</span> <span class="n">xs</span><span class="x">,</span> <span class="n">ys</span><span class="x">;</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-5</span><span class="x">,</span> <span class="n">max_iters</span><span class="o">=</span><span class="mi">2000</span><span class="x">)</span></code></pre></figure>

<p>Plotting the history:</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/micrograd/polyfit_training.png" alt="Gradient descent training history" />
<figcaption></figcaption>
</figure>

<p>Comparing losses on the train set:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">ys_est</span> <span class="o">=</span>  <span class="n">map</span><span class="x">(</span><span class="n">x</span> <span class="o">-&gt;</span> <span class="n">evalpoly</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">coeffs</span><span class="x">),</span> <span class="n">xs</span><span class="x">)</span>
<span class="n">mse</span><span class="x">(</span><span class="n">ys_est</span><span class="x">,</span> <span class="n">ys</span><span class="x">)</span></code></pre></figure>

<table><thead>
  <tr>
    <th>Method</th>
    <th>Loss</th>
    <th>Coefficients</th>
  </tr></thead>
<tbody>
  <tr>
    <td>Target</td>
    <td>416.62</td>
    <td>(15.0, -2.1, 13.9, 13.9, 1.5)</td>
  </tr>
  <tr>
    <td>Analytical</td>
    <td>391.64</td>
    <td>(15.34, -3.24, 13.84, 1.46)</td>
  </tr>
  <tr>
    <td>Gradient Descent</td>
    <td>498.50</td>
    <td>(1.37, 0.54, 14.51, 1.26)</td>
  </tr>
</tbody>
</table>

<p>And finally, comparing the curves:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">x_model</span> <span class="o">=</span> <span class="o">-</span><span class="mi">5</span><span class="o">:</span><span class="mf">0.01</span><span class="o">:</span><span class="mi">5</span>
<span class="n">ys_model</span> <span class="o">=</span>  <span class="n">map</span><span class="x">(</span><span class="n">x</span> <span class="o">-&gt;</span> <span class="n">evalpoly</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">coeffs</span><span class="x">),</span> <span class="n">x_model</span><span class="x">)</span></code></pre></figure>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/micrograd/polyfit.png" alt="Fitted polynomial curves" />
<figcaption></figcaption>
</figure>

<h3 id="gradient-descent-map">4.2 Revisited with map</h3>

<p>It is possible to replace the inner loop over the training data with <code class="language-plaintext highlighter-rouge">map</code>.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> gradient_descent_poly!</span><span class="x">(</span>
    <span class="n">coeffs</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">,</span>
    <span class="n">xs</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">,</span>
    <span class="n">ys</span><span class="o">::</span><span class="kt">AbstractVector</span>
    <span class="x">;</span> <span class="n">learning_rate</span><span class="o">::</span><span class="kt">AbstractFloat</span><span class="o">=</span><span class="mf">0.1</span><span class="x">,</span>
    <span class="n">max_iters</span><span class="o">::</span><span class="kt">Integer</span><span class="o">=</span><span class="mi">100</span>
    <span class="x">)</span>
    <span class="n">history</span> <span class="o">=</span> <span class="kt">Float64</span><span class="x">[]</span>
    <span class="k">for</span> <span class="n">i</span> <span class="k">in</span> <span class="mi">1</span><span class="o">:</span><span class="n">max_iters</span>
        <span class="c"># forward</span>
        <span class="n">ys_and_backs</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">x</span><span class="o">-&gt;</span><span class="n">rrule</span><span class="x">(</span><span class="n">evalpoly</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="n">coeffs</span><span class="x">),</span> <span class="n">xs</span><span class="x">)</span>
        <span class="n">ŷ</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">first</span><span class="x">,</span> <span class="n">ys_and_backs</span><span class="x">)</span>
        <span class="n">loss_iter</span><span class="x">,</span> <span class="n">back_loss</span> <span class="o">=</span> <span class="n">rrule</span><span class="x">(</span><span class="n">mse</span><span class="x">,</span> <span class="n">ŷ</span><span class="x">,</span> <span class="n">ys</span><span class="x">)</span>
        <span class="c"># reverse</span>
        <span class="n">Δmse</span><span class="x">,</span> <span class="n">Δŷ</span><span class="x">,</span> <span class="n">Δy</span> <span class="o">=</span> <span class="n">back_loss</span><span class="x">(</span><span class="mf">1.0</span><span class="x">)</span>
        <span class="n">∂f_and_∂x_zipped</span> <span class="o">=</span> <span class="n">map</span><span class="x">(((</span><span class="n">_</span><span class="x">,</span> <span class="n">pb</span><span class="x">),</span> <span class="n">δ</span><span class="x">)</span> <span class="o">-&gt;</span> <span class="n">pb</span><span class="x">(</span><span class="n">δ</span><span class="x">),</span> <span class="n">ys_and_backs</span><span class="x">,</span> <span class="n">Δŷ</span><span class="x">)</span>
        <span class="n">Δcoeffs_unzipped</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">Δ</span><span class="o">-&gt;</span><span class="n">Δ</span><span class="x">[</span><span class="mi">3</span><span class="x">],</span> <span class="n">∂f_and_∂x_zipped</span><span class="x">)</span> <span class="c"># Δ[i] = (Δevalpoly, Δx, Δcoeffs)</span>
        <span class="n">Δcoeffs</span> <span class="o">=</span> <span class="n">reduce</span><span class="x">(</span><span class="o">+</span><span class="x">,</span> <span class="n">Δcoeffs_unzipped</span><span class="x">)</span>
        <span class="c"># update</span>
        <span class="n">coeffs</span> <span class="o">.-=</span> <span class="n">learning_rate</span> <span class="o">.*</span> <span class="n">Δcoeffs</span>
        <span class="c"># history</span>
        <span class="n">push!</span><span class="x">(</span><span class="n">history</span><span class="x">,</span> <span class="n">loss_iter</span><span class="x">)</span>
    <span class="k">end</span>
    <span class="n">history</span>
<span class="k">end</span></code></pre></figure>

<p>This is code is slightly more complex than the previous version.
The behaviour and performance is practically identical.
However, it is one step closer to being more generic.</p>

<p>In machine learning, models usually execute on multiple inputs at once.
We could make a polynomial model that does that:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> Polynomial</span><span class="x">{</span><span class="n">V</span><span class="o">&lt;:</span><span class="kt">AbstractVector</span><span class="x">}</span>
    <span class="n">weights</span><span class="o">::</span><span class="n">V</span>
<span class="k">end</span>
<span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Polynomial</span><span class="x">)(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">evalpoly</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">m</span><span class="o">.</span><span class="n">weights</span><span class="x">)</span>
<span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Polynomial</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">)</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">m</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span></code></pre></figure>

<p>The goal then is to get gradients for the model’s weights directly:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">model</span> <span class="o">=</span> <span class="n">Polynomial</span><span class="x">(</span><span class="n">coeffs</span><span class="x">)</span>
<span class="n">zs</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">m</span> <span class="o">-&gt;</span> <span class="n">m</span><span class="x">(</span><span class="n">xs</span><span class="x">),</span> <span class="n">model</span><span class="x">)</span></code></pre></figure>

<p>In the next sections we will write code that will inspect the model function call, recognise that it calls <code class="language-plaintext highlighter-rouge">map</code>, and call a <code class="language-plaintext highlighter-rouge">pullback</code> for map.<sup id="fnref:pullback_vs_rrule" role="doc-noteref"><a href="#fn:pullback_vs_rrule" class="footnote" rel="footnote">2</a></sup>
This in turn will call the <code class="language-plaintext highlighter-rouge">pullback</code> for <code class="language-plaintext highlighter-rouge">evalpoly</code>, which will pass the arguments to the <code class="language-plaintext highlighter-rouge">rrule</code> defined above.</p>

<h2 id="conclusion">5 Conclusion</h2>

<p>The next two sections will develop the <code class="language-plaintext highlighter-rouge">pullback</code> function.
It will inspect and decompose code with the goal of passing arguments to <code class="language-plaintext highlighter-rouge">rrule</code> and accumulating gradients via the chain rule.</p>

<p><a href="/machine-learning/2024/08/03/micrograd-2-expr">Part 2</a> will introduce metaprogamming Julia and generate expressions for the backpropagation code. 
However the code is unstable and prone to errors - it is recursive metaprogramming - so <a href="/machine-learning/2024/08/10/micrograd-3-ir">part 3</a> will introduce more robust code making use of the <a href="https://fluxml.ai/IRTools.jl/latest/">IRTools.jl</a> package. 
This code really pushes Julia’s metaprogramming to its limits.</p>

<p>It is possible to jump straight to <a href="/machine-learning/2024/08/10/micrograd-3-ir">part 3</a> if desired.</p>

<hr />

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:micrograd" role="doc-endnote">
      <p>For example, Micrograd defines a <code class="language-plaintext highlighter-rouge">Value</code> class that has a custom definition for <code class="language-plaintext highlighter-rouge">__add__</code>. This custom definition then calculates the forward pass and prepares the backwards pass. The same is true of the <code class="language-plaintext highlighter-rouge">Tensor</code> objects in PyTorch. <a href="#fnref:micrograd" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:pullback_vs_rrule" role="doc-endnote">
      <p>It is a design choice to use <code class="language-plaintext highlighter-rouge">pullback</code> and not <code class="language-plaintext highlighter-rouge">rrule</code> for map. Both <code class="language-plaintext highlighter-rouge">rrule</code> and <code class="language-plaintext highlighter-rouge">pullback</code> have the same outputs. However <code class="language-plaintext highlighter-rouge">rrule</code> is intended for stand alone gradients, whereas <code class="language-plaintext highlighter-rouge">pullback</code> will potentially involve recursive calls to itself. <a href="#fnref:pullback_vs_rrule" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Lior Sinai</name></author><category term="machine-learning" /><category term="mathematics" /><category term="transformers" /><category term="&apos;machine" /><category term="learning&apos;" /><category term="&apos;deep" /><category term="learning&apos;" /><summary type="html"><![CDATA[A series on automatic differentiation in Julia. Part 1 provides an overview and defines explicit chain rules.]]></summary></entry><entry><title type="html">Covering all birthdays</title><link href="https://liorsinai.github.io/mathematics/2024/07/09/birthday-covering.html" rel="alternate" type="text/html" title="Covering all birthdays" /><published>2024-07-09T00:00:00+00:00</published><updated>2024-07-28T00:00:00+00:00</updated><id>https://liorsinai.github.io/mathematics/2024/07/09/birthday-covering</id><content type="html" xml:base="https://liorsinai.github.io/mathematics/2024/07/09/birthday-covering.html"><![CDATA[<p><em>Quantifying how likely each birthday is present (covered) in some large group of people.</em></p>

<h3 id="table-of-contents">Table of Contents</h3>

<nav id="toc"></nav>
<script src="/assets/makeTableOfContents.js"></script>

<h2 id="introduction">1 Introduction</h2>

<p>I recently got <a href="https://xkcd.com/356/">nerd sniped</a> by a fascinating post on <a href="https://news.ycombinator.com/">Hacker News</a> titled <a href="https://seniormars.com/posts/everyday-birthday/?">Every day is an Owl’s Birthday!</a> by SeniorMars. 
It explored the problem of estimating if there was at least one student at a university for every birthday. Put another way, it explored the following question:</p>

<blockquote>
  <p>Given $n$ people, what is the probability that all $N$ birthdays are covered? That is, given $n$ people, what is the probability that there is at least 1 person for each birthday?</p>
</blockquote>

<p>As well as the related question:</p>

<blockquote>
  <p>What is the expected number of people required to have at least 1 person for each birthday? That is, how many people do you need to approach and ask what their birthday is before you see all birthdays?</p>
</blockquote>

<p>For the latter, the minimum number of people is obviously $n=N=365$. 
However you would have to be very lucky to get this outcome.
On the other extreme, one can imagine an incredibly unlucky case where 1000 people are approached and all are born on May 5th, and hence you would be no closer to your goal than when you started. 
You might end up approaching hundreds of thousands of people.
But between these two extremes, how many people do we expect to approach on average?
The first question then seeks to quantify how lucky you are with the number you get.</p>

<p>I want to present the results in the <a href="https://seniormars.com/posts/everyday-birthday/?">original post</a> in my own way.
It took me a few reads to understand those explanations, and I think I can improve on them here.
However I will leave out extra material from the original including mathematic proofs, accounting for leap years and accounting for non-uniform birthday distributions.</p>

<p>This problem is different to the <a href="https://en.wikipedia.org/wiki/Birthday_problem">birthday paradox</a>, which tries to determine how likely duplicate birthdays are in a group of people, and which comes up with the surprising answer that it is very likely for even a small number.
I have explored this problem in an earlier <a href="/mathematics/2021/06/04/birthday-collisions">blog post</a>.
The key differentiator is the birthday paradox deals with $n&lt;N$ (less people than birthdays) where duplicates are not guaranteed, but here the problem has $n&gt;N$ (more people than birthdays), where duplicates cause the extra complexity.</p>

<h2 id="the-coupon-collector-problem">2 The Coupon Collector Problem</h2>

<p>We will start with the second problem because it is simpler to solve. It is:</p>

<blockquote>
  <p>What is the expected number of people required to have at least 1 person for each birthday?</p>
</blockquote>

<p>This problem is identical to the <a href="https://en.wikipedia.org/wiki/Coupon_collector%27s_problem">Coupon Collector’s Problem</a>:</p>

<blockquote>
  <p>Given $N$ coupons, how many coupons do you expect you need to draw with replacement before having drawn each coupon at least once?</p>
</blockquote>

<p>with $N=365$ birthdays.</p>

<p>I’ll first simulate it and present results, and then match the numbers to theory.</p>

<h3 id="the-coupon-collector-simulation">2.1 Simulation</h3>

<p>To run a <a href="https://en.wikipedia.org/wiki/Monte_Carlo_method">Monte Carlo simulation</a>, for each trial create a vector of <code class="language-plaintext highlighter-rouge">seen</code> of size 365 and set it to all <code class="language-plaintext highlighter-rouge">false</code>.
Then get stuck in a <code class="language-plaintext highlighter-rouge">while</code> loop, and on each iteration generate 1 random birthday <code class="language-plaintext highlighter-rouge">k</code> and set <code class="language-plaintext highlighter-rouge">seen[k]</code> to <code class="language-plaintext highlighter-rouge">true</code>.
Exit when all of <code class="language-plaintext highlighter-rouge">seen</code> is <code class="language-plaintext highlighter-rouge">true</code>. Repeat this for some large number $T$ trials.</p>

<p>Here is an implementation in Julia:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">ProgressMeter</span>

<span class="k">function</span><span class="nf"> coupon_collecting_simulation</span><span class="x">(</span><span class="n">ncoupons</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">ntrials</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span>
    <span class="n">counts</span> <span class="o">=</span> <span class="kt">Vector</span><span class="x">{</span><span class="kt">Int</span><span class="x">}(</span><span class="nb">undef</span><span class="x">,</span> <span class="n">ntrials</span><span class="x">)</span>
    <span class="nd">@showprogress</span> <span class="k">for</span> <span class="n">i</span> <span class="k">in</span> <span class="n">eachindex</span><span class="x">(</span><span class="n">counts</span><span class="x">)</span>
        <span class="n">counts</span><span class="x">[</span><span class="n">i</span><span class="x">]</span> <span class="o">=</span> <span class="n">run_collection_trial</span><span class="x">(</span><span class="n">ncoupons</span><span class="x">)</span>
    <span class="k">end</span>
    <span class="n">counts</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> run_collection_trial</span><span class="x">(</span><span class="n">ncoupons</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span>
    <span class="n">seen</span> <span class="o">=</span> <span class="n">fill</span><span class="x">(</span><span class="nb">false</span><span class="x">,</span> <span class="n">ncoupons</span><span class="x">)</span>
    <span class="n">coupons</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="k">while</span> <span class="o">!</span><span class="n">all</span><span class="x">(</span><span class="n">seen</span><span class="x">)</span>
        <span class="n">coupons</span> <span class="o">+=</span> <span class="mi">1</span>
        <span class="n">k</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">1</span><span class="o">:</span><span class="n">ncoupons</span><span class="x">)</span>
        <span class="n">seen</span><span class="x">[</span><span class="n">k</span><span class="x">]</span> <span class="o">=</span> <span class="nb">true</span>
    <span class="k">end</span>
    <span class="n">coupons</span>
<span class="k">end</span></code></pre></figure>

<p>This code can be run with <code class="language-plaintext highlighter-rouge">coupon_collecting_simulation(365, 10_000)</code>.</p>

<p>Here are the results:</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/birthday-covering/coupon_collecting.png" alt="Coupon collecting histogram" />
<figcaption>Monte Carlo frequency of stopping counts for coupon collecting.</figcaption>
</figure>

<p>The stopping counts range from 1,349 to 5,832. The average is 2364.84 with a standard deviation of 466.68.
So on average we need 6.5$\times$ as many people as the number of birthdays to see all of them.</p>

<h3 id="the-coupon-collector-theory">2.2 Theory</h3>

<p>To calculate this number theoretically, it helps to answer the following easier questions.</p>

<p>How many people on average do we need to ask to collect a new birthday,</p>
<ol>
  <li>At the start?</li>
  <li>After collecting 50 unique birthdays?</li>
  <li>After collecting 265 unique birthdays?</li>
  <li>At the end, after collecting 364 unique birthdays?</li>
</ol>

<p>Answers:</p>
<ol>
  <li>One. The first person will give us our first birthday.</li>
  <li>There are 315 remaining birthdays and so $\frac{315}{365}=\frac{1}{1.15}=86.3\%$ of birthdays will be new. This means 1 in every 1.15 people will give us a new birthday, so we need to ask 1.15 people on average to get a new birthday.</li>
  <li>There are 100 remaining birthdays and so $\frac{100}{365}=\frac{1}{3.65}=27.4\%$ of birthdays will be new. This means 1 in every 3.65 people will give us a new birthday, so we need to ask 3.65 people on average to get a new birthday.</li>
  <li>At the end only $\frac{1}{365}$ of birthdays will be new. This means 1 in 365 people will give us a new birthday, so we need to ask a full 365 people to get a new birthday.</li>
</ol>

<p>From this it follows that the formula for the expected number of people $n$ is the sum of all the different scenarios:</p>

\[\begin{align}
n &amp;= \sum_{i=1}^N \frac{1}{p_i} \\
  &amp;= \sum_{i=1}^N \frac{1}{(N-i+1)/N} \\
  &amp;= N\sum_{k=1}^N \frac{1}{k} \quad ; k=N-i+1
\end{align}\]

<p>Setting $N=365$, we get 2364.64 people, which is extremely close to our simulated value of 2364.84.</p>

<p>The sum $\sum_k^N \frac{1}{k}$ is the <a href="https://en.wikipedia.org/wiki/Harmonic_number">harmonic number</a> and is approximated by $\ln N + \gamma$, where $\gamma\approx 0.5772156649$ is the Euler-Mascheroni constant.
This shows that this sum is unbounded for $N$.
That is, the more coupons the more people that need to be asked.</p>

<h2 id="covering-birthdays">3 Covering Birthdays</h2>

<p>Now to solve the first problem. It is:</p>

<blockquote>
  <p>Given $n$ people, what is the probability that all $N$ birthdays are covered? That is, given $n$ people, what is the probability that there is at least 1 person for each birthday?</p>
</blockquote>

<p>Based on the previous answer, we expect the probability to be very low below $n=2364$, and very high above it.</p>

<p>For the theory part you’ll need a good understanding of counting methods and how the <a href="https://en.wikipedia.org/wiki/Binomial_coefficient">binomial coefficient</a> $n \choose k$ (read as “n choose k”) is used in combinatorics.
The main calculation is with the <a href="https://en.wikipedia.org/wiki/Inclusion%E2%80%93exclusion_principle">Inclusion-Exclusion Principle</a> formula, which I’ll introduce gently.</p>

<p>Because the numbers get very large very quickly, I’ll also work with the simpler case of covering 4 seasons with 5 people: spring 🌱, summer ☀️, autumn 🍂 and winter ❄️.</p>

<h3 id="covering-birthdays-simulation">3.1 Simulation</h3>

<p>For this <a href="https://en.wikipedia.org/wiki/Monte_Carlo_method">Monte Carlo simulation</a> more work needs to be done per data point.
Take a fixed $n$ and then generate $n$ random birthdays a large number of $T$ times.
Each time  check if all the birthdays are covered or not and add this to a count $c$.
(The simplest way to do this is check if the length of the hashed set is 365.)
After all the trials estimate the probability as $c/T$.
Then repeat this for several different $n$’s.</p>

<p>Here is an implementation in Julia:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">ProgressMeter</span>

<span class="k">function</span><span class="nf"> covering_simulation</span><span class="x">(</span><span class="n">ndays</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">ntrials</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">population_sizes</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Int</span><span class="x">})</span>
    <span class="n">ratio_covered</span> <span class="o">=</span> <span class="n">zeros</span><span class="x">(</span><span class="n">length</span><span class="x">(</span><span class="n">population_sizes</span><span class="x">))</span>
    <span class="k">for</span> <span class="x">(</span><span class="n">idx</span><span class="x">,</span> <span class="n">pop_size</span><span class="x">)</span> <span class="k">in</span> <span class="n">enumerate</span><span class="x">(</span><span class="n">population_sizes</span><span class="x">)</span>
        <span class="n">covered</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="n">progress</span> <span class="o">=</span> <span class="n">Progress</span><span class="x">(</span><span class="n">ntrials</span><span class="x">;</span> <span class="n">desc</span><span class="o">=</span><span class="s">"Population size: </span><span class="si">$</span><span class="s">pop_size "</span><span class="x">)</span>
        <span class="k">for</span> <span class="n">i</span> <span class="k">in</span> <span class="mi">1</span><span class="o">:</span><span class="n">ntrials</span>
            <span class="n">next!</span><span class="x">(</span><span class="n">progress</span><span class="x">)</span>
            <span class="n">is_covered</span> <span class="o">=</span> <span class="n">covering_trial</span><span class="x">(</span><span class="n">ndays</span><span class="x">,</span> <span class="n">pop_size</span><span class="x">)</span>
            <span class="k">if</span> <span class="n">is_covered</span>
                <span class="n">covered</span> <span class="o">+=</span> <span class="mi">1</span>
            <span class="k">end</span>
        <span class="k">end</span>
        <span class="n">ratio_covered</span><span class="x">[</span><span class="n">idx</span><span class="x">]</span> <span class="o">=</span> <span class="n">covered</span> <span class="o">/</span> <span class="n">ntrials</span>
    <span class="k">end</span>
    <span class="n">ratio_covered</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> covering_trial</span><span class="x">(</span><span class="n">ndays</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">n</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span>
    <span class="n">birthdays</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">1</span><span class="o">:</span><span class="n">ndays</span><span class="x">,</span> <span class="n">n</span><span class="x">)</span>
    <span class="n">length</span><span class="x">(</span><span class="kt">Set</span><span class="x">(</span><span class="n">birthdays</span><span class="x">))</span> <span class="o">==</span> <span class="n">ndays</span>
<span class="k">end</span></code></pre></figure>

<p>It can be run with:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">population_sizes</span> <span class="o">=</span> <span class="x">[</span><span class="mi">365</span><span class="x">,</span> <span class="mi">500</span><span class="x">,</span> <span class="mi">1000</span><span class="x">,</span> <span class="mi">1500</span><span class="x">,</span> <span class="mi">2000</span><span class="x">,</span> <span class="mi">2364</span><span class="x">,</span> <span class="mi">2500</span><span class="x">,</span> <span class="mi">3000</span><span class="x">,</span> <span class="mi">4000</span><span class="x">,</span> <span class="mi">5000</span><span class="x">]</span>
<span class="n">ratio_covered</span> <span class="o">=</span> <span class="n">covering_simulation</span><span class="x">(</span><span class="mi">365</span><span class="x">,</span> <span class="mi">10_000</span><span class="x">,</span> <span class="n">population_sizes</span><span class="x">)</span></code></pre></figure>

<p>Here are the results:</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/birthday-covering/covering_probability.png" alt="Monte Carlo covering graph" />
<figcaption>Monte Carlo simulation of ratio of birthdays covered per $n$</figcaption>
</figure>

<p>The probability is almost zero below 1500, but rise rapidly afterwards and by 4000 is almost one.
At the expected value of $n=2364$, the ratio covered is 0.5739.</p>

<h3 id="covering-birthdays-theory">3.2 Theory</h3>

<h4 id="counting-configurations">Counting configurations</h4>

<p>One way to estimate the probability is to count all the different configurations of birthdays.</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/birthday-covering/seasons_5.png" alt="5 people in pens" />
<figcaption></figcaption>
</figure>

<p>For the season problem, this is straight forward: ${5 \choose 2} = 10$ pairs can share a season, then there are 4 seasons that can be assigned to the pair, then 3 remaining seasons to the next person, then 2 to the next person, and finally the last person must take the last remaining season. This is out of $4^5$ possible configurations:</p>

\[\begin{align}
    P(🌱\cup ☀️\cup 🍂 \cup ❄️ ) &amp;= \frac{ {5 \choose 2} 4!}{4^5} \\
        &amp;= 0.234375
\end{align}\]

<p>This is just under 1/4th.</p>

<p>For the birthday problem, this is much more difficult. 
There are many, many different configurations which all need to be summed together.
For example, one such configuration between $n=2364$ people is 180 birthdays each shared 6 times (1080 people), another 180 birthdays each shared 7 times (1260 people), 4 shared 5 times (20 people), and 1 shared 4 times (4 people). This is out of $365^n$ configurations:</p>

\[\begin{align}
    P\left(X \right) &amp;= \left[{2364 \choose 1080}{1284 \choose 1260}{24 \choose 20}{4 \choose 4} \right] \cdot \\
      &amp;\phantom{=} \quad \left[ {365 \choose 180 } {185 \choose 180 } {5 \choose 4 } {1 \choose 1 }\right] / 365^{2364} \\
      &amp;= \frac{2364!}{1080! 1260! 20! 4!} \frac{365!}{ (180!)^2 4! 1!} / 365^{2364} \\
      &amp;= 8.4\times 10^{-5179}
\end{align}\]

<p>This probability is absolutely tiny. 
Worse, there are an extremely large number of configurations like this, all with extremely small probabilities.
Adding them up is complex and might have numerical issues.</p>

<p>Thankfully, there is a simpler way.</p>

<h4 id="counting-missing-birthdays">Counting missing birthdays</h4>

<p>All probabilities sum to 1.
From this, the probability that at least one person has each birthday is 1 minus the scenarios where birthdays are missing.</p>

<p>As a start, assume mutual exclusivity between missing birthdays.
That is, there is no overlap between missing a birthday.
This is clearly false: a group of people can have multiple missing birthdays.
However, it makes the calculations simple.</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/birthday-covering/seasons_exclude.png" alt="Season counting trees" />
<figcaption>Counting trees for each $S \setminus x$ (S exclude x) season.</figcaption>
</figure>

<p>For the season problem, there are 4 possible ways we can exclude 1 of 4 seasons, and then there are $3^5$ possibilities for all of the five people. The probability is thus:</p>

\[\begin{align}
    P(🌱\cup ☀️\cup 🍂 \cup ❄️ )&amp;= 1 - P(\bar{🌱}\cup \bar{☀️}\cup \bar{🍂} \cup \bar{❄️} )\\
      &amp;= 1 - \frac{4 \cdot 3^5}{4^5} \\
      &amp;= 0.0508
\end{align}\]

<p>This is much smaller than the original value of 0.234. The mutual exclusivity assumption clearly does not hold here.
(This will be corrected shortly.)</p>

<p>For the birthdays, there are 365 possible ways we can exclude 1 of 365 birthdays, and then there are $364^n$ possibilities for the birthdays for $n$ people:</p>

\[\begin{align}
    P\left(\bigcup\limits_{i=1}^{365} B_i \right) &amp;= 1 - P\left(\bigcup\limits_{i=1}^{365} \bar{B}_i \right) \\
      &amp;= 1 - \frac{365 \cdot 364^{2364} }{365^{2364} } \\
      &amp;= 0.4432
\end{align}\]

<p>This is much closer to the target value (77% of the simulated value).
This is because with 2364 people it is somewhat likely that only 1 of the 365 birthdays is missing.</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/birthday-covering/seasons_overlap.png" alt="Overlap in season counting trees" />
<figcaption>Overlap between counting trees.</figcaption>
</figure>

<p>To correct these values, we need to account for overlaps in the counting trees.
For the season problem, we can exclude both winter ❄️ and autumn 🍂 by only choosing spring 🌱 or summer ☀️
in either the $S\setminus ❄️$ tree or the $S\setminus 🍂$ tree.
Since in both we have a choice of 2 seasons for each of the 5 people, there are $2^5=32$ overlapping branches.
In total there are ${4 \choose 2} = 6$ sets of overlapping branches:</p>

<ul>
  <li>Between $S\setminus ❄️$ and $S\setminus 🌱$.</li>
  <li>Between $S\setminus ❄️$ and $S\setminus ☀️$.</li>
  <li>Between $S\setminus ❄️$ and $S\setminus 🍂$.</li>
  <li>Between $S\setminus 🌱$ and $S\setminus ☀️$.</li>
  <li>Between $S\setminus 🌱$ and $S\setminus 🍂$.</li>
  <li>Between $S\setminus ☀️$ and $S\setminus 🍂$.</li>
</ul>

<p>Each branch has been counted twice, so we need to minus one version to correct it:</p>

\[\begin{align}
P(🌱\cup ☀️\cup 🍂 \cup ❄️ )&amp;= 1 - P(\bar{🌱}\cup \bar{☀️}\cup \bar{🍂} \cup \bar{❄️} )\\
      &amp;= 1 - \left[ \frac{4 \cdot 3^5}{4^5} - \frac{ {4 \choose 2} \cdot 2^5}{4^5}\right] \\
      &amp;= 0.23828125
\end{align}\]

<p>Much closer to our original answer of 0.234375!</p>

<p>Similarly, for the birthdays:</p>

\[\begin{align}
    P\left(\bigcup\limits_{i=1}^{365} B_i \right) &amp;= 1 - P\left(\bigcup\limits_{i=1}^{365} \bar{B}_i\right) \\
      &amp;= 1 - \left[ \frac{365 \cdot 364^{2364} }{365^{2364} } - \frac{ {365 \choose 2} \cdot 363^{2364} }{365^{2364} }\right]\\
      &amp;= 0.5955
\end{align}\]

<p>This is slightly over the simulated value of 0.5739.</p>

<p>For the next correction, it is helpful to draw a Venn diagram:</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/birthday-covering/seasons_venn.png" alt="Venn diagram overlap" />
<figcaption></figcaption>
</figure>

<p>For the seasons, we initially double count the overlaps between 2 circles, but then correct this by subtracting each one once. But this means that the middle, which is initially counted 3 times, is subtracted 3 times. So we need to add it back once.
There are $ {4 \choose 3 } = 4 $ overlaps we need to add back:</p>

<ul>
  <li>Between $S\setminus ❄️$, $S\setminus 🌱$ and $S\setminus ☀️$.</li>
  <li>Between $S\setminus ❄️$,  $S\setminus 🌱$ and $S\setminus 🍂$.</li>
  <li>Between $S\setminus ❄️$,  $S\setminus ☀️$ and $S\setminus 🍂$.</li>
  <li>Between $S\setminus ☀️$,  $S\setminus 🌱$ and $S\setminus 🍂$.</li>
</ul>

\[\begin{align}
   P(🌱\cup ☀️\cup 🍂 \cup ❄️ )&amp;= 1 - P(\bar{🌱}\cup \bar{☀️}\cup \bar{🍂} \cup \bar{❄️} )\\
      &amp;= 1 - \left[ \frac{4 \cdot 3^5}{4^5} - \frac{ {4 \choose 2} \cdot 2^5}{4^5} + \frac{ {4 \choose 3} \cdot 1^5}{4^5}\right] \\
      &amp;= 0.234375
\end{align}\]

<p>This is the exact same value as with counting the configurations.</p>

<p>For the birthdays:</p>

\[\begin{align}
    P\left(\bigcup\limits_{i=1}^{365} B_i \right) &amp;= 1 - P\left(\bigcup\limits_{i=1}^{365} \bar{B}_i \right) \\
      &amp;= 1 - \left[ \frac{365 \cdot 364^{2364} }{365^{2364} } - \frac{ {365 \choose 2} \cdot 363^{2364} }{365^{2364}} \right. \\
      &amp;\phantom{=} \left. + \frac{ {365 \choose 3} \cdot 362^{2364} }{365^{2364}}  \right] \\
      &amp;= 0.5681
\end{align}\]

<p>This is now slightly under the simulated value of 0.5739.</p>

<p>For the seasons, we are done. For the birthdays, we can continue this pattern of over-correcting/under-correcting under the <a href="https://en.wikipedia.org/wiki/Inclusion%E2%80%93exclusion_principle">Inclusion-Exclusion Principle</a>:</p>

<div class="card">
  <div class="card-body">
    <h5 class="card-title">Inclusion-Exclusion Principle</h5>
    <p class="card-text">
		$$
        \begin{align}
        P\left( \bigcup\limits_{i=1}^{n} A_i \right) &amp;= \sum_{i=1}^{n} |A_k| - \sum_{1\leq i &lt;j \leq n}^{n} | A_i \cap A_j| \\
        &amp;\phantom{=} + \sum_{1\leq i &lt;j &lt;k \leq n}^{n} | A_i \cap A_j \cap A_k| \\
        &amp;\phantom{=} - ... \\
        &amp;\phantom{=} + (-1)^{n+1} | A_i \cap ... \cap A_n| \\
        &amp;= \sum_{k=1}^n (-1)^{k+1} \left(\sum_{1\leq i_1 &lt; ... &lt;i_k \leq n}^{n} | A_{i_1} \cap ... \cap A_{i_k}|  \right) \\
        &amp;= \sum_{J\subseteq \{1,...,n\}} (-1)^{|J|+1} \left| \bigcap\limits_{j \in {J}} A_j \right|
        \end{align} \\
        $$
	</p>
  </div>
</div>

<p>For the birthday problem, each $A_i$ is the exclusion of one birthday (e.g. $A_5$ is January 5th missing), and groups of intersections $\sum \vert A_{i_1} \cap … \cap A_{i_k} \vert$ are calculated as the number of different combinations $365 \choose k $ of shared missing birthdays multiplied by the probability $\left(\frac{365-k}{365}\right)^n$.</p>

<p>The formula is then:</p>

\[\begin{align}
P\left(\bigcup\limits_{i=1}^{365} B_i\right) &amp;= 1 - P\left(\bigcup\limits_{i=1}^{365} \bar{B}_i\right) \\
&amp;= 1 - \frac{1}{365^n}\sum_{k=1}^{365} (-1)^{(k+1)} { 365 \choose k} (365 - k)^n
\end{align}\]

<p>For $n=2364$, we get an answer of 0.5712. The simulated value of 0.5739 was close.</p>

<p>We can now construct a theoretical graph and compare it to the graph from the simulation:</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/birthday-covering/covering_probability_theory.png" alt="Covering probability with theory" />
<figcaption></figcaption>
</figure>

<p>The graphs match very well.</p>

<h2 id="conclusion">4 Conclusion</h2>

<p>The answer to the question, what is the probability that all birthdays ($N=365$) are present in a group of $n$ people is:</p>

<ul>
  <li>Very low for less than 1000 people ($&lt;3N$).</li>
  <li>About 50% for 2000 people ($\approx 6N$).</li>
  <li>Very high for 3000 people ($8N$) and almost certain for 4000 and above ($&gt;10N$).</li>
</ul>

<p>More generally, the <a href="https://en.wikipedia.org/wiki/Inclusion%E2%80%93exclusion_principle">Inclusion-Exclusion Principle</a> can be used to calculate exact probabilities for this and similar problems.</p>

<p>This was an interesting problem, but I’m not sure if there is a practical use to it.</p>]]></content><author><name>Lior Sinai</name></author><category term="mathematics" /><category term="mathematics" /><category term="probability" /><summary type="html"><![CDATA[Quantifying how likely each birthday is present (covered) in some large group of people.]]></summary></entry><entry><title type="html">Generative transformer from first principles in Julia</title><link href="https://liorsinai.github.io/machine-learning/2024/03/23/transformers-gpt.html" rel="alternate" type="text/html" title="Generative transformer from first principles in Julia" /><published>2024-03-23T00:00:00+00:00</published><updated>2026-01-10T00:00:00+00:00</updated><id>https://liorsinai.github.io/machine-learning/2024/03/23/transformers-gpt</id><content type="html" xml:base="https://liorsinai.github.io/machine-learning/2024/03/23/transformers-gpt.html"><![CDATA[<p><em>A transformer for generating text in Julia, trained on Shakespeare’s plays. This model can be used as a Generative Pre-trained Transformer (GPT) with further work. This post was inspired by Andrej Karpathy’s Zero to Hero course.</em></p>

<p><em>Update 2 February 2025: update to Flux 0.16.</em></p>

<p>See also a previous post: <a href="/machine-learning/2022/05/18/transformers">Transformers from first principles in Julia</a>.
And a later post: <a href="/machine-learning/2025/02/22/mla">DeepSeek’s Multi-Head Latent Attention</a>.</p>

<p>All code available at <a href="https://github.com/LiorSinai/TransformersLite.jl">github.com/LiorSinai/TransformersLite.jl</a>.</p>

<h3 id="table-of-contents">Table of Contents</h3>

<nav id="toc"></nav>
<script src="/assets/makeTableOfContents.js"></script>

<h2 id="introduction">1 Introduction</h2>

<p>The transformer architecture was introduced by Google AI in their famous <a href="https://arxiv.org/abs/1706.03762">Attention is all you need (2017)</a> paper.
They have dominated the natural language processing (NLP) landscape since then.
Nearly all of the state of the NLP models today are transformer models.
Most of them have an incredibly similar architecture to the original and differ only on training regimes, datasets and sizes.</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/transformers/transformer_model_sizes_annotated_2024.png" alt="transformer model sizes 2017-2024" />
<figcaption>Transformers have continued to grow in size.</figcaption>
</figure>

<p>In 2018 OpenAI released a paper titled <a href="https://web.archive.org/web/20210126024542/https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf">Improving Language Understanding by Generative Pre-Training</a>.
This led to the development of their first Generative Pre-trained Transformer (GPT) model.
As of 2024 they have released four versions of GPT, with the latest requiring over <a href="https://www.youtube.com/watch?v=Y2F8yisiS6E&amp;t=1202s">1.8 trillion parameters</a>.
The interactive version of the model, ChatGPT, has gained widespread fame for its human like responses.</p>

<figure class="post-figure" id="fig-gpt1">
<img class="img-80" src="/assets/posts/transformers/gpt.png" alt="GPT architecture" />
<figcaption>GPT Transformer architecture (left) and fine tuning tasks (right). Source: <a href="https://web.archive.org/web/20210126024542/https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf">GPT1 paper (2018)</a></figcaption>
</figure>

<p>The goal of this post is to create a generative transformer following OpenAI’s methodology for their first GPT-1 paper.
It will be a vanilla transformer without many of the additions that have been proposed in this fast paced field.
The model will be trained on Shakespeare plays and will be able to generate text that looks and sounds like Shakespeare.
This model can then be used as the pre-trained foundation for further supervised tasks.</p>

<h3 id="outcome">Outcome</h3>

<p>The goal is to create a model which implements the architecture in the <a href="#fig-gpt1">GPT paper</a>:</p>
<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>TransformerGenerator(
  Embedding(72 =&gt; 32),                  # 2_304 parameters
  Embedding(64 =&gt; 32),                  # 2_048 parameters
  Dropout(0.1),
  TransformerBlock(
    MultiHeadAttention(
      nhead=4,
      denseQ = Dense(32 =&gt; 32; bias=false),  # 1_024 parameters
      denseK = Dense(32 =&gt; 32; bias=false),  # 1_024 parameters
      denseV = Dense(32 =&gt; 32; bias=false),  # 1_024 parameters
      denseO = Dense(32 =&gt; 32),         # 1_056 parameters
    ),
    LayerNorm(32),                      # 64 parameters
    Dense(32 =&gt; 128, relu),             # 4_224 parameters
    Dense(128 =&gt; 32),                   # 4_128 parameters
    LayerNorm(32),                      # 64 parameters
    Dropout(0.1),
  ),
  ..., # 2x more TransformerBlocks
  Dense(32 =&gt; 72),                      # 2_376 parameters
  mask = 64×64 Matrix{Bool},
)        # Total: 43 trainable arrays, 44_552 parameters,
          # plus 1 non-trainable, 4_096 parameters, summarysize 180.641 KiB.
</code></pre></div></div>

<p>It will map tokens to indices and will operate on those :</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">mask</span> <span class="o">=</span> <span class="n">make_causal_mask</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">8</span><span class="x">,</span> <span class="mi">8</span><span class="x">))</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">indexer</span><span class="x">(</span><span class="n">collect</span><span class="x">(</span><span class="s">"LYSANDER"</span><span class="x">))</span> <span class="c"># [23, 36, 30, 12, 25, 15, 16, 29]</span>
<span class="n">model</span><span class="x">(</span><span class="n">indices</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">)</span></code></pre></figure>

<p>It will return a $ V \times n $ matrix, where $V$ is the vocabulary size and $n$ is the length of the input vector (8 in this example).
Each column represents logits for each token. 
These will then be normalised to values between 0 and 1 using the softmax function.
The model will be trained so that each value represents the probability of the next most likely token based on all the tokens before, up to a fixed context length $n$.
As a whole the matrix represents the probabilities associated with shifting the input one value to the right.</p>

<p>As an example, during training the input will be “LYSANDER” and the reference “YSANDER\n”.
The model will output a probability matrix and after sampling the result will be something like “YSANDR\nH”. This is then compared to the reference to improve the output.</p>

<p>The model computes all the probabilities for all $n$ characters in parallel through the same set of matrix operations, which makes this very efficient during training.
We will effectively compare $n$ different predictions for one sample.
However at inference time we are only interested in the last ($n$th) character, because we already have the first $n$ characters.
Therefore we will discard the first $n-1$ predictions. (They would have already been used internally in the model.)</p>

<p>This is an inherent inefficiency in the transformer model architecture. 
(KV Caching can be used to overcome it. See <a href="https://medium.com/@joaolages/kv-caching-explained-276520203249">João Lages’ visual explanation</a> or a later <a href="/machine-learning/2025/02/22/mla#kv-caching">blog post</a>.)</p>

<p>Generation will repeat inference many times, each time adding the last generated token to the context and generating a new token. The result is something like:</p>

<blockquote><pre>
CLATIO.
No, Goe, him buchieds is, hand I was,
To queer thee that of till moxselat by twish are.

BENET.
Are warrain Astier, the Cowlles,
bourse and nope, Merfore myen our to of them coun-mothared man,
Here is
Mafter my thath and herop, and in in have low’t so, veriege a the can eeset thy
inscestle marriom.

ADY.
Thus him stome
To so an streeward. Here cas, which id renuderser what thou bee of as the hightseleh-to.

CHAESS.
With he mand, th’ fouthos. I purcot Lay,
You.

GATHENT.
Who, to hath fres
</pre></blockquote>

<p>This was generated by a tiny 42,400 parameter model with a <a href="https://en.wikipedia.org/wiki/Perplexity">perplexity</a> of 6.3, down from a random sampling perplexity of 71 for 71 characters.</p>

<h3 id="background">Background</h3>

<p>In May 2022 I wrote a blog post on <a href="/machine-learning/2022/05/18/transformers">transformers from first principles in Julia</a>.
It developed a transformer for a classification task, namely predicting stars for Amazon Reviews.
That post was lacking however in that it did not create a decoder transformer.
This post is dedicated to that task.
I’ve written this as a stand-alone from the original even though much of the code is the same.
I refer back to the original post for some explanations.
Please see the <a href="/machine-learning/2022/05/18/transformers#design-considerations">Design Considerations</a> section which is not repeated here.</p>

<p>This post was inspired by Andrej Karpathy’s <a href="https://karpathy.ai/zero-to-hero.html">Zero to Hero</a> course.
I highly recommend it.
It covers many ideas like backpropagation, normalisation and embeddings that are assumed knowledge in this post.
In particular, this post emulates <a href="https://www.youtube.com/watch?v=kCc8FmEb1nY">lesson 7</a> except the language and framework used are Julia and Flux.jl, not Python and PyTorch.
The source code can be accessed at Karpathy’s famed <a href="https://github.com/karpathy/nanoGPT">nanoGPT</a> repository.</p>

<p>My own repositories with the code in this blog post can be accessed at <a href="https://github.com/LiorSinai/TransformersLite.jl">TransformersLite.jl</a> and <a href="https://github.com/LiorSinai/TransformersLite-examples">TransformersLite-examples</a>.
I will not detail any “pretty” printing function here - please see the repository for those.</p>

<p>This is not meant to be a full scale Julia solution.
For that, please see the <a href="https://github.com/chengchingwen/Transformers.jl">Transformers.jl</a> package. 
It has better optimizations, APIs for HuggingFace and more.</p>

<h2 id="data">2 Data</h2>
<h3 id="Download">2.1 Download</h3>

<p>The Complete Works of William Shakespeare by William Shakespeare has no copyright attached and can be downloaded legally from <a href="https://www.gutenberg.org/ebooks/100">Project Gutenburg</a>.</p>

<p>Here is a line to download it with cURL:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>curl https://www.gutenberg.org/cache/epub/100/pg100.txt <span class="o">&gt;</span> project_gutenberg_shakespeare.txt
</code></pre></div></div>

<h3 id="preparation">2.2 Preparation</h3>

<p>A typical passage from the text looks like:</p>

<blockquote>
<pre>
LYSANDER.
How now, my love? Why is your cheek so pale?
How chance the roses there do fade so fast?

HERMIA.
Belike for want of rain, which I could well
Beteem them from the tempest of my eyes.

LYSANDER.
Ay me! For aught that I could ever read,
Could ever hear by tale or history,
The course of true love never did run smooth.
But either it was different in blood—
</pre>
</blockquote>

<p>This is what we want the transformer to learn and the vast majority of the text follows this format.
However some pieces do not. These include the Project Gutenberg introduction and conclusion, the table of contents, the sonnets, the preambles - these list the acts and scenes in each play - and so on. 
Those should all be removed.</p>

<p>Optionally, the small amount of non-ASCII characters (œ, Æ,æ, …) should be removed. I also removed the “&amp;” symbol and changed the archaic usage of “&amp;c.” to “etc.”.</p>

<p>I’ve made a script which does all this work, <a href="https://github.com/LiorSinai/TransformersLite-Examples/blob/72d0d76256fc5b8447a84855f2eb065ef05f8b27/data/Shakespeare/prepare_shakespeare.jl">prepare_shakespeare.jl</a>.
It reduces the file size from 5.4 MB to 4.8 MB.</p>

<h3 id="Exploration">2.3 Exploration</h3>

<p>We can load the text in Julia with:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">text</span> <span class="o">=</span> <span class="n">open</span><span class="x">(</span><span class="n">filepath</span><span class="x">)</span> <span class="k">do</span> <span class="n">file</span>
    <span class="n">read</span><span class="x">(</span><span class="n">file</span><span class="x">,</span> <span class="kt">String</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>Some basic statistics:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">count</span><span class="x">(</span><span class="sc">'\n'</span><span class="x">,</span> <span class="n">text</span><span class="x">)</span>   <span class="c"># 182,027 lines</span>
<span class="n">count</span><span class="x">(</span><span class="s">"</span><span class="se">\n\n</span><span class="s">"</span><span class="x">,</span> <span class="n">text</span><span class="x">)</span> <span class="c"># 38,409 passages</span>
<span class="n">count</span><span class="x">(</span><span class="n">r</span><span class="s">"\w+"</span><span class="x">,</span> <span class="n">text</span><span class="x">)</span> <span class="c"># 921,816 words</span>
<span class="n">length</span><span class="x">(</span><span class="n">text</span><span class="x">)</span>        <span class="c"># 4,963,197 characters</span></code></pre></figure>

<p>The prepared dataset contains 182,027 lines spanning over approximately 38,409 passages, 921,816 words and 4,963,197 characters.</p>

<p>Most passages are very short - less than 100 characters.
The longest is Richard’s monologue in “The Third Part of King Henry the Sixth” which consists of 3047 characters.</p>

<p>Lines have an average of 26.27 characters with the longest being 77 characters in length.</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/transformers/shakespeare_char_frequencies.png" alt="Character frequencies" />
<figcaption>Frequencies of characters in the Complete Works of Shakespeare</figcaption>
</figure>

<p>After the data preparation there are 71 unique characters in the text: <code class="language-plaintext highlighter-rouge">\n !(),-.:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxyz—‘’“”</code></p>

<p>There are approximately 30,040 unique words in the dataset. 
Of these, approximately 80% appear less than 10 times and 96.5% less than 100 times.
The most frequent word is “the” with 23,467 occurrences.</p>

<h2 id="model">3 Model</h2>

<h3 id="project-setup">3.1 Project Setup</h3>

<p>To start, make a package in the Julia REPL:</p>
<figure class="highlight">
    <code class="language-julia-repl hljs" data-lang="julia-repl">
        <span class="hljs-meta">julia&gt;</span><span class="julia"> cd(<span class="hljs-string">"path\\to\\project"</span>)</span>
        <br />
        <span class="hljs-meta">julia&gt;</span><span class="julia"> ] <span class="hljs-comment"># enter package mode</span></span>
        <br />
        <span class="hljs-meta">(@v1.x) pkg&gt;</span><span class="julia"> generate TransformersLite <span class="hljs-comment"># make a directory structure</span></span>
        <br /> 
        <span class="hljs-meta">(@v1.x) pkg&gt;</span><span class="julia"> dev "path\\to\\project\\TransformersLite"</span>
    </code>
</figure>

<p>The purpose of making a package is that we can now use the super helpful Revise package,
which will dynamically update most changes during development without errors:</p>

<figure class="highlight"><pre><code class="language-julia-repl" data-lang="julia-repl">julia&gt; using Revise
julia&gt; using TransformersLite</code></pre></figure>

<p>The following packages need to be loaded/added for this tutorial:</p>

<figure class="highlight"><pre><code class="language-julia-repl" data-lang="julia-repl">julia&gt; using Flux, LinearAlgebra, NNlib, ProgressMeter, Random, StatsBase</code></pre></figure>

<h3 id="tokenization">3.2 Tokenization</h3>

<p>The model will predict probabilities for each token in a given vocabulary.
There is a choice as to what constitutes a token.
One extreme is one token for each word in the dataset.
Here there are far too many unique words so it will explode the parameter count while providing too few training samples per token.
The other extreme is character level tokens. This compresses the learning space too much to get fully realistic outputs, but otherwise it works surprisingly well.
In between is sub-word tokenization such as Byte Pair Pair encoding.
This allows configurable vocabulary lengths.
See my <a href="(https://github.com/LiorSinai/TokenizersLite)">TokenizersLite.jl</a> package, the <a href="https://github.com/chengchingwen/BytePairEncoding.jl">BytePairEncoding.jl</a> package or Karpathy’s latest <a href="https://www.youtube.com/watch?v=zduSFxRajkE&amp;feature=youtu.be">video</a>.</p>

<p>Here we will follow Karpathy’s approach and use character level tokens.
The model will learn to predict each word character by character.</p>

<p>First get all the characters:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">characters</span> <span class="o">=</span> <span class="n">sort</span><span class="x">(</span><span class="n">collect</span><span class="x">(</span><span class="kt">Set</span><span class="x">(</span><span class="n">text</span><span class="x">)))</span></code></pre></figure>

<p>Karpathy uses two dictionaries to convert between characters and indices: <code class="language-plaintext highlighter-rouge">char_to_int</code> and <code class="language-plaintext highlighter-rouge">int_to_char</code>.
I’m going to wrap these in a slightly more complex <code class="language-plaintext highlighter-rouge">IndexTokenizer</code> struct introduced in my <a href="/machine-learning/2022/05/18/transformers#tokenizers">first post</a>.
It holds a vector of the vocabulary (equivalent to <code class="language-plaintext highlighter-rouge">int_to_char</code>) and a <code class="language-plaintext highlighter-rouge">lookup</code> for reversing this (equivalent to <code class="language-plaintext highlighter-rouge">char_to_int</code>).
Additionally, it has an unknown symbol if any of the characters are not in the vocabulary.</p>

<p>The constructor is as follows:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> IndexTokenizer</span><span class="x">{</span><span class="n">T</span><span class="x">}</span>
    <span class="n">vocabulary</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="n">T</span><span class="x">}</span>
    <span class="n">lookup</span><span class="o">::</span><span class="kt">Dict</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="kt">Int</span><span class="x">}</span>
    <span class="n">unksym</span><span class="o">::</span><span class="n">T</span>
    <span class="n">unkidx</span><span class="o">::</span><span class="kt">Int</span>
    <span class="k">function</span><span class="nf"> IndexTokenizer</span><span class="x">(</span><span class="n">vocab</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">unksym</span><span class="o">::</span><span class="n">T</span><span class="x">)</span> <span class="k">where</span> <span class="n">T</span>
        <span class="k">if</span> <span class="o">!</span><span class="x">(</span><span class="n">unksym</span> <span class="n">∈</span> <span class="n">vocab</span><span class="x">)</span>
            <span class="n">pushfirst!</span><span class="x">(</span><span class="n">vocab</span><span class="x">,</span> <span class="n">unksym</span><span class="x">)</span>
            <span class="n">unkidx</span> <span class="o">=</span> <span class="mi">1</span>
        <span class="k">else</span>
            <span class="n">unkidx</span> <span class="o">=</span> <span class="n">findfirst</span><span class="x">(</span><span class="n">isequal</span><span class="x">(</span><span class="n">unksym</span><span class="x">),</span> <span class="n">vocab</span><span class="x">)</span>
        <span class="k">end</span>
        <span class="n">lookup</span> <span class="o">=</span> <span class="kt">Dict</span><span class="x">(</span><span class="n">x</span> <span class="o">=&gt;</span> <span class="n">idx</span> <span class="k">for</span> <span class="x">(</span><span class="n">idx</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span> <span class="k">in</span> <span class="n">enumerate</span><span class="x">(</span><span class="n">vocab</span><span class="x">))</span>
        <span class="n">new</span><span class="x">{</span><span class="n">T</span><span class="x">}(</span><span class="n">vocab</span><span class="x">,</span> <span class="n">lookup</span><span class="x">,</span> <span class="n">unksym</span><span class="x">,</span> <span class="n">unkidx</span><span class="x">)</span>
    <span class="k">end</span>
<span class="k">end</span>

<span class="n">Base</span><span class="o">.</span><span class="n">length</span><span class="x">(</span><span class="n">tokenizer</span><span class="o">::</span><span class="n">IndexTokenizer</span><span class="x">)</span> <span class="o">=</span> <span class="n">length</span><span class="x">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">vocabulary</span><span class="x">)</span>

<span class="k">function</span><span class="nf"> Base.show</span><span class="x">(</span><span class="n">io</span><span class="o">::</span><span class="kt">IO</span><span class="x">,</span> <span class="n">tokenizer</span><span class="o">::</span><span class="n">IndexTokenizer</span><span class="x">)</span> 
    <span class="n">T</span> <span class="o">=</span> <span class="n">eltype</span><span class="x">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">vocabulary</span><span class="x">)</span>
    <span class="n">print</span><span class="x">(</span><span class="n">io</span><span class="x">,</span> <span class="s">"IndexTokenizer{</span><span class="si">$(T)</span><span class="s">}(length(vocabulary)=</span><span class="si">$</span><span class="s">(length(tokenizer)), unksym=</span><span class="si">$</span><span class="s">(tokenizer.unksym))"</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>For encoding we lookup the character in the dictionary, returning the index of the unknown symbol by default:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> encode</span><span class="x">(</span><span class="n">tokenizer</span><span class="o">::</span><span class="n">IndexTokenizer</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">x</span><span class="o">::</span><span class="n">T</span><span class="x">)</span> <span class="k">where</span> <span class="n">T</span>
    <span class="n">get</span><span class="x">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">lookup</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">unkidx</span><span class="x">)</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> encode</span><span class="x">(</span><span class="n">tokenizer</span><span class="o">::</span><span class="n">IndexTokenizer</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">seq</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">{</span><span class="n">T</span><span class="x">})</span> <span class="k">where</span> <span class="n">T</span>
    <span class="n">map</span><span class="x">(</span><span class="n">x</span><span class="o">-&gt;</span><span class="n">encode</span><span class="x">(</span><span class="n">tokenizer</span><span class="x">,</span> <span class="n">x</span><span class="x">),</span> <span class="n">seq</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>We can add a method to do multiple dispatch on the type <code class="language-plaintext highlighter-rouge">IndexTokenizer</code> itself 
which turns the struct into a function:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="x">(</span><span class="n">tokenizer</span><span class="o">::</span><span class="n">IndexTokenizer</span><span class="x">)(</span><span class="n">x</span><span class="x">)</span> <span class="o">=</span> <span class="n">encode</span><span class="x">(</span><span class="n">tokenizer</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span></code></pre></figure>

<p>Encoding example:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">push!</span><span class="x">(</span><span class="n">characters</span><span class="x">,</span> <span class="sc">'Ø'</span><span class="x">)</span> <span class="c"># unknown symbol</span>
<span class="n">vocab_size</span> <span class="o">=</span> <span class="n">length</span><span class="x">(</span><span class="n">characters</span><span class="x">)</span> <span class="c"># 72</span>
<span class="n">indexer</span> <span class="o">=</span> <span class="n">IndexTokenizer</span><span class="x">(</span><span class="n">characters</span><span class="x">,</span> <span class="sc">'Ø'</span><span class="x">)</span>
<span class="n">tokens</span> <span class="o">=</span> <span class="n">indexer</span><span class="x">(</span><span class="n">collect</span><span class="x">(</span><span class="s">"How now, my love?"</span><span class="x">))</span> <span class="c"># [19, 55, 63, 2, 54, ..., 62, 45, 11]</span></code></pre></figure>

<p>Decoding goes the other way:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">decode</span><span class="x">(</span><span class="n">tokenizer</span><span class="o">::</span><span class="n">IndexTokenizer</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">x</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span> <span class="k">where</span> <span class="n">T</span> <span class="o">=</span> 
	<span class="mi">0</span> <span class="o">&lt;=</span> <span class="n">x</span> <span class="o">&lt;=</span> <span class="n">length</span><span class="x">(</span><span class="n">tokenizer</span><span class="x">)</span> <span class="o">?</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">vocabulary</span><span class="x">[</span><span class="n">x</span><span class="x">]</span> <span class="o">:</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">unksym</span>

<span class="k">function</span><span class="nf"> decode</span><span class="x">(</span><span class="n">tokenizer</span><span class="o">::</span><span class="n">IndexTokenizer</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">seq</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Int</span><span class="x">})</span> <span class="k">where</span> <span class="n">T</span>
    <span class="n">map</span><span class="x">(</span><span class="n">x</span><span class="o">-&gt;</span><span class="n">decode</span><span class="x">(</span><span class="n">tokenizer</span><span class="x">,</span> <span class="n">x</span><span class="x">),</span> <span class="n">seq</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>An example:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">join</span><span class="x">(</span><span class="n">decode</span><span class="x">(</span><span class="n">indexer</span><span class="x">,</span> <span class="x">[</span><span class="mi">23</span><span class="x">,</span> <span class="mi">36</span><span class="x">,</span> <span class="mi">30</span><span class="x">,</span> <span class="mi">12</span><span class="x">,</span> <span class="mi">25</span><span class="x">,</span> <span class="mi">15</span><span class="x">,</span> <span class="mi">16</span><span class="x">,</span> <span class="mi">29</span><span class="x">]))</span> <span class="c"># LYSANDER</span></code></pre></figure>

<h3 id="embeddings">3.3 Embeddings</h3>

<p>Each token is transformed into a vector of floating point numbers. 
This vector represents some sort of meaning in a large, abstract vector space, where vectors that are closer to each other are more similar.
(There is plenty of literature on this subject.)</p>

<p>Flux.jl comes with an embedding layer which can be used directly:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">embedding</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">Embedding</span><span class="x">(</span><span class="mi">72</span> <span class="o">=&gt;</span> <span class="mi">32</span><span class="x">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">1</span><span class="o">:</span><span class="mi">72</span><span class="x">,</span> <span class="mi">10</span><span class="x">)</span> <span class="c"># [40, 49, 55, 65, 27, 50, 35, 69, 40, 29]</span>
<span class="n">embedding</span><span class="x">(</span><span class="n">x</span><span class="x">)</span> <span class="c"># 32×10 Matrix{Float32}</span></code></pre></figure>

<p>Here is the <a href="https://github.com/FluxML/Flux.jl/blob/009d9841960ac15d9a02499ac6e341e777dedf34/src/layers/basic.jl#L762">source code</a>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> Embedding</span><span class="x">{</span><span class="n">W</span><span class="o">&lt;:</span><span class="kt">AbstractMatrix</span><span class="x">}</span>
  <span class="n">weight</span><span class="o">::</span><span class="n">W</span>
<span class="k">end</span>

<span class="n">Flux</span><span class="o">.</span><span class="nd">@layer</span> <span class="n">Embedding</span>

<span class="n">Embedding</span><span class="x">((</span><span class="k">in</span><span class="x">,</span> <span class="n">out</span><span class="x">)</span><span class="o">::</span><span class="kt">Pair</span><span class="x">{</span><span class="o">&lt;:</span><span class="kt">Integer</span><span class="x">,</span> <span class="o">&lt;:</span><span class="kt">Integer</span><span class="x">};</span> <span class="n">init</span> <span class="o">=</span> <span class="n">randn32</span><span class="x">)</span> <span class="o">=</span> <span class="n">Embedding</span><span class="x">(</span><span class="n">init</span><span class="x">(</span><span class="n">out</span><span class="x">,</span> <span class="k">in</span><span class="x">))</span>

<span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Embedding</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="kt">Integer</span><span class="x">)</span> <span class="o">=</span> <span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="n">x</span><span class="x">]</span>
<span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Embedding</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">)</span> <span class="o">=</span> <span class="n">NNlib</span><span class="o">.</span><span class="n">gather</span><span class="x">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="x">,</span> <span class="n">x</span><span class="x">)</span>
<span class="x">(</span><span class="n">m</span><span class="o">::</span><span class="n">Embedding</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">)</span> <span class="o">=</span> <span class="n">reshape</span><span class="x">(</span><span class="n">m</span><span class="x">(</span><span class="n">vec</span><span class="x">(</span><span class="n">x</span><span class="x">)),</span> <span class="o">:</span><span class="x">,</span> <span class="n">size</span><span class="x">(</span><span class="n">x</span><span class="x">)</span><span class="o">...</span><span class="x">)</span>

<span class="k">function</span><span class="nf"> Base.show</span><span class="x">(</span><span class="n">io</span><span class="o">::</span><span class="kt">IO</span><span class="x">,</span> <span class="n">m</span><span class="o">::</span><span class="n">Embedding</span><span class="x">)</span>
  <span class="n">print</span><span class="x">(</span><span class="n">io</span><span class="x">,</span> <span class="s">"Embedding("</span><span class="x">,</span> <span class="n">size</span><span class="x">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="x">,</span> <span class="mi">2</span><span class="x">),</span> <span class="s">" =&gt; "</span><span class="x">,</span> <span class="n">size</span><span class="x">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="x">,</span> <span class="mi">1</span><span class="x">),</span> <span class="s">")"</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>This struct stores a weight, by default the smaller datatype of <code class="language-plaintext highlighter-rouge">Float32</code> rather than the usual Julia default of <code class="language-plaintext highlighter-rouge">Float64</code>.
This saves on space without reducing accuracy. 
(<code class="language-plaintext highlighter-rouge">Float16</code>, <code class="language-plaintext highlighter-rouge">Float8</code> and as low as <code class="language-plaintext highlighter-rouge">Float4</code> are all used in machine learning models.)</p>

<p>On the forward pass each index is used to retrieve the associated column vector from the matrix.
However instead of using <code class="language-plaintext highlighter-rouge">m.weight[:, x]</code> the function uses <code class="language-plaintext highlighter-rouge">NNlib.gather(m.weight, x)</code>. 
This is because <code class="language-plaintext highlighter-rouge">gather</code> comes with an  <code class="language-plaintext highlighter-rouge">rrule</code> defined for it (<a href="https://github.com/FluxML/NNlib.jl/blob/1af2535d12cfdcabc6ccd2f259968c16e84c7b81/src/gather.jl#L131">source</a>):</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">∇gather_src</span><span class="x">(</span><span class="n">Δ</span><span class="x">,</span> <span class="n">src_size</span><span class="x">,</span> <span class="n">idx</span><span class="x">)</span> <span class="o">=</span> <span class="n">scatter!</span><span class="x">(</span><span class="o">+</span><span class="x">,</span> <span class="n">fill!</span><span class="x">(</span><span class="n">similar</span><span class="x">(</span><span class="n">Δ</span><span class="x">,</span> <span class="n">eltype</span><span class="x">(</span><span class="n">Δ</span><span class="x">),</span> <span class="n">src_size</span><span class="x">),</span> <span class="mi">0</span><span class="x">),</span> <span class="n">Δ</span><span class="x">,</span> <span class="n">idx</span><span class="x">)</span></code></pre></figure>

<p>The <code class="language-plaintext highlighter-rouge">rrule</code> is a reverse (backwards) rule that encodes the derivative for backpropagation.
It is what makes the magic of automatic differentiation work.</p>

<p>The function <code class="language-plaintext highlighter-rouge">gather</code> does not have a formal derivative, but <code class="language-plaintext highlighter-rouge">scatter</code> is the opposite of it and is what we need to apply when we calculate the loss:</p>
<figure class="post-figure">
<img class="img-60" src="/assets/posts/transformers/gather.png" alt="architecture" />
<figcaption></figcaption>
</figure>

<p>At the end of backpropagation we need to distribute the error matrix amongst the original word embeddings.
This is what <code class="language-plaintext highlighter-rouge">scatter</code> does. Note that we use the red column twice, so we have two error columns directed towards it.
The <code class="language-plaintext highlighter-rouge">rrule</code> applies <code class="language-plaintext highlighter-rouge">+</code> as the reducing function; that is, the two errors are added together and then to the word embedding.</p>

<p>Scatter can be inefficient.
If we do a small experiment and call scatter we will see it results in a large matrix of mostly zeros:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">NNlib</span><span class="o">.</span><span class="n">scatter</span><span class="x">(</span><span class="o">+</span><span class="x">,</span> <span class="n">rand</span><span class="x">(</span><span class="mi">8</span><span class="x">,</span> <span class="mi">4</span><span class="x">),</span> <span class="x">[</span><span class="mi">1</span><span class="x">,</span> <span class="mi">5</span><span class="x">,</span> <span class="mi">11</span><span class="x">,</span> <span class="mi">1</span><span class="x">];</span> <span class="n">dstsize</span><span class="o">=</span><span class="x">(</span><span class="mi">8</span><span class="x">,</span> <span class="mi">15</span><span class="x">))</span>
<span class="mi">8</span><span class="n">×15</span> <span class="kt">Matrix</span><span class="x">{</span><span class="kt">Float64</span><span class="x">}</span><span class="o">:</span>
 <span class="mf">1.62703</span>   <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.495725</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.237452</span>     <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>
 <span class="mf">0.979735</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.984499</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.145738</span>     <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>
 <span class="mf">0.892948</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.76959</span>   <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.714658</span>     <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>
 <span class="mf">1.45113</span>   <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.883492</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.52775</span>      <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>
 <span class="mf">0.702824</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.965256</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0966964</span>    <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>
 <span class="mf">1.16978</span>   <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.568429</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.000161501</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>
 <span class="mf">1.80566</span>   <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.271676</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.430018</span>     <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>
 <span class="mf">1.16445</span>   <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.911601</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.786343</span>     <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span>  <span class="mf">0.0</span></code></pre></figure>

<h3 id="position-encoding">3.4 Position encoding</h3>

<p>The matrix operations used in the transformer are parallel operations.
This speeds up computation and is a major reason why they are so popular.
However this is an issue: they do not take order into account.
We can shuffle the columns in the embedding matrix and it will not affect the output.</p>

<figure class="post-figure">
<img class="img-80" src="/assets/posts/transformers/position_encoding_similarities.png" alt="position encoding cosine similarities" />
<figcaption>Cosine similarities of different position encodings. The learned embedding is from a model made using the code in this blog post.</figcaption>
</figure>

<p>To counter-act this, the authors of the <a href="https://arxiv.org/abs/1706.03762">Attention is all you need (2017)</a> paper suggested adding a second embedding to the first where the indices are the positions in the sequence.<sup id="fnref:cosine" role="doc-noteref"><a href="#fn:cosine" class="footnote" rel="footnote">1</a></sup></p>

<p>We can use an <code class="language-plaintext highlighter-rouge">Embedding</code> matrix as before, except with a different input:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">position_encoding</span> <span class="o">=</span> <span class="n">Embedding</span><span class="x">(</span><span class="mi">16</span> <span class="o">=&gt;</span> <span class="mi">32</span><span class="x">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">32</span><span class="x">,</span> <span class="mi">10</span><span class="x">)</span> <span class="c"># the output of the first embedding layer</span>
<span class="n">indices</span> <span class="o">=</span> <span class="mi">1</span><span class="o">:</span><span class="n">size</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="mi">2</span><span class="x">)</span> <span class="c"># 1:10</span>
<span class="n">embedding</span><span class="x">(</span><span class="n">indices</span><span class="x">)</span> <span class="c"># 32×10 Matrix{Float32}</span></code></pre></figure>

<div class="card">
  <div class="card-body">
    <h5 class="card-title">Other position encodings</h5>
    <p class="card-text">
    Transformers are an active area of research and many position encodings have been proposed.
    <ul>
        <li> Sinusodial Position Encodings: The original paper gave equations to calculate a fixed embedding matrix.
        For an explanation and implementation see my <a href="/machine-learning/2022/05/18/transformers#position-encodings">first post</a>.
        </li>
        <li> <a href="https://arxiv.org/abs/1803.02155">Relative Position Embeddings (RPE) (2018)</a>: add embeddings in the attention step where each entry relates $r=j_k-i_q$.
        The <a href="https://arxiv.org/pdf/1809.04281">Music Transformer (2018)</a> paper greatly improved computation of this matrix.
        For a helpful video see <a href="https://www.youtube.com/watch?v=XdlmDfa2hew">Relative Self-Attention Explained</a>.
        </li>
        <li><a href="https://arxiv.org/abs/2104.09864">Rotary Position Embeddings (RoPE) (2023)</a>: encode absolute position with a fixed rotation matrix, which handles longer sequences better and encodes relative positions better than sinusoidal embeddings. A downside is every <code>key</code> and <code>query</code> in the attention step needs to be multiplied by this rotation matrix instead of adding an encoding once at the start.
        </li>
        <li><a href="https://arxiv.org/abs/2306.15595">Position Interpolation (2023)</a>: extending embedding matrices by linearly down-scaling the input position indices to match the original context window size. For example, each index in a 128 context window can be down-scaled by a factor of 2 to match a 64 length encoding matrix. Half indices like 42.5 are a linear combination of the indices before and after (so 42 and 43). Some fine tuning is required for best results.
        This can be combined with RoPE.
        </li>
    </ul>
    </p>
  </div>
</div>

<p>This embedding matrix restricts the context size.
In the example the embedding matrix is 32×16 so a maximum of 16 tokens that can be passed to the model at time.
To overcome this a sliding window must be implemented and the model will completely “forget” any character outside of the window.</p>

<p>Ideally we would create an embedding matrix as large as possible so that the bottleneck is the training data, not the model. 
However attention, which will be discussed in the next section, scales with $n^2$ for a context length $n$.
This is a significant performance penalty for a larger context size.</p>

<h3 id="attention">3.5 Attention</h3>

<figure class="post-figure" id="fig-multi-head-attention">
<img class="img-30" src="/assets/posts/transformers/multihead_attention.png" alt="Multi-head attention" />
<figcaption>Source: <a href="https://arxiv.org/abs/1706.03762">Attention paper (2017)</a></figcaption>
</figure>

<ol>
  <li><a href="#attention-definition">Definition</a></li>
  <li><a href="#attention-masking">Masking</a></li>
  <li><a href="#attention-batched-multiplication">Batched multiplication</a></li>
  <li><a href="#attention-multiheadattention-layer">MultiHeadAttention layer</a></li>
  <li><a href="#attention-multi-head-attention">Multi-Head Attention</a></li>
  <li><a href="#attention-scaled-dot-attention">Scaled Dot Attention</a></li>
  <li><a href="#attention-full-example">Full example</a></li>
</ol>

<h4 id="attention-definition">3.5.1 Definition</h4>

<p>Attention is the main mechanism at the heart of the transformer.
Theoretically it is a weighting of every token towards every other token, including itself.
It is asymmetrical and so forms a full $n \times n$ matrix.
For example consider word level tokens for the sentence “The elephant in the room”.
The tokens “The”, “in”, “the” and “room” might all rate “elephant” the highest, but “elephant” will probably only rate “room” highly.</p>

<div class="message-container info-message">
	<div class="message-icon fa fa-fw fa-2x fa-exclamation-circle">
	</div>
	<div class="content-container">
		<div class="message-body">
		Julia uses column major format whereas Python uses row major format. In Julia word vectors are columns while in Python they are rows.
		Equations between the two formats will look backwards to each other.
		They need to be transposed and definitions also need to be transposed. 
		E.g. $K^TQ \rightarrow (K_c^TQ_c)^T=Q_c^TK_c= Q_r K_r^T$
		</div>
	</div>
</div>

<p>The attention equation is:
\(A = V\text{softmax}\left(\frac{1}{\sqrt{d_h}}K^T Q\right)
\label{eq:attention}
\tag{3.6.1}\)</p>

<p>where $\text{softmax}$ is given by:</p>

\[\text{softmax}(z, i) = \frac{e^{z_i}}{\sum_r^V e^{z_r}}
\label{eq:softmax}
\tag{3.6.2}\]

<p>Its calculation scales with $\mathcal{O}(n^2d_h)$ where $n$ is the input token length and $d_h$ is the head dimension, also known as the hidden dimension.</p>

<div class="card">
  <div class="card-body">
    <h5 class="card-title">Efficient self-attention</h5>
    <p class="card-text">
    Given the $n^2$ scaling of attention much effort has gone into altering this step.
    This includes sparse attention layers, factorisation/kernels for linear attention and down-sampling.
    A detailed survey can be found at <a href="https://arxiv.org/abs/2009.06732">Efficient Transformers: A Survey (2020)</a>.
    All these alternatives are faster than attention but come at the expense of accuracy.
    </p>
    <p>
    Another line of research is to improve the computation.
    This include
    <a href="https://arxiv.org/abs/2205.14135">Flash attention (2022)</a> which improves computational efficiency on a single GPU while <a href="https://arxiv.org/abs/2310.01889">Ring attention (2023)</a> aims to distribute the work efficiently across multiple devices.
    </p>
  </div>
</div>

<p>Here the key $K$, query $Q$ and value $V$ are derived from the input matrix $X$ using weights:</p>

\[\begin{align}
    K = W_K X \\
    Q = W_Q X \\
    V = W_V X
\end{align}\]

<p>Each weight $W$ has a size $d_h \times d_\text{emb}$ and the input matrix has a size $d_\text{emb} \times n$ where $d_\text{emb}$ is the embedding dimension. Each of these matrices therefore has a size $d_h \times n$.</p>

<p>From this we can show that the first matrix product is a weighted <a href="https://en.wikipedia.org/wiki/Dot_product">dot product</a> of every vector to every other vector in the input matrix, resulting in a $n \times n$ matrix:</p>

\[K^T Q = (W_KX)^T(W_QX) = X^T W_K^T W_Q X\]

<p>This is then following by scaling ($1/\sqrt{d_h}$) and normalisation ($\text{softmax}$).
Lastly this matrix is used as a weight for $V$. The output is $d_h \times n$.</p>

<h4 id="attention-masking">3.5.2 Masking</h4>

<p>There is a flaw in this architecture.
The attention is computed across all tokens at once.
This means that past tokens will be given access to future tokens.
However the training objective is to predict future tokens.
Therefore only the $n$th token, whose next token is missing, will be trained fairly.</p>

<p>To overcome this the authors of the <a href="https://arxiv.org/abs/1706.03762">Attention (2017)</a> paper suggested masking the matrix before the softmax with $-\infty$ at each illegal connection, so that $\exp(-\infty)=0$ which effectively removes their influence.</p>

<p>The masked matrix will look like:</p>

\[\begin{bmatrix}
s_{11} &amp; s_{12} &amp; s_{13} &amp;... &amp; s_{1n} \\
-\infty &amp; s_{22} &amp; s_{23} &amp;... &amp; s_{2n} \\
-\infty &amp; -\infty &amp; s_{33} &amp;... &amp; s_{3n} \\
\vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\
-\infty &amp; -\infty &amp; -\infty &amp; ... &amp; s_{nn}
\end{bmatrix}\]

<p>Firstly a mask is made where all valid connections have a <code class="language-plaintext highlighter-rouge">true</code> and all illegal connections have a <code class="language-plaintext highlighter-rouge">false</code>.
Here is the code from <a href="https://github.com/FluxML/NNlib.jl/blob/07833637dec96d12d0614308d3145b432fdb320a/src/attention.jl#L149">NNlib.jl</a>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">LinearAlgebra</span>
<span class="k">function</span><span class="nf"> make_causal_mask</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">;</span> <span class="n">dims</span><span class="o">::</span><span class="kt">Int</span><span class="o">=</span><span class="mi">2</span><span class="x">)</span>
  <span class="n">len</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">dims</span><span class="x">)</span>
  <span class="n">mask</span> <span class="o">=</span> <span class="n">triu</span><span class="x">(</span><span class="n">trues_like</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="x">(</span><span class="n">len</span><span class="x">,</span> <span class="n">len</span><span class="x">)))</span>
  <span class="n">mask</span>
<span class="k">end</span>

<span class="n">trues_like</span><span class="x">(</span><span class="n">x</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">,</span> <span class="n">sz</span><span class="o">=</span><span class="n">size</span><span class="x">(</span><span class="n">x</span><span class="x">))</span> <span class="o">=</span> <span class="n">fill!</span><span class="x">(</span><span class="n">similar</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="kt">Bool</span><span class="x">,</span> <span class="n">sz</span><span class="x">),</span> <span class="nb">true</span><span class="x">)</span></code></pre></figure>

<div class="card">
  <div class="card-body">
    <h5 class="card-title">Dataless masks</h5>
    <p class="card-text">
    We don't have to allocate memory to create a mask. The causal mask is defined by $j \geq i$ for all indices $i$, $j$. We can write this as a function as long as we can also write an equivalent <code>rrule</code> for it as well. See <a href="https://github.com/chengchingwen/NeuralAttentionlib.jl/blob/master/src/mask/dataless.jl">NeuralAttentionlib.jl</a> for such an implementation.
    </p>
  </div>
</div>

<p>The mask will be applied through <code class="language-plaintext highlighter-rouge">ifelse</code>, where <code class="language-plaintext highlighter-rouge">true</code>s maintain their value but the <code class="language-plaintext highlighter-rouge">false</code>s are replaced with some large negative number.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">apply_mask</span><span class="x">(</span><span class="n">logits</span><span class="x">,</span> <span class="n">mask</span><span class="o">::</span><span class="kt">Nothing</span><span class="x">)</span> <span class="o">=</span> <span class="n">logits</span>

<span class="k">function</span><span class="nf"> apply_mask</span><span class="x">(</span><span class="n">logits</span><span class="x">,</span> <span class="n">mask</span><span class="x">)</span>
    <span class="n">neginf</span> <span class="o">=</span> <span class="n">typemin</span><span class="x">(</span><span class="n">eltype</span><span class="x">(</span><span class="n">logits</span><span class="x">))</span>
    <span class="n">ifelse</span><span class="o">.</span><span class="x">(</span><span class="n">mask</span><span class="x">,</span> <span class="n">logits</span><span class="x">,</span> <span class="n">neginf</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>Usage:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">mask</span> <span class="o">=</span> <span class="n">make_causal_mask</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">5</span><span class="x">,</span> <span class="mi">5</span><span class="x">))</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="kt">Float32</span><span class="x">,</span> <span class="mi">5</span><span class="x">,</span> <span class="mi">5</span><span class="x">)</span>
<span class="n">apply_mask</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">mask</span><span class="x">)</span> <span class="c"># 5×5 Matrix{Float32}:</span></code></pre></figure>

<p>Backpropagation:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">Flux</span><span class="o">:</span> <span class="n">pullback</span>
<span class="n">y</span><span class="x">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">pullback</span><span class="x">(</span><span class="n">apply_mask</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="n">mask</span><span class="x">);</span>
<span class="n">grads</span> <span class="o">=</span> <span class="n">back</span><span class="x">(</span><span class="n">randn</span><span class="x">(</span><span class="n">size</span><span class="x">(</span><span class="n">y</span><span class="x">)</span><span class="o">...</span><span class="x">))</span>
<span class="n">grads</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span> <span class="c"># zero where -inf</span></code></pre></figure>

<p>As an experiment, set the mask to <code class="language-plaintext highlighter-rouge">nothing</code> during training. 
It should be possible to get very low training losses (below 0.5) corresponding to very low perplexities (less than 2) with very small models but without a corresponding increase in generation quality.</p>

<h4 id="attention-batched-multiplication">3.5.3 Batched multiplication</h4>

<p>The <a href="https://arxiv.org/abs/1706.03762">Attention (2017)</a> paper suggested a further enhancement on attention where the input matrix is divided amongst $H$ heads. This results in a $\tfrac{d_\text{emb}}{H} \times n \times H$ array.
Furthermore, working with batches adds an extra dimension:  $d_h \times n \times H \times B$.</p>

<p>We could work with these arrays as <code class="language-plaintext highlighter-rouge">Vector{&lt;:Matrix{T}}</code> and <code class="language-plaintext highlighter-rouge">Vector{&lt;:Vector{&lt;:Matrix{T}}}</code> respectively, but it is more efficient to work with them as <code class="language-plaintext highlighter-rouge">Array{T, 3}</code> and <code class="language-plaintext highlighter-rouge">Array{T, 4}</code> because  then we can work with optimised array functions.</p>

<p>My <a href="/machine-learning/2022/05/18/transformers#multiplication-with-higher-order-arrays">first post</a> goes into more detail about multiplication with higher order arrays.<sup id="fnref:tensors" role="doc-noteref"><a href="#fn:tensors" class="footnote" rel="footnote">2</a></sup>
It compares vanilla versions with optimised versions.
Here I will present the optimised version only.</p>

<p>Batch multiplication is defined as:</p>

\[C_{ijk} = \sum_r A_{irk} B_{rjk}\]

<p>An optimised version is available through the NNlib.jl library, a dependency of Flux.jl:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">NNlib</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">6</span><span class="x">,</span> <span class="mi">8</span><span class="x">,</span> <span class="mi">4</span><span class="x">);</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">8</span><span class="x">,</span> <span class="mi">5</span><span class="x">,</span> <span class="mi">4</span><span class="x">);</span>
<span class="n">NNlib</span><span class="o">.</span><span class="n">batched_mul</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="n">B</span><span class="x">)</span> <span class="c"># 6×5×4 Array{Float64, 3}</span></code></pre></figure>

<p>The 4D batched multiplication is defined as:</p>

\[C_{ijkl} = \sum_r A_{irkl} B_{rjkl}\]

<p>We can calculate this array with the same <code class="language-plaintext highlighter-rouge">batched_mul</code> by reshaping any 4D $m\times n \times p \times q$ arrays into 3D $m\times n \times pq$ arrays, do the multiplication, and reshape back.
This is exactly what the implementation does behind the <a href="https://github.com/FluxML/NNlib.jl/blob/07833637dec96d12d0614308d3145b432fdb320a/src/batched/batchedmul.jl#L47">scenes</a>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">NNlib</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">6</span><span class="x">,</span> <span class="mi">8</span><span class="x">,</span> <span class="mi">4</span><span class="x">,</span> <span class="mi">3</span><span class="x">);</span>
<span class="n">B</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">8</span><span class="x">,</span> <span class="mi">5</span><span class="x">,</span> <span class="mi">4</span><span class="x">,</span> <span class="mi">3</span><span class="x">);</span>
<span class="n">NNlib</span><span class="o">.</span><span class="n">batched_mul</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="n">B</span><span class="x">)</span> <span class="c"># 6×5×4×3 Array{Float64, 4}</span></code></pre></figure>

<p>The Flux <code class="language-plaintext highlighter-rouge">Dense</code> layer does something <a href="https://github.com/FluxML/Flux.jl/blob/348c56f6172c6ce838790b0ba23c5f4c58d93b83/src/layers/basic.jl#L177">similar</a>.</p>

<h4 id="attention-multiheadattention-layer">3.5.4 MultiHeadAttention layer</h4>

<p>Flux.jl now comes with a <code class="language-plaintext highlighter-rouge">Flux.MultiHeadAttention</code> layer.
However for continuity with my <a href="/machine-learning/2022/05/18/transformers#multi-head-attention">first post</a>, I will present my own <code class="language-plaintext highlighter-rouge">MultiheadAttention</code> layer except now with masking.
It is very similar to the code in <a href="https://github.com/FluxML/Flux.jl/blob/master/src/layers/attention.jl">Flux.jl</a> and <a href="https://github.com/FluxML/NNlib.jl/blob/master/src/attention.jl">NNlib.jl</a>.
The differences are in design choices for the inputs and Flux.jl’s implementations are slightly more generic.</p>

<p>First define a struct to hold all the dense layers and a parameter for $H$ called <code class="language-plaintext highlighter-rouge">nhead</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> MultiheadAttention</span><span class="x">{</span><span class="n">Q</span><span class="o">&lt;:</span><span class="n">Dense</span><span class="x">,</span> <span class="n">K</span><span class="o">&lt;:</span><span class="n">Dense</span><span class="x">,</span> <span class="n">V</span><span class="o">&lt;:</span><span class="n">Dense</span><span class="x">,</span> <span class="n">O</span><span class="o">&lt;:</span><span class="n">Dense</span><span class="x">}</span>
    <span class="n">nhead</span><span class="o">::</span><span class="kt">Int</span>
    <span class="n">denseQ</span><span class="o">::</span><span class="n">Q</span>
    <span class="n">denseK</span><span class="o">::</span><span class="n">K</span>
    <span class="n">denseV</span><span class="o">::</span><span class="n">V</span>
    <span class="n">denseO</span><span class="o">::</span><span class="n">O</span>
<span class="k">end</span>

<span class="cm">#= tell Flux which parameters are trainable =#</span>
<span class="n">Flux</span><span class="o">.</span><span class="nd">@layer</span> <span class="n">MultiHeadAttention</span> <span class="n">trainable</span><span class="o">=</span><span class="x">(</span><span class="n">denseQ</span><span class="x">,</span> <span class="n">denseK</span><span class="x">,</span> <span class="n">denseV</span><span class="x">,</span> <span class="n">denseO</span><span class="x">)</span></code></pre></figure>

<p>The model is defined by 4 values: the number of heads $H$, the input dimension $d_\text{in}$, the output dimension $d_\text{out}$ and the head dimension $d_h$. The default for $d_h$ is $d_\text{in}/H$.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> MultiheadAttention</span><span class="x">(</span>
    <span class="n">nhead</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">dim_in</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">dim_head</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">dim_out</span><span class="o">::</span><span class="kt">Int</span>
    <span class="x">)</span>
    <span class="n">MultiheadAttention</span><span class="x">(</span>
        <span class="n">nhead</span><span class="x">,</span>
        <span class="n">Dense</span><span class="x">(</span><span class="n">dim_in</span><span class="x">,</span> <span class="n">dim_head</span><span class="o">*</span><span class="n">nhead</span><span class="x">;</span> <span class="n">bias</span><span class="o">=</span><span class="nb">false</span><span class="x">),</span>
        <span class="n">Dense</span><span class="x">(</span><span class="n">dim_in</span><span class="x">,</span> <span class="n">dim_head</span><span class="o">*</span><span class="n">nhead</span><span class="x">;</span> <span class="n">bias</span><span class="o">=</span><span class="nb">false</span><span class="x">),</span>
        <span class="n">Dense</span><span class="x">(</span><span class="n">dim_in</span><span class="x">,</span> <span class="n">dim_head</span><span class="o">*</span><span class="n">nhead</span><span class="x">;</span> <span class="n">bias</span><span class="o">=</span><span class="nb">false</span><span class="x">),</span>
        <span class="n">Dense</span><span class="x">(</span><span class="n">dim_head</span><span class="o">*</span><span class="n">nhead</span><span class="x">,</span> <span class="n">dim_out</span><span class="x">),</span>
    <span class="x">)</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> MultiheadAttention</span><span class="x">(</span>
    <span class="n">nhead</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">dim_in</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">dim_out</span><span class="o">::</span><span class="kt">Int</span>
    <span class="x">)</span>
    <span class="k">if</span> <span class="n">dim_in</span> <span class="o">%</span> <span class="n">nhead</span> <span class="o">!=</span> <span class="mi">0</span> 
        <span class="n">error</span><span class="x">(</span><span class="s">"input dimension=</span><span class="si">$</span><span class="s">dim_in is not divisible by number of heads=</span><span class="si">$</span><span class="s">nhead"</span><span class="x">)</span>
    <span class="k">end</span>
    <span class="n">MultiheadAttention</span><span class="x">(</span><span class="n">nhead</span><span class="x">,</span> <span class="n">dim_in</span><span class="x">,</span> <span class="n">div</span><span class="x">(</span><span class="n">dim_in</span><span class="x">,</span> <span class="n">nhead</span><span class="x">),</span> <span class="n">dim_out</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>Now for the forward pass. 
In general there are three input matrices with the names of <code class="language-plaintext highlighter-rouge">key</code>, <code class="language-plaintext highlighter-rouge">query</code> and <code class="language-plaintext highlighter-rouge">value</code>.
Later we will pass the same value <code class="language-plaintext highlighter-rouge">x</code> for all of them.
From these we can calculate $Q$, $K$ and $V$ and pass them to the <code class="language-plaintext highlighter-rouge">multi_head_scaled_dot_attention</code> function:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> </span><span class="o">(</span><span class="n">mha</span><span class="o">::</span><span class="n">MultiheadAttention</span><span class="x">)(</span><span class="n">query</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">key</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">value</span><span class="o">::</span><span class="n">A3</span>
    <span class="x">;</span> <span class="n">kwargs</span><span class="o">...</span><span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">A3</span> <span class="o">&lt;:</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">3</span><span class="x">}}</span>
    <span class="n">Q</span> <span class="o">=</span> <span class="n">mha</span><span class="o">.</span><span class="n">denseQ</span><span class="x">(</span><span class="n">query</span><span class="x">)</span>
    <span class="n">K</span> <span class="o">=</span> <span class="n">mha</span><span class="o">.</span><span class="n">denseK</span><span class="x">(</span><span class="n">key</span><span class="x">)</span>
    <span class="n">V</span> <span class="o">=</span> <span class="n">mha</span><span class="o">.</span><span class="n">denseV</span><span class="x">(</span><span class="n">value</span><span class="x">)</span>
    <span class="n">A</span><span class="x">,</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">multi_head_scaled_dot_attention</span><span class="x">(</span><span class="n">mha</span><span class="o">.</span><span class="n">nhead</span><span class="x">,</span> <span class="n">Q</span><span class="x">,</span> <span class="n">K</span><span class="x">,</span> <span class="n">V</span><span class="x">;</span> <span class="n">kwargs</span><span class="o">...</span><span class="x">)</span>
    <span class="n">mha</span><span class="o">.</span><span class="n">denseO</span><span class="x">(</span><span class="n">A</span><span class="x">),</span> <span class="n">scores</span>
<span class="k">end</span></code></pre></figure>

<p>This layer returns the scores as well, like Flux.jl’s <code class="language-plaintext highlighter-rouge">MultiheadAttention</code> layer.
These are useful for inspecting the model.</p>

<h4 id="attention-multi-head-attention">3.5.5 Multi-Head Attention</h4>

<p>The <code class="language-plaintext highlighter-rouge">multi_head_scaled_dot_attention</code> begins as follows:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> multi_head_scaled_dot_attention</span><span class="x">(</span><span class="n">nhead</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">Q</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">K</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">V</span><span class="o">::</span><span class="n">A3</span>
    <span class="x">;</span> <span class="n">kwargs</span><span class="o">...</span><span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">A3</span> <span class="o">&lt;:</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">3</span><span class="x">}}</span>
    <span class="n">qs</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">Q</span><span class="x">)</span>
    <span class="n">ks</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">K</span><span class="x">)</span>
    <span class="n">vs</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">V</span><span class="x">)</span>
    <span class="n">dm</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">Q</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span>
    <span class="n">dh</span> <span class="o">=</span> <span class="n">div</span><span class="x">(</span><span class="n">dm</span><span class="x">,</span> <span class="n">nhead</span><span class="x">)</span></code></pre></figure>

<p>The $Q$, $K$ and $V$ matrices need to be split from $d_m \times N \times B$ to $d_h \times N \times H \times B$.
This is done in two steps:</p>
<ol>
  <li>$(d_h \times H)\times N \times B$ (break $d_m$ into $d_h$ and $H$)</li>
  <li>$d_h \times N \times H \times B$ (swap the 2nd and 3rd dimensions)</li>
</ol>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="n">Q</span> <span class="o">=</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">reshape</span><span class="x">(</span><span class="n">Q</span><span class="x">,</span> <span class="n">dh</span><span class="x">,</span> <span class="n">nhead</span><span class="x">,</span> <span class="n">qs</span><span class="x">[</span><span class="mi">2</span><span class="x">],</span> <span class="n">qs</span><span class="x">[</span><span class="mi">3</span><span class="x">]),</span> <span class="x">[</span><span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">])</span>
    <span class="n">K</span> <span class="o">=</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">reshape</span><span class="x">(</span><span class="n">K</span><span class="x">,</span> <span class="n">dh</span><span class="x">,</span> <span class="n">nhead</span><span class="x">,</span> <span class="n">ks</span><span class="x">[</span><span class="mi">2</span><span class="x">],</span> <span class="n">ks</span><span class="x">[</span><span class="mi">3</span><span class="x">]),</span> <span class="x">[</span><span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">])</span>
    <span class="n">V</span> <span class="o">=</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">reshape</span><span class="x">(</span><span class="n">V</span><span class="x">,</span> <span class="n">dh</span><span class="x">,</span> <span class="n">nhead</span><span class="x">,</span> <span class="n">vs</span><span class="x">[</span><span class="mi">2</span><span class="x">],</span> <span class="n">vs</span><span class="x">[</span><span class="mi">3</span><span class="x">]),</span> <span class="x">[</span><span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">])</span></code></pre></figure>

<p>Then we calculate the scaled dot attention for each head, combine results and return it:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia">    <span class="n">A</span><span class="x">,</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">scaled_dot_attention</span><span class="x">(</span><span class="n">Q</span><span class="x">,</span> <span class="n">K</span><span class="x">,</span> <span class="n">V</span><span class="x">;</span> <span class="n">kwargs</span><span class="o">...</span><span class="x">)</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="x">[</span><span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">,</span> <span class="mi">2</span><span class="x">,</span> <span class="mi">4</span><span class="x">])</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">reshape</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="n">dm</span><span class="x">,</span> <span class="n">size</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="mi">3</span><span class="x">),</span> <span class="n">size</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="mi">4</span><span class="x">))</span>
    <span class="n">A</span><span class="x">,</span> <span class="n">scores</span>
<span class="k">end</span></code></pre></figure>

<h4 id="attention-scaled-dot-attention">3.5.6 Scaled Dot Attention</h4>

<p>The scaled dot attention is defined by default for 3D arrays. $Q$ is of size $d_h \times d_q \times H$
while $K$ and $V$ are both of size $d_h \times d_{kv} \times H$. Usually $n=d_q=d_{kv}$.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> scaled_dot_attention</span><span class="x">(</span>
    <span class="n">query</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">key</span><span class="o">::</span><span class="n">A3</span><span class="x">,</span> <span class="n">value</span><span class="o">::</span><span class="n">A3</span>
    <span class="x">;</span> <span class="n">mask</span><span class="o">::</span><span class="kt">Union</span><span class="x">{</span><span class="kt">Nothing</span><span class="x">,</span> <span class="n">M</span><span class="x">}</span><span class="o">=</span><span class="nb">nothing</span>
    <span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">A3</span> <span class="o">&lt;:</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">3</span><span class="x">},</span> <span class="n">M</span> <span class="o">&lt;:</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="kt">Bool</span><span class="x">}}</span>
    <span class="n">dh</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">query</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span>
    <span class="n">keyT</span> <span class="o">=</span> <span class="n">permutedims</span><span class="x">(</span><span class="n">key</span><span class="x">,</span> <span class="x">(</span><span class="mi">2</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="mi">3</span><span class="x">))</span> <span class="c"># (dkv, dh, nhead)</span>
    <span class="n">atten</span> <span class="o">=</span> <span class="n">one</span><span class="x">(</span><span class="n">T</span><span class="x">)</span><span class="o">/</span><span class="n">convert</span><span class="x">(</span><span class="n">T</span><span class="x">,</span> <span class="n">sqrt</span><span class="x">(</span><span class="n">dh</span><span class="x">))</span> <span class="o">.*</span> <span class="n">batched_mul</span><span class="x">(</span><span class="n">keyT</span><span class="x">,</span> <span class="n">query</span><span class="x">)</span> <span class="c"># (dkv, dh, nhead)*(dh, dq, nhead) =&gt; (dkv, dq, nhead)</span>
    <span class="n">atten</span> <span class="o">=</span> <span class="n">apply_mask</span><span class="x">(</span><span class="n">atten</span><span class="x">,</span> <span class="n">mask</span><span class="x">)</span> <span class="c"># (dkv, dq, nhead)</span>
    <span class="n">scores</span> <span class="o">=</span> <span class="n">softmax</span><span class="x">(</span><span class="n">atten</span><span class="x">;</span> <span class="n">dims</span><span class="o">=</span><span class="mi">1</span><span class="x">)</span> <span class="c"># (dkv, dq, nhead)</span>
    <span class="n">batched_mul</span><span class="x">(</span><span class="n">value</span><span class="x">,</span> <span class="n">scores</span><span class="x">),</span> <span class="n">scores</span> <span class="c"># (dh, dkv, nhead)*(dkv, dq, nhead) =&gt; (dh, dq, nhead)</span>
<span class="k">end</span></code></pre></figure>

<p>As explained <a href="#attention-batched-multiplication">above</a>, we need to reshape 4D arrays into 3D arrays, apply the usual scaled dot attention and then reshape back:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> scaled_dot_attention</span><span class="x">(</span><span class="n">query</span><span class="o">::</span><span class="n">A4</span><span class="x">,</span> <span class="n">key</span><span class="o">::</span><span class="n">A4</span><span class="x">,</span> <span class="n">value</span><span class="o">::</span><span class="n">A4</span>
    <span class="x">;</span> <span class="n">kwargs</span><span class="o">...</span><span class="x">)</span> <span class="k">where</span> <span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="n">A4</span> <span class="o">&lt;:</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">4</span><span class="x">}}</span>
    <span class="n">batch_size</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">query</span><span class="x">)[</span><span class="mi">3</span><span class="o">:</span><span class="k">end</span><span class="x">]</span>
    <span class="n">Q</span><span class="x">,</span> <span class="n">K</span><span class="x">,</span> <span class="n">V</span> <span class="o">=</span> <span class="n">map</span><span class="x">(</span><span class="n">x</span> <span class="o">-&gt;</span> <span class="n">reshape</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">size</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="mi">1</span><span class="x">),</span> <span class="n">size</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="mi">2</span><span class="x">),</span> <span class="o">:</span><span class="x">),</span> <span class="x">(</span><span class="n">query</span><span class="x">,</span> <span class="n">key</span><span class="x">,</span> <span class="n">value</span><span class="x">))</span>
    <span class="n">A</span><span class="x">,</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">scaled_dot_attention</span><span class="x">(</span><span class="n">Q</span><span class="x">,</span> <span class="n">K</span><span class="x">,</span> <span class="n">V</span><span class="x">;</span> <span class="n">kwargs</span><span class="o">...</span><span class="x">)</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">reshape</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="x">(</span><span class="n">size</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="mi">1</span><span class="x">),</span> <span class="n">size</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="mi">2</span><span class="x">),</span> <span class="n">batch_size</span><span class="o">...</span><span class="x">))</span>
    <span class="n">scores</span> <span class="o">=</span> <span class="n">reshape</span><span class="x">(</span><span class="n">scores</span><span class="x">,</span> <span class="x">(</span><span class="n">size</span><span class="x">(</span><span class="n">scores</span><span class="x">,</span> <span class="mi">1</span><span class="x">),</span> <span class="n">size</span><span class="x">(</span><span class="n">scores</span><span class="x">,</span> <span class="mi">2</span><span class="x">),</span> <span class="n">batch_size</span><span class="o">...</span><span class="x">))</span>
    <span class="n">A</span><span class="x">,</span> <span class="n">scores</span>
<span class="k">end</span></code></pre></figure>

<h4 id="attention-full-example">3.5.7 Full example</h4>

<p>Model:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">mha</span> <span class="o">=</span> <span class="n">MultiheadAttention</span><span class="x">(</span><span class="mi">4</span><span class="x">,</span> <span class="mi">32</span><span class="x">,</span> <span class="mi">32</span><span class="x">)</span>
<span class="n">Flux</span><span class="o">.</span><span class="n">_big_show</span><span class="x">(</span><span class="nb">stdout</span><span class="x">,</span> <span class="n">mha</span><span class="x">)</span>
<span class="cm">#=
MultiheadAttention(
  4,
  Dense(32 =&gt; 32; bias=false),          # 1_024 parameters
  Dense(32 =&gt; 32; bias=false),          # 1_024 parameters
  Dense(32 =&gt; 32; bias=false),          # 1_024 parameters
  Dense(32 =&gt; 32),                      # 1_056 parameters
)                   # Total: 5 arrays, 4_128 parameters, 16.422 KiB.
=#</span></code></pre></figure>

<p>Forward pass:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">x</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="kt">Float32</span><span class="x">,</span> <span class="mi">32</span><span class="x">,</span> <span class="mi">20</span><span class="x">,</span> <span class="mi">2</span><span class="x">)</span> <span class="c"># d×n×B</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">make_causal_mask</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">32</span><span class="x">,</span> <span class="mi">20</span><span class="x">))</span>
<span class="n">y</span><span class="x">,</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">mha</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="n">x</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">)</span> <span class="c"># 32×20×2 Array{Float32, 3}, 20×20×4×2 Array{Float32, 4}</span></code></pre></figure>

<p>Backpropagation:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">Flux</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">sum</span> <span class="c"># dummy loss function</span>
<span class="n">grads</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">gradient</span><span class="x">(</span><span class="n">m</span> <span class="o">-&gt;</span> <span class="n">loss</span><span class="x">(</span><span class="n">m</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="n">x</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">)[</span><span class="mi">1</span><span class="x">]),</span> <span class="n">mha</span><span class="x">)</span>
<span class="n">keys</span><span class="x">(</span><span class="n">grads</span><span class="x">[</span><span class="mi">1</span><span class="x">])</span> <span class="c"># (:nhead, :denseQ, :denseK, :denseV, :denseO)</span></code></pre></figure>

<h3 id="transformer-blocks">3.6 Transformer Blocks</h3>

<p>The other components we need for the transformer block are Layer Norm, Feed Forward (two consecutive dense layers) and dropout. 
We can use the Flux.jl implementations for these.</p>

<figure class="post-figure" id="fig-gpt-block">
<img class="img-20" src="/assets/posts/transformers/gpt-block.png" alt="Transformer block" />
<figcaption>Source: <a href="https://web.archive.org/web/20210126024542/https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf">GPT1 paper (2018)</a></figcaption>
</figure>

<p>This means we can now create a transformer block:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> TransformerBlock</span><span class="x">{</span>
    <span class="n">MHA</span><span class="o">&lt;:</span><span class="n">MultiheadAttention</span><span class="x">,</span>
    <span class="n">N1</span><span class="o">&lt;:</span><span class="n">LayerNorm</span><span class="x">,</span>
    <span class="n">D1</span><span class="o">&lt;:</span><span class="n">Dense</span><span class="x">,</span>
    <span class="n">D2</span><span class="o">&lt;:</span><span class="n">Dense</span><span class="x">,</span>
    <span class="n">N2</span><span class="o">&lt;:</span><span class="n">LayerNorm</span><span class="x">,</span>
    <span class="n">DO</span><span class="o">&lt;:</span><span class="n">Dropout</span><span class="x">}</span>
    <span class="n">multihead_attention</span><span class="o">::</span><span class="n">MHA</span>
    <span class="n">norm_attention</span><span class="o">::</span><span class="n">N1</span>
    <span class="n">dense1</span><span class="o">::</span><span class="n">D1</span>
    <span class="n">dense2</span><span class="o">::</span><span class="n">D2</span>
    <span class="n">norm_feedforward</span><span class="o">::</span><span class="n">N2</span>
    <span class="n">dropout</span><span class="o">::</span><span class="n">DO</span>
<span class="k">end</span>

<span class="n">Flux</span><span class="o">.</span><span class="nd">@layer</span> <span class="n">TransformerBlock</span> <span class="c"># make whole layer trainable</span></code></pre></figure>

<p>This whole block can be defined with only 5 parameters:</p>
<ol>
  <li>The number of heads $H$.</li>
  <li>The dimension $d$.</li>
  <li>The hidden dimension for the feed-forward network. The convention is $4d$.</li>
  <li>The activation function.</li>
  <li>A drop out probability.</li>
</ol>

<p>In code:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">TransformerBlock</span><span class="x">(</span>
    <span class="n">nhead</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span>
    <span class="n">dim_model</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span>
    <span class="n">dim_hidden</span><span class="o">::</span><span class="kt">Int</span><span class="x">;</span>
    <span class="n">act</span><span class="o">=</span><span class="n">relu</span><span class="x">,</span>
    <span class="n">pdrop</span><span class="o">::</span><span class="kt">Float64</span><span class="o">=</span><span class="mf">0.1</span><span class="x">,</span>
    <span class="x">)</span> <span class="o">=</span> <span class="n">TransformerBlock</span><span class="x">(</span>
    <span class="n">MultiheadAttention</span><span class="x">(</span><span class="n">nhead</span><span class="x">,</span> <span class="n">dim_model</span><span class="x">,</span> <span class="n">dim_model</span><span class="x">),</span>
    <span class="n">LayerNorm</span><span class="x">(</span><span class="n">dim_model</span><span class="x">),</span>
    <span class="n">Dense</span><span class="x">(</span><span class="n">dim_model</span><span class="x">,</span> <span class="n">dim_hidden</span><span class="x">,</span> <span class="n">act</span><span class="x">),</span>
    <span class="n">Dense</span><span class="x">(</span><span class="n">dim_hidden</span><span class="x">,</span> <span class="n">dim_model</span><span class="x">),</span>
    <span class="n">LayerNorm</span><span class="x">(</span><span class="n">dim_model</span><span class="x">),</span>
    <span class="n">Dropout</span><span class="x">(</span><span class="n">pdrop</span><span class="x">),</span>
<span class="x">)</span></code></pre></figure>

<p>There are skip connections in the forward pass:<sup id="fnref:block_scores" role="doc-noteref"><a href="#fn:block_scores" class="footnote" rel="footnote">3</a></sup></p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> </span><span class="o">(</span><span class="n">t</span><span class="o">::</span><span class="n">TransformerBlock</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="n">A</span><span class="x">;</span> <span class="n">mask</span><span class="o">::</span><span class="n">M</span><span class="o">=</span><span class="nb">nothing</span><span class="x">)</span> <span class="k">where</span> <span class="x">{</span>
    <span class="n">A</span><span class="o">&lt;:</span><span class="kt">AbstractArray</span><span class="x">,</span> <span class="n">M</span><span class="o">&lt;:</span><span class="kt">Union</span><span class="x">{</span><span class="kt">Nothing</span><span class="x">,</span> <span class="kt">AbstractArray</span><span class="x">{</span><span class="kt">Bool</span><span class="x">}}}</span>
    <span class="n">h</span><span class="x">,</span> <span class="n">scores</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">multihead_attention</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="n">x</span><span class="x">,</span> <span class="n">x</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">)</span> <span class="c"># (dm, N, B)</span>
    <span class="n">h</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">dropout</span><span class="x">(</span><span class="n">h</span><span class="x">)</span> 
    <span class="n">h</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">h</span>
    <span class="n">h</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">norm_attention</span><span class="x">(</span><span class="n">h</span><span class="x">)</span>     <span class="c"># (dm, N, B)</span>
    <span class="n">hff</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">dense1</span><span class="x">(</span><span class="n">h</span><span class="x">)</span>           <span class="c"># (dh, N, B)</span>
    <span class="n">hff</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">dense2</span><span class="x">(</span><span class="n">hff</span><span class="x">)</span>         <span class="c"># (dm, N, B)</span>
    <span class="n">hff</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">dropout</span><span class="x">(</span><span class="n">hff</span><span class="x">)</span>
    <span class="n">h</span> <span class="o">=</span> <span class="n">h</span> <span class="o">+</span> <span class="n">hff</span>
    <span class="n">h</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">norm_feedforward</span><span class="x">(</span><span class="n">h</span><span class="x">)</span>   <span class="c"># (dm, N, B)</span>
    <span class="n">h</span>
<span class="k">end</span></code></pre></figure>

<p>Model:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">block</span> <span class="o">=</span> <span class="n">TransformerBlock</span><span class="x">(</span><span class="mi">4</span><span class="x">,</span> <span class="mi">32</span><span class="x">,</span> <span class="mi">32</span><span class="o">*</span><span class="mi">4</span><span class="x">)</span> 
<span class="n">Flux</span><span class="o">.</span><span class="n">_big_show</span><span class="x">(</span><span class="nb">stdout</span><span class="x">,</span> <span class="n">block</span><span class="x">)</span>
<span class="cm">#=
TransformerBlock(
  ...
)  # Total: 13 arrays, 12_608 parameters, 50.234 KiB.
=#</span></code></pre></figure>

<p>Forward pass:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">x</span> <span class="o">=</span> <span class="n">randn</span><span class="x">(</span><span class="kt">Float32</span><span class="x">,</span> <span class="mi">32</span><span class="x">,</span> <span class="mi">20</span><span class="x">,</span> <span class="mi">2</span><span class="x">)</span> <span class="c"># d×n×B</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">make_causal_mask</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="mi">32</span><span class="x">,</span> <span class="mi">20</span><span class="x">))</span> <span class="c"># 20×20 Matrix{Bool}</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">block</span><span class="x">(</span><span class="n">x</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">)</span> <span class="c"># 32×20×2 Array{Float32, 3}</span></code></pre></figure>

<p>Backpropagation:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">loss</span> <span class="o">=</span> <span class="n">sum</span> <span class="c"># dummy loss function</span>
<span class="n">grads</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">gradient</span><span class="x">(</span><span class="n">m</span> <span class="o">-&gt;</span> <span class="n">loss</span><span class="x">(</span><span class="n">m</span><span class="x">(</span><span class="n">x</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">)),</span> <span class="n">block</span><span class="x">)</span>
<span class="n">keys</span><span class="x">(</span><span class="n">grads</span><span class="x">[</span><span class="mi">1</span><span class="x">])</span> <span class="c"># (:multihead_attention, :norm_attention, :dense1, :dense2, :norm_feedforward, :dropout)</span></code></pre></figure>

<h3 id="generator">3.7 Generator</h3>

<figure class="post-figure" id="fig-gpt-model">
<img class="img-30" src="/assets/posts/transformers/gpt-model.png" alt="Transformer generator" />
<figcaption>Modified from <a href="https://web.archive.org/web/20210126024542/https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf">GPT1 paper (2018)</a></figcaption>
</figure>

<p>We will create a struct to hold the generator.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> TransformerGenerator</span><span class="x">{</span>
    <span class="n">E</span><span class="o">&lt;:</span><span class="n">Flux</span><span class="o">.</span><span class="n">Embedding</span><span class="x">,</span> 
    <span class="n">PE</span><span class="o">&lt;:</span><span class="n">Flux</span><span class="o">.</span><span class="n">Embedding</span><span class="x">,</span> 
    <span class="n">DO</span><span class="o">&lt;:</span><span class="n">Dropout</span><span class="x">,</span> 
    <span class="n">TB</span><span class="o">&lt;:</span><span class="kt">Vector</span><span class="x">{</span><span class="o">&lt;:</span><span class="n">TransformerBlock</span><span class="x">},</span> 
    <span class="n">D</span><span class="o">&lt;:</span><span class="n">Dense</span><span class="x">,</span>
    <span class="n">M</span><span class="o">&lt;:</span><span class="kt">Union</span><span class="x">{</span><span class="kt">Nothing</span><span class="x">,</span> <span class="kt">AbstractMatrix</span><span class="x">{</span><span class="kt">Bool</span><span class="x">}},</span>
    <span class="x">}</span> 
    <span class="n">embedding</span><span class="o">::</span><span class="n">E</span>
    <span class="n">position_encoding</span><span class="o">::</span><span class="n">PE</span>
    <span class="n">dropout</span><span class="o">::</span><span class="n">DO</span>
    <span class="n">blocks</span><span class="o">::</span><span class="n">TB</span>
    <span class="n">head</span><span class="o">::</span><span class="n">D</span>
    <span class="n">mask</span><span class="o">::</span><span class="n">M</span> <span class="c"># optional buffer</span>
<span class="k">end</span>

<span class="n">Flux</span><span class="o">.</span><span class="nd">@layer</span> <span class="n">TransformerGenerator</span> <span class="n">trainable</span><span class="o">=</span><span class="x">(</span><span class="n">embedding</span><span class="x">,</span> <span class="n">position_encoding</span><span class="x">,</span> <span class="n">blocks</span><span class="x">,</span> <span class="n">dropout</span><span class="x">,</span> <span class="n">head</span><span class="x">)</span></code></pre></figure>

<p>By default the forward pass will use the model’s mask, else the user can pass a mask to it:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> </span><span class="o">(</span><span class="n">t</span><span class="o">::</span><span class="n">TransformerGenerator</span><span class="x">)(</span><span class="n">x</span><span class="o">::</span><span class="n">A</span><span class="x">;</span> <span class="n">mask</span><span class="o">::</span><span class="n">M</span><span class="o">=</span><span class="n">t</span><span class="o">.</span><span class="n">mask</span><span class="x">)</span> <span class="k">where</span> <span class="x">{</span>
    <span class="n">A</span><span class="o">&lt;:</span><span class="kt">AbstractArray</span><span class="x">,</span> <span class="n">M</span><span class="o">&lt;:</span><span class="kt">Union</span><span class="x">{</span><span class="kt">Nothing</span><span class="x">,</span> <span class="kt">AbstractMatrix</span><span class="x">{</span><span class="kt">Bool</span><span class="x">}}}</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">embedding</span><span class="x">(</span><span class="n">x</span><span class="x">)</span>              <span class="c"># (dm, N, B)</span>
    <span class="n">N</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">x</span><span class="x">,</span> <span class="mi">2</span><span class="x">)</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">.+</span> <span class="n">t</span><span class="o">.</span><span class="n">position_encoding</span><span class="x">(</span><span class="mi">1</span><span class="o">:</span><span class="n">N</span><span class="x">)</span> <span class="c"># (dm, N, B)</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">dropout</span><span class="x">(</span><span class="n">x</span><span class="x">)</span>                <span class="c"># (dm, N, B)</span>
    <span class="k">for</span> <span class="n">block</span> <span class="k">in</span> <span class="n">t</span><span class="o">.</span><span class="n">blocks</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">block</span><span class="x">(</span><span class="n">x</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">)</span>     <span class="c"># (dm, N, B)</span>
    <span class="k">end</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">head</span><span class="x">(</span><span class="n">x</span><span class="x">)</span>                   <span class="c"># (vocab_size, N, B)</span>
    <span class="n">x</span>
<span class="k">end</span></code></pre></figure>

<p>Create a model:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">context_size</span> <span class="o">=</span> <span class="mi">64</span>
<span class="n">dim</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">nheads</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">vocab_size</span> <span class="o">=</span> <span class="mi">71</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">make_causal_mask</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="n">context_size</span><span class="x">,</span> <span class="n">context_size</span><span class="x">))</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">TransformerGenerator</span><span class="x">(</span>
    <span class="n">Embedding</span><span class="x">(</span><span class="n">vocab_size</span> <span class="o">=&gt;</span> <span class="n">dim</span><span class="x">),</span>
    <span class="n">Embedding</span><span class="x">(</span><span class="n">context_size</span> <span class="o">=&gt;</span> <span class="n">dim</span><span class="x">),</span>
    <span class="n">Dropout</span><span class="x">(</span><span class="mf">0.1</span><span class="x">),</span>
    <span class="n">TransformerBlock</span><span class="x">[</span>
        <span class="n">TransformerBlock</span><span class="x">(</span><span class="mi">4</span><span class="x">,</span> <span class="n">dim</span><span class="x">,</span> <span class="n">dim</span> <span class="o">*</span> <span class="mi">4</span><span class="x">;</span> <span class="n">pdrop</span><span class="o">=</span><span class="mf">0.1</span><span class="x">),</span>
        <span class="n">TransformerBlock</span><span class="x">(</span><span class="mi">4</span><span class="x">,</span> <span class="n">dim</span><span class="x">,</span> <span class="n">dim</span> <span class="o">*</span> <span class="mi">4</span><span class="x">;</span> <span class="n">pdrop</span><span class="o">=</span><span class="mf">0.1</span><span class="x">),</span>
        <span class="n">TransformerBlock</span><span class="x">(</span><span class="mi">4</span><span class="x">,</span> <span class="n">dim</span><span class="x">,</span> <span class="n">dim</span> <span class="o">*</span> <span class="mi">4</span><span class="x">;</span> <span class="n">pdrop</span><span class="o">=</span><span class="mf">0.1</span><span class="x">),</span>
    <span class="x">],</span>
    <span class="n">Dense</span><span class="x">(</span><span class="n">dim</span><span class="x">,</span> <span class="n">vocab_size</span><span class="x">),</span>
    <span class="n">copy</span><span class="x">(</span><span class="n">mask</span><span class="x">)</span>
<span class="x">)</span>
<span class="n">Flux</span><span class="o">.</span><span class="n">_big_show</span><span class="x">(</span><span class="nb">stdout</span><span class="x">,</span> <span class="n">model</span><span class="x">)</span>
<span class="cm">#=
TransformerGenerator(
  ...
)         # Total: 43 trainable arrays, 44_487 parameters,
          # plus 1 non-trainable, 4_096 parameters, summarysize 180.410 KiB.
=#</span></code></pre></figure>

<p>We can test it with a random vector of indices:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">x</span> <span class="o">=</span> <span class="n">reshape</span><span class="x">(</span><span class="n">rand</span><span class="x">(</span><span class="mi">1</span><span class="o">:</span><span class="n">vocab_size</span><span class="x">,</span> <span class="mi">34</span><span class="x">),</span> <span class="o">:</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span> <span class="c"># make it a batch of 1</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">make_causal_mask</span><span class="x">(</span><span class="n">ones</span><span class="x">(</span><span class="n">dim</span><span class="x">,</span> <span class="n">length</span><span class="x">(</span><span class="n">x</span><span class="x">)))</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">model</span><span class="x">(</span><span class="n">x</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">)</span> <span class="c"># 71×34×1 Array{Float32, 3}</span></code></pre></figure>

<p>Or a random batch:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">X</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">1</span><span class="o">:</span><span class="n">vocab_size</span><span class="x">,</span> <span class="mi">34</span><span class="x">,</span> <span class="mi">10</span><span class="x">)</span>
<span class="n">Y</span> <span class="o">=</span> <span class="n">model</span><span class="x">(</span><span class="n">X</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">)</span> <span class="c"># 71×34×10</span></code></pre></figure>

<h3 id="generation">3.8 Generation</h3>

<p>Let’s now generate text with the model.</p>

<p>The model has a fixed context length.
To generate text longer than this fixed length we will implement a sliding window.
This window will take the last $n$ tokens (rows) of the current context for each column (sample) in the batch:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> tail</span><span class="x">(</span><span class="n">A</span><span class="o">::</span><span class="kt">AbstractMatrix</span><span class="x">,</span> <span class="n">n</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span>
    <span class="n">n</span> <span class="o">=</span> <span class="n">min</span><span class="x">(</span><span class="n">n</span><span class="x">,</span> <span class="n">size</span><span class="x">(</span><span class="n">A</span><span class="x">,</span> <span class="mi">1</span><span class="x">))</span>
    <span class="n">A</span><span class="x">[(</span><span class="k">end</span> <span class="o">-</span> <span class="n">n</span> <span class="o">+</span> <span class="mi">1</span><span class="x">)</span><span class="o">:</span><span class="k">end</span><span class="x">,</span> <span class="o">:</span><span class="x">]</span>
<span class="k">end</span></code></pre></figure>

<p>The transformer generates a $V\times N \times B$ matrix. We will only take the logits for the last token per iteration, resulting in a $V\times B$ matrix.
These logits will be converted to probabilities via the softmax function $\ref{eq:softmax}$.</p>

<p>We have a choice of how to sample these probabilities.
The greedy approach is to always take the token with the maximum probability.
A better approach is to randomly sample based on the probabilities.
That way a token with a high probability is more likely to be chosen, but it is not guaranteed.
This gives us some diversity in the results.
We then add this to the context and repeat.</p>

<p>The full function is:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">Random</span><span class="x">,</span> <span class="n">StatsBase</span>
<span class="k">function</span><span class="nf"> generate</span><span class="x">(</span>
    <span class="n">rng</span><span class="o">::</span><span class="kt">AbstractRNG</span><span class="x">,</span> <span class="n">model</span><span class="o">::</span><span class="n">TransformerGenerator</span><span class="x">,</span> <span class="n">context</span><span class="o">::</span><span class="kt">AbstractMatrix</span><span class="x">{</span><span class="n">T</span><span class="x">}</span>
    <span class="x">;</span> <span class="n">context_size</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">max_tokens</span><span class="o">::</span><span class="kt">Int</span><span class="o">=</span><span class="mi">100</span><span class="x">,</span>
    <span class="x">)</span> <span class="k">where</span> <span class="n">T</span>
    <span class="k">for</span> <span class="n">i</span> <span class="k">in</span> <span class="mi">1</span><span class="o">:</span><span class="n">max_tokens</span>
        <span class="n">context_crop</span> <span class="o">=</span> <span class="n">tail</span><span class="x">(</span><span class="n">context</span><span class="x">,</span> <span class="n">context_size</span><span class="x">)</span>
        <span class="n">n</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">context_crop</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span>
        <span class="n">mask</span> <span class="o">=</span> <span class="n">isnothing</span><span class="x">(</span><span class="n">model</span><span class="o">.</span><span class="n">mask</span><span class="x">)</span> <span class="o">?</span> <span class="nb">nothing</span> <span class="o">:</span> <span class="n">view</span><span class="x">(</span><span class="n">model</span><span class="o">.</span><span class="n">mask</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">n</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">n</span><span class="x">)</span>
        <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="x">(</span><span class="n">context_crop</span><span class="x">;</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">)</span> <span class="o">|&gt;</span> <span class="n">cpu</span> <span class="c"># (vocab_size, n, B)</span>
        <span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="k">end</span><span class="x">,</span> <span class="o">:</span><span class="x">]</span> <span class="c"># (vocab_size, B) </span>
        <span class="n">context_next</span> <span class="o">=</span> <span class="n">multinomial_sampling</span><span class="x">(</span><span class="n">rng</span><span class="x">,</span> <span class="n">logits</span><span class="x">)</span>
        <span class="n">context</span> <span class="o">=</span> <span class="n">cat</span><span class="x">(</span><span class="n">context</span><span class="x">,</span> <span class="n">context_next</span><span class="x">;</span> <span class="n">dims</span><span class="o">=</span><span class="mi">1</span><span class="x">)</span> 
    <span class="k">end</span>
    <span class="n">context</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> generate</span><span class="x">(</span><span class="n">model</span><span class="o">::</span><span class="n">TransformerGenerator</span><span class="x">,</span> <span class="n">context</span><span class="o">::</span><span class="kt">AbstractMatrix</span><span class="x">;</span> <span class="n">kwargs</span><span class="o">...</span><span class="x">)</span>
    <span class="n">generate</span><span class="x">(</span><span class="n">Random</span><span class="o">.</span><span class="n">default_rng</span><span class="x">(),</span> <span class="n">model</span><span class="x">,</span> <span class="n">context</span><span class="x">;</span> <span class="n">kwargs</span><span class="o">...</span><span class="x">)</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> multinomial_sampling</span><span class="x">(</span><span class="n">rng</span><span class="o">::</span><span class="kt">AbstractRNG</span><span class="x">,</span> <span class="n">logits</span><span class="o">::</span><span class="kt">AbstractMatrix</span><span class="x">)</span>
    <span class="n">probs</span> <span class="o">=</span> <span class="n">softmax</span><span class="x">(</span><span class="n">logits</span><span class="x">;</span> <span class="n">dims</span><span class="o">=</span><span class="mi">1</span><span class="x">)</span>
    <span class="n">tokens</span> <span class="o">=</span> <span class="x">[</span><span class="n">sample</span><span class="x">(</span><span class="n">rng</span><span class="x">,</span> <span class="n">Weights</span><span class="x">(</span><span class="n">p</span><span class="x">))</span> <span class="k">for</span> <span class="n">p</span> <span class="k">in</span> <span class="n">eachcol</span><span class="x">(</span><span class="n">probs</span><span class="x">)]</span>
    <span class="n">tokens</span>
<span class="k">end</span></code></pre></figure>

<p>Testing it out:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">context</span> <span class="o">=</span> <span class="n">reshape</span><span class="x">([</span><span class="mi">1</span><span class="x">],</span> <span class="mi">1</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span> <span class="c"># start with the new line symbol</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">generate</span><span class="x">(</span><span class="n">model</span><span class="x">,</span> <span class="n">context</span><span class="x">;</span> <span class="n">context_size</span><span class="o">=</span><span class="mi">64</span><span class="x">)</span> <span class="c"># 101×1 Matrix{Int64}</span></code></pre></figure>

<p>Decode the output using the tokenizer from <a href="#tokenization">section 3.2</a>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">decoded_text</span> <span class="o">=</span> <span class="n">join</span><span class="x">(</span><span class="n">decode</span><span class="x">(</span><span class="n">indexer</span><span class="x">,</span> <span class="n">out</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="mi">1</span><span class="x">]))</span>
<span class="n">print</span><span class="x">(</span><span class="n">decoded_text</span><span class="x">)</span></code></pre></figure>

<p>The output:</p>
<blockquote><pre>

A[RH N)pEy.QEgs?YbgnRsz-ZRDdUXvU Pzwzzxukvv_P;goxe(G;C;I
RIgB ‘E[xIqZ-J;gK—wwEUTZYtUg:tEhl-kZ;s:x.ggt
</pre></blockquote>

<p>This is nonsense. The model does no better than drawing each character randomly.
We need to train the model to get something sensible out of it.</p>

<h2 id="training">4 Training</h2>
<h3 id="train-validation-split">4.1 Train/validation split</h3>

<p>It is always good practice to split the data into train, validation and test splits.
For simplicity, we’ll only use a train and validation split. 
We’ll put the first 95% of data in the train split and the remainder in the validation split.<sup id="fnref:split" role="doc-noteref"><a href="#fn:split" class="footnote" rel="footnote">4</a></sup></p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">tokens</span> <span class="o">=</span> <span class="n">indexer</span><span class="x">(</span><span class="n">collect</span><span class="x">(</span><span class="n">text</span><span class="x">))</span>
<span class="n">n_val</span> <span class="o">=</span> <span class="n">floor</span><span class="x">(</span><span class="kt">Int</span><span class="x">,</span> <span class="x">(</span><span class="mf">0.95</span><span class="x">)</span> <span class="o">*</span> <span class="n">length</span><span class="x">(</span><span class="n">tokens</span><span class="x">))</span>
<span class="n">train_data</span> <span class="o">=</span> <span class="n">tokens</span><span class="x">[</span><span class="mi">1</span><span class="o">:</span><span class="n">n_val</span><span class="x">]</span>
<span class="n">val_data</span> <span class="o">=</span> <span class="n">tokens</span><span class="x">[(</span><span class="n">n_val</span> <span class="o">+</span> <span class="mi">1</span><span class="x">)</span><span class="o">:</span><span class="k">end</span><span class="x">]</span></code></pre></figure>

<h3 id="batch-generation">4.2 Batch generation</h3>

<p>The model will be trained on segments of the text which match the context length $n$.
For a text of length $L$ there are $L-n+1$ characters we can select to be the first character of the segments, excluding the last $n-1$ characters.
For the Shakespeare text, this results in approximately 4.9 million different segments.</p>

<p>There is however plenty of overlap so we don’t have to train on all of them.
We can instead randomly sample segments from the text.
Characters at any point in the text will have a probability of appearing of $p\approx n/L$ (the ends are less likely).
For many steps $s$ this binomial distribution can be approximated with a normal distribution with a mean $sp\approx sn/L$ and standard deviation $\sqrt{sp(1-p)}\approx \sqrt{sn/L}$. For example, for 4.9 million characters, a context length of 64 and 100,000 steps, each character at each point will appear 1.31±1.14 times.</p>

<p>The other important task is to create the reference text that the model will be trained to generate, which is simply the input text shifted by one.
(This reduces the number of valid segments by 1.)</p>

<p>The function is as follows:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">Random</span>
<span class="k">function</span><span class="nf"> get_shifted_batch</span><span class="x">(</span><span class="n">rng</span><span class="o">::</span><span class="kt">AbstractRNG</span><span class="x">,</span> <span class="n">data</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">,</span> <span class="n">context_size</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">batch_size</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span>
    <span class="n">indices</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="n">rng</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="x">(</span><span class="n">length</span><span class="x">(</span><span class="n">data</span><span class="x">)</span><span class="o">-</span><span class="n">context_size</span><span class="x">),</span> <span class="n">batch_size</span><span class="x">)</span>
    <span class="n">X</span> <span class="o">=</span> <span class="n">similar</span><span class="x">(</span><span class="n">data</span><span class="x">,</span> <span class="n">context_size</span><span class="x">,</span> <span class="n">batch_size</span><span class="x">)</span>
    <span class="n">Y</span> <span class="o">=</span> <span class="n">similar</span><span class="x">(</span><span class="n">data</span><span class="x">,</span> <span class="n">context_size</span><span class="x">,</span> <span class="n">batch_size</span><span class="x">)</span>
    <span class="k">for</span> <span class="x">(</span><span class="n">j</span><span class="x">,</span> <span class="n">idx</span><span class="x">)</span> <span class="k">in</span> <span class="n">enumerate</span><span class="x">(</span><span class="n">indices</span><span class="x">)</span>
        <span class="n">X</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="n">j</span><span class="x">]</span> <span class="o">=</span> <span class="n">data</span><span class="x">[</span><span class="n">idx</span><span class="o">:</span><span class="x">(</span><span class="n">idx</span> <span class="o">+</span> <span class="n">context_size</span> <span class="o">-</span> <span class="mi">1</span><span class="x">)]</span>
        <span class="n">Y</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="n">j</span><span class="x">]</span> <span class="o">=</span> <span class="n">data</span><span class="x">[(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="x">)</span><span class="o">:</span><span class="x">(</span><span class="n">idx</span> <span class="o">+</span> <span class="n">context_size</span><span class="x">)]</span>
    <span class="k">end</span>
    <span class="n">X</span><span class="x">,</span> <span class="n">Y</span>
<span class="k">end</span>

<span class="n">get_shifted_batch</span><span class="x">(</span><span class="n">data</span><span class="o">::</span><span class="kt">AbstractVector</span><span class="x">,</span> <span class="n">context_size</span><span class="o">::</span><span class="kt">Int</span><span class="x">,</span> <span class="n">batch_size</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span> <span class="o">=</span> 
    <span class="n">get_shifted_batch</span><span class="x">(</span><span class="n">Random</span><span class="o">.</span><span class="n">default_rng</span><span class="x">(),</span> <span class="n">data</span><span class="x">,</span> <span class="n">context_size</span><span class="x">,</span> <span class="n">batch_size</span><span class="x">)</span></code></pre></figure>

<p>Usage:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">text</span> <span class="o">=</span> <span class="n">rand</span><span class="x">(</span><span class="mi">1</span><span class="o">:</span><span class="mi">72</span><span class="x">,</span> <span class="mi">1000</span><span class="x">)</span> <span class="c"># pretend we've already indexed it</span>
<span class="n">rng</span> <span class="o">=</span> <span class="kt">MersenneTwister</span><span class="x">(</span><span class="mi">2</span><span class="x">)</span>
<span class="n">X</span><span class="x">,</span> <span class="n">Y</span> <span class="o">=</span> <span class="n">get_shifted_batch</span><span class="x">(</span><span class="n">rng</span><span class="x">,</span> <span class="n">text</span><span class="x">,</span> <span class="mi">4</span><span class="x">,</span> <span class="mi">3</span><span class="x">)</span></code></pre></figure>

<p>The outputs look like:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>    X              Y
 1   70  66  |   9  60   3
 9   60   3  |  26   4  32
26   4  32   |   1  17  35
 1   17  35  |  68  54  70
</code></pre></div></div>

<p>Lastly, it can be convenient to wrap this functionality in a struct similar to Flux.jl’s <code class="language-plaintext highlighter-rouge">DataLoader</code>.
For an example of this, please see the <code class="language-plaintext highlighter-rouge">BatchGenerator</code> object in my <a href="https://github.com/LiorSinai/TransformersLite-Examples/blob/main/examples/GPT/generate_batches.jl">generate_batches.jl</a> file.</p>

<h3 id="loss">4.3 Loss</h3>

<p>What is our goal?</p>

<blockquote>
  <p>We want the probability of the true next character to be the highest.</p>
</blockquote>

<p>The model returns a $V \times n \times B$ array. We have an $n \times B$ reference array of the true next characters ($Y$). The first step is to convert it to probabilities - a range of values from 0 to 1 summing to 1 - with the softmax equation $\ref{eq:softmax}$.
We can then pick out the next true characters by converting the reference array to a one hot matrix and multiplying:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">Z</span> <span class="o">=</span> <span class="n">model</span><span class="x">(</span><span class="n">X</span><span class="x">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="x">)</span> <span class="c"># V×n×B</span>
<span class="n">probs</span> <span class="o">=</span> <span class="n">softmax</span><span class="x">(</span><span class="n">Z</span><span class="x">,</span> <span class="n">dims</span><span class="o">=</span><span class="mi">1</span><span class="x">)</span>
<span class="n">Y_onehot</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">onehotbatch</span><span class="x">(</span><span class="n">Y</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">vocab_size</span><span class="x">)</span> <span class="c"># V×n×B</span>
<span class="n">Y_onehot</span> <span class="o">.*</span> <span class="n">probs</span> <span class="c"># V×n×B</span></code></pre></figure>

<p>All the non-zero values are the probabilities of interest.</p>

<p>Since these values are small numbers the convention is to instead use the cross entropy, so $-Y\log(P)$ rather than $YP$.
This maps the values from the range $(0, 1)$ to the range $(0, \infty)$.
We then reduce it to a single value by taking the mean.
This is known as the cross entropy loss:</p>

\[\begin{align}
l(y, p) &amp;= -\frac{1}{N}\sum^{N}_i y_i \log(p_i) \\
  &amp;= -\frac{1}{N}\sum^{N}_i y_i \log\left(\frac{e^{z_i}}{\sum e^z}\right) \\
  &amp;= -\frac{1}{N}\sum^{N}_i y_i \left(z_i - \log\left(\sum e^z\right)\right) 
  \tag{4.2.1} \label{eq:cross_entropy}
\end{align}\]

<p>where $N=nB$.</p>

<p>As a baseline, imagine a model which predicts characters uniformly randomly.
All probabilities will be $1/V$ and hence the loss will reduce to $-\log(1/V)$.
For $V=71$ the expected loss is therefore 4.26.
A trained model should achieve a value closer to 0.</p>

<p>Flux.jl comes with <code class="language-plaintext highlighter-rouge">Flux.logitcrossentropy</code> that will implement equation $\ref{eq:cross_entropy}$:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">l1</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">logitcrossentropy</span><span class="x">(</span><span class="n">Z</span><span class="x">,</span> <span class="n">Y_onehot</span><span class="x">)</span> <span class="c"># Float32</span>
<span class="n">l2</span> <span class="o">=</span> <span class="o">-</span><span class="n">sum</span><span class="x">(</span><span class="n">Y_onehot</span> <span class="o">.*</span> <span class="n">log</span><span class="o">.</span><span class="x">(</span><span class="n">probs</span><span class="x">))</span> <span class="o">/</span> <span class="x">(</span><span class="n">n</span> <span class="o">*</span> <span class="n">B</span><span class="x">)</span> <span class="c"># Float32</span>
<span class="n">l1</span> <span class="n">≈</span> <span class="n">l2</span> <span class="c"># true</span></code></pre></figure>

<p>In a single function:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> full_loss</span><span class="x">(</span><span class="n">Ŷ</span><span class="o">::</span><span class="kt">AbstractArray</span><span class="x">{</span><span class="n">T</span><span class="x">,</span> <span class="mi">3</span><span class="x">},</span> <span class="n">Y</span><span class="o">::</span><span class="kt">AbstractMatrix</span><span class="x">{</span><span class="kt">Int</span><span class="x">})</span> <span class="k">where</span> <span class="n">T</span>
    <span class="n">vocab_size</span> <span class="o">=</span> <span class="n">size</span><span class="x">(</span><span class="n">Ŷ</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span> 
    <span class="n">Y_onehot</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">onehotbatch</span><span class="x">(</span><span class="n">Y</span><span class="x">,</span> <span class="mi">1</span><span class="o">:</span><span class="n">vocab_size</span><span class="x">)</span>
    <span class="n">Flux</span><span class="o">.</span><span class="n">logitcrossentropy</span><span class="x">(</span><span class="n">Ŷ</span><span class="x">,</span> <span class="n">Y_onehot</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>I’ve called it the full loss to indicate that it is over all $nB$ token predictions and not only the last ($B$) tokens.</p>

<h3 id="perplexity">4.4 Perplexity</h3>

<p>Another common measure of the ability of the model is perplexity, which is the inverse of the average probability for each character.
It is defined as:</p>

\[e^{l(y, p)} = \prod_i^N p_i^{-y_i/N} = 1 \div \left(\prod_i^N p_i \right)^{1/N} \tag{4.3} \label{eq:perplexity}\]

<p>where $l(y, p)$ is the cross entropy loss.</p>

<p>The perplexity for random sampling with $p_i=1/V$ is simply $V$.
In other words, the perplexity for randomly sampling 72 characters is a 1 in 72 chance for each character.
A trained model should achieve a value closer to 1 in 1, because the context and known distributions allow the model to select characters with greater than random chance.</p>

<p>Like other types of averages, perplexity does not describe the shape of the distribution and outliers can have an outsized effect on it.</p>

<p>We can use many samples, say 1000 steps of 32 sized batches each to estimate it:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">ProgressMeter</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">num_steps</span> <span class="o">=</span> <span class="mi">1000</span>
<span class="n">mean_loss</span> <span class="o">=</span> <span class="mf">0.0f0</span>
<span class="nd">@showprogress</span> <span class="k">for</span> <span class="n">step</span> <span class="k">in</span> <span class="mi">1</span><span class="o">:</span><span class="n">num_steps</span>
    <span class="n">X</span><span class="x">,</span> <span class="n">Y</span> <span class="o">=</span> <span class="n">get_shifted_batch</span><span class="x">(</span><span class="n">tokens</span><span class="x">,</span> <span class="n">context_size</span><span class="x">,</span> <span class="n">batch_size</span><span class="x">)</span>
    <span class="n">mean_loss</span> <span class="o">+=</span> <span class="n">full_loss</span><span class="x">(</span><span class="n">model</span><span class="x">(</span><span class="n">X</span><span class="x">),</span> <span class="n">Y</span><span class="x">)</span>
<span class="k">end</span>
<span class="n">mean_loss</span> <span class="o">/=</span> <span class="n">num_steps</span>
<span class="n">perplexity</span> <span class="o">=</span> <span class="n">exp</span><span class="x">(</span><span class="n">mean_loss</span><span class="x">)</span></code></pre></figure>

<p>Running this with an untrained model gave me a mean loss of 4.518 and perplexity of 91.7, which is even worse than the theoretical values.</p>

<h3 id="training-loop">4.5 Training loop</h3>

<p>We can now setup a training loop. It will use gradient descent with an <a href="https://arxiv.org/abs/1412.6980">Adam optimizer</a> to adjust learning rates during the process:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">Flux</span><span class="x">,</span> <span class="n">ProgressMeter</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">opt_state</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">setup</span><span class="x">(</span><span class="n">Flux</span><span class="o">.</span><span class="n">Adam</span><span class="x">(</span><span class="mf">0.01</span><span class="x">),</span> <span class="n">model</span><span class="x">)</span> 
<span class="nd">@showprogress</span> <span class="k">for</span> <span class="n">step</span> <span class="k">in</span> <span class="mi">1</span><span class="o">:</span><span class="mi">1_000</span>
    <span class="n">X</span><span class="x">,</span> <span class="n">Y</span> <span class="o">=</span> <span class="n">get_shifted_batch</span><span class="x">(</span><span class="n">train_data</span><span class="x">,</span> <span class="n">context_size</span><span class="x">,</span> <span class="n">batch_size</span><span class="x">)</span>
    <span class="n">batch_loss</span><span class="x">,</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">Flux</span><span class="o">.</span><span class="n">withgradient</span><span class="x">(</span><span class="n">model</span><span class="x">)</span> <span class="k">do</span> <span class="n">m</span>
        <span class="n">full_loss</span><span class="x">(</span><span class="n">m</span><span class="x">(</span><span class="n">X</span><span class="x">),</span> <span class="n">Y</span><span class="x">)</span>
    <span class="k">end</span>
    <span class="n">Flux</span><span class="o">.</span><span class="n">update!</span><span class="x">(</span><span class="n">opt_state</span><span class="x">,</span> <span class="n">model</span><span class="x">,</span> <span class="n">grads</span><span class="x">[</span><span class="mi">1</span><span class="x">])</span>
<span class="k">end</span></code></pre></figure>

<p>This works well enough, but will require many more steps to train. I recommend at least 10 epochs, where one epoch is defined as $0.95L/(nB)$ steps. 
Then based on the logic in <a href="#batch-generation">Batch Generation</a> each character at each position in the text should appear approximately once per epoch.
For $L=4.9\times10^6$, $n=64$ and $B=32$, this is 2,300 steps per epoch.</p>

<p>Please see my <a href="https://github.com/LiorSinai/TransformersLite-Examples/blob/main/common/training.jl">training.jl</a> file for a <code class="language-plaintext highlighter-rouge">train!</code> function which also does the following:</p>
<ul>
  <li>Displays a running total of the latest batch loss and the mean batch loss.</li>
  <li>Calculates the total loss and accuracy at the end of each epoch.</li>
  <li>Returns a history <code class="language-plaintext highlighter-rouge">Dictionary</code> which saves these values for each epoch for each metric.</li>
</ul>

<h2 id="evaluation">5 Evaluation</h2>
<h3 id="evaluation-qualitive">5.1 Qualitative</h3>

<p>After the model has been properly trained, we can test how well it generates text:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">context</span> <span class="o">=</span> <span class="n">reshape</span><span class="x">([</span><span class="mi">1</span><span class="x">],</span> <span class="mi">1</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span> <span class="c"># start with the new line symbol</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">generate</span><span class="x">(</span><span class="n">model</span><span class="x">,</span> <span class="n">context</span><span class="x">;</span> <span class="n">context_size</span><span class="o">=</span><span class="mi">64</span><span class="x">,</span> <span class="n">max_tokens</span><span class="o">=</span><span class="mi">300</span><span class="x">)</span>
<span class="n">decoded_text</span> <span class="o">=</span> <span class="n">join</span><span class="x">(</span><span class="n">decode</span><span class="x">(</span><span class="n">indexer</span><span class="x">,</span> <span class="n">out</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="mi">1</span><span class="x">]))</span></code></pre></figure>

<blockquote><pre>
Enter which at con to pratele-timen, man,
Nus maxchant newall the strainans, spauks wring-all likell come bein?

PAGLERANIA.
I not all sakompty hanet are the our
parry adme is waith
On shalt, full, in mety infor to thee I pater: let fathing
you, do taks you mail was the mascain
Am.
Him fore our waka
</pre></blockquote>

<p>The output is not fully cohesive and is not proper English.
However, there are many true English words and the made-up words follow general English patterns.
The general structure of the output matches the <a href="#preparation">input structure</a>. 
Characters (Actors) are introduced in all capitals.</p>

<p>We could also use a prompt, for example a famous line from the input:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">tokens</span> <span class="o">=</span> <span class="n">indexer</span><span class="x">(</span><span class="n">collect</span><span class="x">(</span><span class="s">"To be, or not to be, that is the question:</span><span class="se">\n</span><span class="s">"</span><span class="x">));</span>
<span class="n">context</span> <span class="o">=</span> <span class="n">reshape</span><span class="x">(</span><span class="n">tokens</span><span class="x">,</span> <span class="o">:</span><span class="x">,</span> <span class="mi">1</span><span class="x">)</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">generate</span><span class="x">(</span><span class="n">model</span><span class="x">,</span> <span class="n">context</span><span class="x">;</span> <span class="n">context_size</span><span class="o">=</span><span class="mi">64</span><span class="x">,</span> <span class="n">max_tokens</span><span class="o">=</span><span class="mi">300</span><span class="x">)</span>
<span class="n">decoded_text</span> <span class="o">=</span> <span class="n">join</span><span class="x">(</span><span class="n">decode</span><span class="x">(</span><span class="n">indexer</span><span class="x">,</span> <span class="n">out</span><span class="x">[</span><span class="o">:</span><span class="x">,</span> <span class="mi">1</span><span class="x">]))</span></code></pre></figure>

<blockquote><pre>
To be, or not to be, that is the question:
Of them I conful usall but as dull will henow
I wold you stay shaked marce, I witth all mine Ren, to siven,
Thoumbeines.

GUERD.
Swetr with they bloctain now tires’d do stord’s my leed.

NAWIPAR.
And then tillf’s broky! house stoop lord
you lay’d beater of Ettion say.

DUKEnge what by to and King an
</pre></blockquote>

<p>It does not reproduce the actual line in the text (“Whether ‘tis nobler in the mind to suffer
“). However, as before, the output is not entirely nonsense and at least has the correct strcture.</p>

<h3 id="evaluation-quantitative">5.2. Quantitative</h3>

<p>Calculate the mean loss and perplexity again:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">mean_loss</span> <span class="o">=</span> <span class="mf">0.0f0</span>
<span class="nd">@showprogress</span> <span class="k">for</span> <span class="n">step</span> <span class="k">in</span> <span class="mi">1</span><span class="o">:</span><span class="mi">1000</span>
    <span class="n">X</span><span class="x">,</span> <span class="n">Y</span> <span class="o">=</span> <span class="n">get_shifted_batch</span><span class="x">(</span><span class="n">val_data</span><span class="x">,</span> <span class="mi">64</span><span class="x">,</span> <span class="mi">32</span><span class="x">)</span>
    <span class="n">mean_loss</span> <span class="o">+=</span> <span class="n">full_loss</span><span class="x">(</span><span class="n">model</span><span class="x">(</span><span class="n">X</span><span class="x">),</span> <span class="n">Y</span><span class="x">)</span>
<span class="k">end</span>
<span class="n">mean_loss</span> <span class="o">/=</span> <span class="mi">1000</span>
<span class="n">perplexity</span> <span class="o">=</span> <span class="n">exp</span><span class="x">(</span><span class="n">mean_loss</span><span class="x">)</span></code></pre></figure>

<p>The mean loss is 1.853 and the perplexity is 6.379.
This is a significant improvement from the initialisation values of 4.277 and 72.0 respectively.</p>

<h2 id="inspection">6 Inspection</h2>
<h3 id="inspect-embeddings">6.1 Embeddings</h3>

<p>For the most part the model we have created is black box. There are however various techniques to inspect the model. For example, cosine similarities which was showcased in the <a href="#position-encoding">Position Encoding</a> section.</p>

<p>Another popular technique is to visually examine the embeddings after dimension reduction. For example our model has a dimension of 32, and we can reduce this to 2 dimensions and then create a 2D scatter plot. The popular techniques to do this are PCA (Principal Component Analysis) and t-SNE (t-distributed Stochastic Neighbor Embedding). t-SNE starts with PCA and iterates to give better looking results.</p>

<p>Here is an implementation of t-SNE with Julia:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">TSne</span>
<span class="n">W</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">embedding</span><span class="o">.</span><span class="n">weight</span> <span class="c"># or transpose(model.head.weight)</span>
<span class="n">reduce_dims</span><span class="x">,</span> <span class="n">max_iter</span><span class="x">,</span> <span class="n">perplexit</span> <span class="o">=</span> <span class="mi">0</span><span class="x">,</span> <span class="mi">1000</span><span class="x">,</span> <span class="mf">20.0</span>
<span class="n">Y</span> <span class="o">=</span> <span class="n">tsne</span><span class="x">(</span><span class="n">transpose</span><span class="x">(</span><span class="n">W</span><span class="x">),</span> <span class="mi">2</span><span class="x">,</span> <span class="n">reduce_dims</span><span class="x">,</span> <span class="n">max_iter</span><span class="x">,</span> <span class="n">perplexit</span><span class="x">);</span>
<span class="n">scatter</span><span class="x">(</span><span class="n">Y</span><span class="x">[</span><span class="o">:</span><span class="x">,</span><span class="mi">1</span><span class="x">],</span> <span class="n">Y</span><span class="x">[</span><span class="o">:</span><span class="x">,</span><span class="mi">2</span><span class="x">],</span> <span class="n">series_annotations</span><span class="o">=</span><span class="n">vocabulary</span><span class="x">,</span> 
    <span class="n">markeralpha</span><span class="o">=</span><span class="mf">0.0</span><span class="x">,</span>
    <span class="n">label</span><span class="o">=</span><span class="s">""</span><span class="x">,</span>
    <span class="n">aspectratio</span><span class="o">=:</span><span class="n">equal</span>
<span class="x">)</span></code></pre></figure>

<p>where the <code class="language-plaintext highlighter-rouge">vocabulary</code> is:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">vocabulary</span> <span class="o">=</span> <span class="n">string</span><span class="o">.</span><span class="x">(</span><span class="n">indexer</span><span class="o">.</span><span class="n">vocabulary</span><span class="x">)</span>
<span class="n">vocabulary</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span> <span class="o">=</span> <span class="n">string</span><span class="x">(</span><span class="kt">Int</span><span class="x">(</span><span class="n">indexer</span><span class="o">.</span><span class="n">vocabulary</span><span class="x">[</span><span class="mi">1</span><span class="x">]))</span> <span class="c">#\n =&gt; 10</span>
<span class="n">vocabulary</span><span class="x">[</span><span class="mi">2</span><span class="x">]</span> <span class="o">=</span> <span class="n">string</span><span class="x">(</span><span class="kt">Int</span><span class="x">(</span><span class="n">indexer</span><span class="o">.</span><span class="n">vocabulary</span><span class="x">[</span><span class="mi">2</span><span class="x">]))</span> <span class="c">#' '=&gt; 32</span></code></pre></figure>

<p>The output:</p>
<figure class="post-figure">
    <div class="row">
        <div class="col">
            <img class="img-fluid" src="/assets/posts/transformers/embedding_tsne.png" alt="Embedding t-SNE" />
        </div>
        <div class="col">
            <img class="img-fluid" src="/assets/posts/transformers/head_tsne.png" alt="Head t-SNE" />
        </div>
    </div>
    <figcaption>t-SNE embeddings for the embedding matrix (left) and head matrix (right). New line is 10 and space is 32.</figcaption>
</figure>

<p>Note that t-SNE is stochastic and each run will give different results.</p>

<p>For the embedding matrix we can see that the model groups all the vowels (a, e, i, o, u) and their capital forms together. It also tends to group the lowercase form and uppercase form together e.g. ‘g’ and ‘G’. The head meanwhile has 3 distinct groups: capital letters, punctuation and lower case letters. It also groups the vowels together.</p>

<p>Perhaps with further training more meaning would be encoded into these vectors.</p>

<h3 id="inspect-attention">6.2 Attention scores</h3>

<p>We can pass an input to the model and visually inspect the attention scores.
To do this we need to alter the <a href="#transformer-blocks">transformer functions</a> to return the score as well (including reshaping it as needed).
At the top level - the forward pass of the model - these scores should be saved in a vector.
Then we can plot them:</p>

<p>
  <a class="btn" data-toggle="collapse" href="#code-scores-plot" role="button" aria-expanded="false" aria-controls="collapseExample">
    Code for scores plot &#8681;
  </a>
</p>
<div class="collapse" id="code-scores-plot">
  <div class="card card-body ">
    <code><pre>
using Plots
text = """LYSANDER.
How now, my love? Why is your cheek so pale?
How chance the roses there do fade so fast?"""
tokens = reshape(indexer(collect(text)), :, 1);
X = tokens[1:context_size, :];
X_text = decode(indexer, X[:, 1]);
Y, scores = predict_with_scores(model, X, mask=model.mask); # modified forward pass
s = scores[3][:, :, 3, 1]
s = ifelse.(model.mask, s, NaN)
heatmap(s,
    xticks=(1:context_size, X_text),
    yticks=(1:context_size, X_text),
    yrotation=90,
    aspectratio=:equal,
    xlims=(0.5, n+0.5),
    size=(500, 500),
)
</pre></code>
  </div>
</div>

<figure class="post-figure" id="fig-attention-scores-block3-head3">
<img class="img-60" src="/assets/posts/transformers/attention_scores_block3_head3.png" alt="Attention scores for block 3, head 3" />
<figcaption>Attention scores for block 3, head 3.</figcaption>
</figure>

<p>The attention matrices are very sparse. 
Most tokens only place emphasis on the four or less tokens directly before them.
This suggests we could have used a much smaller context length, for example 16 and indeed that does work.</p>

<p>Ideally the model should be learning long range relationships and it is worrying that it is not.</p>

<p>That said, the model does confidently predict that after “How chanc” is an “e”:</p>
<p>
  <a class="btn" data-toggle="collapse" href="#code-probs-plot" role="button" aria-expanded="false" aria-controls="collapseExample">
    Code for probability plot &#8681;
  </a>
</p>
<div class="collapse" id="code-probs-plot">
  <div class="card card-body ">
    <code><pre>
using Plots
probs_next = softmax(Y[:, end, 1])
v = length(indexer.vocabulary)
bar(probs_next,
    xticks=(1:v, indexer.vocabulary),
    xlims=(1, v),
    label="",
    ylabel="probabilities",
    xlabel="tokens"
)
</pre></code>
  </div>
</div>

<figure class="post-figure" id="fig-prob-next">
<img class="img-60" src="/assets/posts/transformers/probs_next.png" alt="Probability next" />
<figcaption>Probabilities for the next token for the last token in the sequence.</figcaption>
</figure>

<p>Perhaps with more training the model would give better results.</p>

<h2 id="conclusion">Conclusion</h2>

<p>Thank you for following this tutorial.
I hope you now have a working transformer and have much better insight into how they work.</p>

<hr />

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:cosine" role="doc-endnote">
      <p>The cosine similarity is calculated as $W^TW/ m^T m $ where $m_{1j}=\sqrt{\sum_i W_{ij}^2}$ for each column $j$ in $W$.
In code:</p>

      <pre><code class="language-juliajtt">using LinearAlgebra
function cosine_similarity(W::AbstractMatrix)
    sim = transpose(W) * W
    magnitudes = sqrt.(diag(sim))
    for i in 1:size(sim, 1)
        for j in 1:size(sim, 2)
            sim[i, j] /= magnitudes[i] * magnitudes[j]
        end
    end
    sim
end
</code></pre>
      <p><a href="#fnref:cosine" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:tensors" role="doc-endnote">
      <p>In general multiplication is not defined for higher order arrays. But there is a set of multidimensional algebraic objects called <a href="https://en.wikipedia.org/wiki/Tensor">tensors</a> where it is. 
Confusingly, Google named their machine learning framework TensorFlow and calls higher order arrays tensors.
So one should differentiate between machine learning tensors and geometric tensors.
They are not the same.
To give a simple explanation: one can think of geometric tensors as higher order arrays with severe constraints on their entries and operations because they represent geometric objects. These constraints make it harder - not easier - to code higher order arrays as geometric tensors. <a href="#fnref:tensors" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:block_scores" role="doc-endnote">
      <p>The design decision is to purposely drop the attention scores in the <code class="language-plaintext highlighter-rouge">TransformerBlock</code>’s forward pass. This is to simplify the code and to not place a bias on the attention.
In a typical block the <code class="language-plaintext highlighter-rouge">MultiheadAttention</code> layer will make up 1/3rd of parameters while the dense layers will make up 2/3rds, so the dense layers are potentially more important.
To return the scores you can edit the forward pass for the block and model, or create two new functions entirely. <a href="#fnref:block_scores" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:split" role="doc-endnote">
      <p>A smarter strategy is to randomly sample passages throughout the text until the desired proportions are reached. <a href="#fnref:split" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Lior Sinai</name></author><category term="machine-learning" /><category term="mathematics" /><category term="transformers" /><category term="&apos;machine" /><category term="learning&apos;" /><category term="&apos;deep" /><category term="learning&apos;" /><summary type="html"><![CDATA[A transformer for generating text in Julia, trained on Shakespeare’s plays. This model can be used as a Generative Pre-trained Transformer (GPT) with further work. This post was inspired by Andrej Karpathy’s Zero to Hero course.]]></summary></entry><entry><title type="html">Radix Tree in Julia</title><link href="https://liorsinai.github.io/coding/2024/03/21/radix-tree.html" rel="alternate" type="text/html" title="Radix Tree in Julia" /><published>2024-03-21T00:00:00+00:00</published><updated>2024-03-21T00:00:00+00:00</updated><id>https://liorsinai.github.io/coding/2024/03/21/radix-tree</id><content type="html" xml:base="https://liorsinai.github.io/coding/2024/03/21/radix-tree.html"><![CDATA[<p><em>A radix tree in Julia, built following Test Driven Development (TDD).</em></p>

<h3 id="table-of-contents">Table of Contents</h3>

<nav id="toc"></nav>
<script src="/assets/makeTableOfContents.js"></script>

<h2 id="introduction">1 Introduction</h2>

<p>I recently discovered <a href="https://en.wikipedia.org/wiki/Radix_tree">radix trees</a>, also known as compressed tries. They are a specialised, space-optimised data structure for storing and searching through strings.
They can be used for text suggestions in search engines and for predictive text.
They are used in databases for storing IP addresses and for the <a href="https://www.geeksforgeeks.org/inverted-index/">inverted index</a> of search engines.<sup id="fnref:inverted_index" role="doc-noteref"><a href="#fn:inverted_index" class="footnote" rel="footnote">1</a></sup></p>

<figure class="post-figure">
    <img class="img-80" src="/assets/posts/radix-tree/wiki_radix_tree.png" alt="radix tree" />
    <figcaption>Source: <a href="https://en.wikipedia.org/wiki/Radix_tree">en.wikipedia.org/wiki/Radix_tree</a>.</figcaption>
</figure>

<p>The above figure shows an example of a radix tree. Each edge stores part of a string. 
The full string can be recovered by combining all the edges of the parents of a given node.
Searching through the tree is $\mathcal{O}(\log_r(n))$ where $r$ is called the radix of the tree and $n$ is the total number of items stored in the tree.</p>

<p>This post describes how to build one in Julia.
I’ll be following Test Driven Development (TDD) for part of the process.</p>

<p>As always, the full code can be viewed at my Github repository at <a href="https://github.com/LiorSinai/RadixTree.jl">github.com/LiorSinai/RadixTree.jl</a>.</p>

<p>I’d like to note upfront that radix trees are not always the best solution for text search.
In particular, binary search through a sorted linear list is $\mathcal{O}(\log_2(n))$ and is much simpler.
In Julia the inbuilt <code class="language-plaintext highlighter-rouge">searchsortedfirst</code> function does this:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">idx</span> <span class="o">=</span> <span class="n">searchsortedfirst</span><span class="x">(</span><span class="n">sorted_words</span><span class="x">,</span> <span class="n">key</span><span class="x">)</span></code></pre></figure>

<p>So this is partly an academic exercise.</p>

<h2 id="implementation">2 Implementation</h2>

<h3 id="project-setup-optional">Project setup (optional)</h3>

<p>To start, make a package in the Julia REPL:</p>
<figure class="highlight">
    <code class="language-julia-repl hljs" data-lang="julia-repl">
        <span class="hljs-meta">julia&gt;</span><span class="julia"> cd(<span class="hljs-string">"path\\to\\project"</span>)</span>
        <br />
        <span class="hljs-meta">julia&gt;</span><span class="julia"> ] <span class="hljs-comment"># enter package mode</span></span>
        <br />
        <span class="hljs-meta">(@v1.x) pkg&gt;</span><span class="julia"> generate RadixTree <span class="hljs-comment"># make a directory structure</span></span>
        <br /> 
        <span class="hljs-meta">(@v1.x) pkg&gt;</span><span class="julia"> dev "path\\to\\project\\RadixTree"</span>
    </code>
</figure>

<p>The purpose of making a package is that we can now use the super helpful Revise package,
which will dynamically update most changes during development without errors:</p>

<figure class="highlight"><pre><code class="language-julia-repl" data-lang="julia-repl">julia&gt; using Revise
julia&gt; using RadixTree</code></pre></figure>

<h3 id="radixtreenode">RadixTreeNode</h3>

<p>My goal is to create a simple radix tree where each node stores a string.
In this way the tree functions as a type of array.<sup id="fnref:Dictionary" role="doc-noteref"><a href="#fn:Dictionary" class="footnote" rel="footnote">2</a></sup>
The struct looks like:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">mutable struct</span><span class="nc"> RadixTreeNode</span><span class="x">{</span><span class="n">T</span><span class="o">&lt;:</span><span class="kt">AbstractString</span><span class="x">}</span>
    <span class="n">data</span><span class="o">::</span><span class="n">T</span>
    <span class="n">is_label</span><span class="o">::</span><span class="kt">Bool</span>
    <span class="n">children</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="o">&lt;:</span><span class="n">RadixTreeNode</span><span class="x">}</span>
<span class="k">end</span>

<span class="n">RadixTreeNode</span><span class="x">(</span><span class="n">data</span><span class="o">::</span><span class="n">T</span><span class="o">=</span><span class="s">""</span><span class="x">,</span> <span class="n">label</span><span class="o">::</span><span class="kt">Bool</span><span class="o">=</span><span class="nb">false</span><span class="x">)</span> <span class="k">where</span> <span class="n">T</span> <span class="o">=</span> 
    <span class="n">RadixTreeNode</span><span class="x">{</span><span class="n">T</span><span class="x">}(</span><span class="n">data</span><span class="x">,</span> <span class="n">label</span><span class="x">,</span> <span class="n">RadixTreeNode</span><span class="x">{</span><span class="n">T</span><span class="x">}[])</span></code></pre></figure>

<p>In Julia an immutable <code class="language-plaintext highlighter-rouge">struct</code> is usually preferable because the compiler can more easily optimise code for it.
However here we will often need to change the data field during inserts, and so require a <code class="language-plaintext highlighter-rouge">mutable struct</code>.</p>

<p>The whole tree will be accessed through the first node, which is called the root:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">root</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">()</span> <span class="c"># RadixTreeNode{String}("", false, RadixTreeNode{String}[])</span></code></pre></figure>

<p>If we store children in the root then the default printing will print them too:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">root</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">{</span><span class="kt">String</span><span class="x">}(</span><span class="s">""</span><span class="x">,</span> <span class="nb">false</span><span class="x">,</span> <span class="x">[</span><span class="n">RadixTreeNode</span><span class="x">(</span><span class="s">"a"</span><span class="x">,</span> <span class="nb">true</span><span class="x">),</span> <span class="n">RadixTreeNode</span><span class="x">(</span><span class="s">"b"</span><span class="x">,</span> <span class="nb">true</span><span class="x">)])</span>
<span class="cm">#= RadixTreeNode{String}("", false, RadixTreeNode{String}[RadixTreeNode{String}("a", true, RadixTreeNode{String}[]), RadixTreeNode{String}("b", true, RadixTreeNode{String}[])]) =#</span></code></pre></figure>

<p>This will get out of hand for a large tree, as it will print the entire tree.
To avoid this, we can create a custom printing function which will only print the data for the immediate children of a node:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">children_data</span><span class="x">(</span><span class="n">node</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">)</span> <span class="o">=</span> <span class="x">[</span><span class="n">child</span><span class="o">.</span><span class="n">data</span> <span class="k">for</span> <span class="n">child</span> <span class="k">in</span> <span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">]</span>

<span class="k">function</span><span class="nf"> Base.show</span><span class="x">(</span><span class="n">io</span><span class="o">::</span><span class="kt">IO</span><span class="x">,</span> <span class="n">node</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">)</span>
    <span class="n">print</span><span class="x">(</span><span class="n">io</span><span class="x">,</span> <span class="n">typeof</span><span class="x">(</span><span class="n">node</span><span class="x">))</span>
    <span class="n">print</span><span class="x">(</span><span class="n">io</span><span class="x">,</span> <span class="s">"(data="</span><span class="x">,</span> <span class="n">node</span><span class="o">.</span><span class="n">data</span><span class="x">)</span>
    <span class="n">print</span><span class="x">(</span><span class="n">io</span><span class="x">,</span> <span class="s">", is_label="</span><span class="x">,</span> <span class="n">node</span><span class="o">.</span><span class="n">is_label</span><span class="x">)</span>
    <span class="n">print</span><span class="x">(</span><span class="n">io</span><span class="x">,</span> <span class="s">", children="</span><span class="x">,</span> <span class="n">children_data</span><span class="x">(</span><span class="n">node</span><span class="x">))</span>
    <span class="n">print</span><span class="x">(</span><span class="n">io</span><span class="x">,</span> <span class="s">")"</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>Now if we <code class="language-plaintext highlighter-rouge">print(root)</code> we get:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="cm">#= RadixTreeNode{String}(data=, is_label=false, children=["a", "b"]) =#</span></code></pre></figure>

<p>We can create other helper functions for the <code class="language-plaintext highlighter-rouge">RadixTreeNode</code>:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">Base</span><span class="o">.</span><span class="n">eltype</span><span class="x">(</span><span class="n">node</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">{</span><span class="n">T</span><span class="x">})</span> <span class="k">where</span> <span class="n">T</span> <span class="o">=</span> <span class="n">T</span>
<span class="n">children</span><span class="x">(</span><span class="n">node</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">)</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">children</span>
<span class="n">is_leaf</span><span class="x">(</span><span class="n">node</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">)</span> <span class="o">=</span> <span class="n">isempty</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">)</span></code></pre></figure>

<h3 id="search">Search</h3>

<p>We can use a very basic example to create and test a search function. See the tree below:</p>

<figure class="post-figure">
    <img class="img-80" src="/assets/posts/radix-tree/radix_tree_get.png" alt="radix tree get" />
</figure>

<p>We can construct it directly as:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">root</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">{</span><span class="kt">String</span><span class="x">}(</span>
    <span class="s">""</span><span class="x">,</span> <span class="nb">false</span><span class="x">,</span> <span class="x">[</span> 
            <span class="n">RadixTreeNode</span><span class="x">{</span><span class="kt">String</span><span class="x">}(</span><span class="s">"te"</span><span class="x">,</span> <span class="nb">false</span><span class="x">,</span> 
            <span class="x">[</span>
                <span class="n">RadixTreeNode</span><span class="x">(</span><span class="s">"am"</span><span class="x">),</span> <span class="n">RadixTreeNode</span><span class="x">(</span><span class="s">"st"</span><span class="x">)</span>
            <span class="x">]</span>
        <span class="x">)</span>
    <span class="x">]</span>
<span class="x">)</span></code></pre></figure>

<p>The goal of the search algorithm is to return the deepest node in the tree that matches the given <code class="language-plaintext highlighter-rouge">key</code>.
We would also like to know how many letters are matched.
We can make the following two tests:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">using</span> <span class="n">Test</span>
<span class="n">node</span><span class="x">,</span> <span class="n">num_found</span> <span class="o">=</span> <span class="n">get</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"hello"</span><span class="x">)</span>
<span class="nd">@test</span> <span class="n">node</span> <span class="o">==</span> <span class="n">root</span> <span class="o">&amp;&amp;</span> <span class="n">num_found</span> <span class="o">==</span> <span class="mi">0</span>
<span class="n">node</span><span class="x">,</span> <span class="n">num_found</span> <span class="o">=</span> <span class="n">get</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"team"</span><span class="x">)</span>
<span class="nd">@test</span> <span class="n">node</span> <span class="o">==</span> <span class="n">root</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span> <span class="o">&amp;&amp;</span> <span class="n">num_found</span> <span class="o">==</span> <span class="mi">4</span></code></pre></figure>

<p>The algorithm on <a href="https://en.wikipedia.org/wiki/Radix_tree">Wikipedia</a> is as follows:</p>
<ol>
  <li>Check if any child has a matching prefix with the key.</li>
  <li>Chop off the matching prefix (keep the suffix) of the key and set the node to the child.</li>
  <li>Repeat steps 1-2. Stop when:
    <ul>
      <li>There is no matching prefix.</li>
      <li>Or the node is a leaf (has no children).</li>
      <li>Or all the letters are matched.</li>
    </ul>
  </li>
</ol>

<p>Here is the full algorithm in code:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> Base.get</span><span class="x">(</span><span class="n">root</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">,</span> <span class="n">key</span><span class="o">::</span><span class="kt">AbstractString</span><span class="x">)</span>
    <span class="n">node</span> <span class="o">=</span> <span class="n">root</span>
    <span class="n">num_found</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">suffix</span> <span class="o">=</span> <span class="n">key</span>
    <span class="k">while</span> <span class="o">!</span><span class="x">(</span><span class="n">isnothing</span><span class="x">(</span><span class="n">node</span><span class="x">))</span> <span class="o">&amp;&amp;</span> <span class="o">!</span><span class="x">(</span><span class="n">is_leaf</span><span class="x">(</span><span class="n">node</span><span class="x">))</span> <span class="o">&amp;&amp;</span> <span class="x">(</span><span class="n">num_found</span> <span class="o">&lt;</span> <span class="n">length</span><span class="x">(</span><span class="n">key</span><span class="x">))</span>
        <span class="n">child</span> <span class="o">=</span> <span class="n">search_children</span><span class="x">(</span><span class="n">node</span><span class="x">,</span> <span class="n">suffix</span><span class="x">)</span>
        <span class="k">if</span> <span class="n">isnothing</span><span class="x">(</span><span class="n">child</span><span class="x">)</span>
            <span class="n">break</span>
        <span class="k">end</span>
        <span class="n">node</span> <span class="o">=</span> <span class="n">child</span>
        <span class="n">num_found</span> <span class="o">+=</span> <span class="n">length</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">data</span><span class="x">)</span>
        <span class="n">suffix</span> <span class="o">=</span> <span class="n">get_suffix</span><span class="x">(</span><span class="n">suffix</span><span class="x">,</span> <span class="n">length</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">data</span><span class="x">))</span>
    <span class="k">end</span>
    <span class="n">node</span><span class="x">,</span> <span class="n">num_found</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> get_suffix</span><span class="x">(</span><span class="n">s</span><span class="o">::</span><span class="kt">AbstractString</span><span class="x">,</span> <span class="n">head</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span>
    <span class="k">if</span> <span class="n">isempty</span><span class="x">(</span><span class="n">s</span><span class="x">)</span>
        <span class="k">return</span> <span class="n">s</span>
    <span class="k">end</span>
    <span class="n">s</span><span class="x">[</span><span class="n">nextind</span><span class="x">(</span><span class="n">s</span><span class="x">,</span> <span class="n">firstindex</span><span class="x">(</span><span class="n">s</span><span class="x">),</span> <span class="n">head</span><span class="x">)</span><span class="o">:</span><span class="k">end</span><span class="x">]</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> search_children</span><span class="x">(</span><span class="n">node</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">,</span> <span class="n">key</span><span class="o">::</span><span class="kt">AbstractString</span><span class="x">)</span>
    <span class="k">for</span> <span class="n">child</span> <span class="k">in</span> <span class="n">node</span><span class="o">.</span><span class="n">children</span>
        <span class="k">if</span> <span class="n">startswith</span><span class="x">(</span><span class="n">key</span><span class="x">,</span> <span class="n">child</span><span class="o">.</span><span class="n">data</span><span class="x">)</span>
            <span class="k">return</span> <span class="n">child</span>
        <span class="k">end</span>
    <span class="k">end</span>
<span class="k">end</span></code></pre></figure>

<p>This passes both tests.</p>

<p>Some comments:</p>
<ul>
  <li>These functions are fully compatible with unicode strings. See this <a href="https://en.wikibooks.org/wiki/Introducing_Julia/Strings_and_characters#Unicode_strings">tutorial</a> for more information.</li>
  <li>The <code class="language-plaintext highlighter-rouge">get_suffix</code> function may also be implemented using <code class="language-plaintext highlighter-rouge">chop(s; head=head, tail=0)</code> which returns <code class="language-plaintext highlighter-rouge">SubString</code> instead of <code class="language-plaintext highlighter-rouge">String</code>. Working directly with strings seems to reduce memory allocations.</li>
  <li>The <code class="language-plaintext highlighter-rouge">search_children</code> function can be made faster with binary search. But in practice the child arrays tend to be small so this is not essential.</li>
</ul>

<p>A question is, what will <code class="language-plaintext highlighter-rouge">get(root, "tea")</code> return?
Technically “tea” is in the tree, split up as “te” and “am”.
However this function is purposely limited to only full matching prefixes and not partial matches.
Hence the “te” node will be returned with a match length of 2.</p>

<h3 id="insert">Insert</h3>

<figure class="post-figure">
    <img class="img-80" src="/assets/posts/radix-tree/radix_tree_insert.png" alt="radix tree insert examples" />
</figure>

<p>The <a href="https://en.wikipedia.org/wiki/Radix_tree">Wikipedia</a> page has a fairly complex insert example.
I’m instead going to work through four simple examples, extending the <code class="language-plaintext highlighter-rouge">insert!</code> function each time to make the tests pass.
By the end the function will be able to handle all scenarios.</p>

<h4 id="1-insert-in-order">1 Insert in order</h4>

<p>For efficient search we want the children inserted in order. Our test is:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">root</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">()</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"t"</span><span class="x">)</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"z"</span><span class="x">)</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"a"</span><span class="x">)</span>
<span class="nd">@test</span> <span class="n">root</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">==</span> <span class="s">"a"</span>
<span class="nd">@test</span> <span class="n">root</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">2</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">==</span> <span class="s">"t"</span>
<span class="nd">@test</span> <span class="n">root</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">3</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">==</span> <span class="s">"z"</span></code></pre></figure>

<p>For a given key we first need to find which node to insert it at (<code class="language-plaintext highlighter-rouge">get</code>) then we can use <code class="language-plaintext highlighter-rouge">searchsortedfirst</code> to find which index to put it in:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> Base.insert!</span><span class="x">(</span><span class="n">root</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">key</span><span class="o">::</span><span class="kt">AbstractString</span><span class="x">)</span> <span class="k">where</span> <span class="n">T</span>
    <span class="n">node</span><span class="x">,</span> <span class="n">match_length</span> <span class="o">=</span> <span class="n">get</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="n">key</span><span class="x">)</span>
    <span class="n">new_node</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">(</span><span class="n">key</span><span class="x">,</span> <span class="nb">true</span><span class="x">)</span>
    <span class="n">idx</span> <span class="o">=</span> <span class="n">searchsortedfirst</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">,</span> <span class="n">new_node</span><span class="x">;</span> <span class="n">lt</span><span class="o">=</span><span class="x">(</span><span class="n">n1</span><span class="x">,</span> <span class="n">n2</span><span class="x">)</span><span class="o">-&gt;</span><span class="n">n1</span><span class="o">.</span><span class="n">data</span> <span class="o">&lt;</span> <span class="n">n2</span><span class="o">.</span><span class="n">data</span><span class="x">)</span>
    <span class="n">insert!</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">,</span> <span class="n">idx</span><span class="x">,</span> <span class="n">new_node</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>And all our tests pass.</p>

<h4 id="2-extend">2 Extend</h4>

<p>If we add strings which share prefixes with existing nodes, then we only want to extend by the suffix. Our test is:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">root</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">(</span><span class="s">""</span><span class="x">)</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"s"</span><span class="x">)</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"slow"</span><span class="x">)</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"slowly"</span><span class="x">)</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"slower"</span><span class="x">)</span>
<span class="nd">@test</span> <span class="n">root</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">==</span> <span class="s">"s"</span>
<span class="nd">@test</span> <span class="n">root</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">==</span> <span class="s">"low"</span>
<span class="nd">@test</span> <span class="n">root</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">==</span> <span class="s">"er"</span>
<span class="nd">@test</span> <span class="n">root</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">2</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">==</span> <span class="s">"ly"</span></code></pre></figure>

<p>The new code is:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> Base.insert!</span><span class="x">(</span><span class="n">root</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">key</span><span class="o">::</span><span class="kt">AbstractString</span><span class="x">)</span> <span class="k">where</span> <span class="n">T</span>
    <span class="n">node</span><span class="x">,</span> <span class="n">match_length</span> <span class="o">=</span> <span class="n">get</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="n">key</span><span class="x">)</span>
    <span class="n">suffix</span> <span class="o">=</span> <span class="n">get_suffix</span><span class="x">(</span><span class="n">key</span><span class="x">,</span> <span class="n">match_length</span><span class="x">)</span> <span class="c"># new</span>
    <span class="n">new_node</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">(</span><span class="n">T</span><span class="x">(</span><span class="n">suffix</span><span class="x">),</span> <span class="nb">true</span><span class="x">)</span> <span class="c"># edit</span>
    <span class="n">idx</span> <span class="o">=</span> <span class="n">searchsortedfirst</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">,</span> <span class="n">new_node</span><span class="x">;</span> <span class="n">lt</span><span class="o">=</span><span class="x">(</span><span class="n">n1</span><span class="x">,</span> <span class="n">n2</span><span class="x">)</span><span class="o">-&gt;</span><span class="n">n1</span><span class="o">.</span><span class="n">data</span> <span class="o">&lt;</span> <span class="n">n2</span><span class="o">.</span><span class="n">data</span><span class="x">)</span>
    <span class="n">insert!</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">,</span> <span class="n">idx</span><span class="x">,</span> <span class="n">new_node</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<h4 id="3-split">3 Split</h4>

<p>If we add a string which shares a prefix with an existing node, then we have to split that node.</p>

<p>Our test is:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">root</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">(</span><span class="s">""</span><span class="x">)</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"test"</span><span class="x">)</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"team"</span><span class="x">)</span>
<span class="nd">@test</span> <span class="n">root</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">==</span> <span class="s">"te"</span>
<span class="nd">@test</span> <span class="n">root</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">==</span> <span class="s">"am"</span>
<span class="nd">@test</span> <span class="n">root</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">2</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">==</span> <span class="s">"st"</span></code></pre></figure>

<p>Unlike before with <code class="language-plaintext highlighter-rouge">get</code>, we now will go the extra step of checking if any child overlaps with the remaining suffix.
This requires checking all prefixes up to the suffix length $s$ for all children $c$, so this is inherently an $\mathcal{O}(cs)$ operation.
If it does, we will <code class="language-plaintext highlighter-rouge">split!</code> that child into two and then add the suffix as a new child.
The child will only have two children - the suffix of the old data and this new suffix - so determining the order is straightforward.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> Base.insert!</span><span class="x">(</span><span class="n">root</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">key</span><span class="o">::</span><span class="kt">AbstractString</span><span class="x">)</span> <span class="k">where</span> <span class="n">T</span>
    <span class="n">node</span><span class="x">,</span> <span class="n">match_length</span> <span class="o">=</span> <span class="n">get</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="n">key</span><span class="x">)</span>
    <span class="n">suffix</span> <span class="o">=</span> <span class="n">get_suffix</span><span class="x">(</span><span class="n">key</span><span class="x">,</span> <span class="n">match_length</span><span class="x">)</span>
    <span class="n">child</span><span class="x">,</span> <span class="n">overlap</span> <span class="o">=</span> <span class="n">search_children_with_overlap</span><span class="x">(</span><span class="n">node</span><span class="x">,</span> <span class="n">suffix</span><span class="x">)</span> <span class="c"># new</span>
    <span class="k">if</span> <span class="n">isnothing</span><span class="x">(</span><span class="n">child</span><span class="x">)</span> <span class="c"># new</span>
        <span class="n">new_node</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">(</span><span class="n">T</span><span class="x">(</span><span class="n">suffix</span><span class="x">),</span> <span class="nb">true</span><span class="x">)</span>
        <span class="n">idx</span> <span class="o">=</span> <span class="n">searchsortedfirst</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">,</span> <span class="n">new_node</span><span class="x">;</span> <span class="n">lt</span><span class="o">=</span><span class="x">(</span><span class="n">n1</span><span class="x">,</span> <span class="n">n2</span><span class="x">)</span><span class="o">-&gt;</span><span class="n">n1</span><span class="o">.</span><span class="n">data</span> <span class="o">&lt;</span> <span class="n">n2</span><span class="o">.</span><span class="n">data</span><span class="x">)</span>
        <span class="n">insert!</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">,</span> <span class="n">idx</span><span class="x">,</span> <span class="n">new_node</span><span class="x">)</span>
    <span class="k">else</span> <span class="c"># new</span>
        <span class="n">node</span> <span class="o">=</span> <span class="n">child</span> <span class="c"># new</span>
        <span class="n">split!</span><span class="x">(</span><span class="n">node</span><span class="x">,</span> <span class="n">overlap</span><span class="x">)</span> <span class="c"># new</span>
        <span class="n">new_suffix</span> <span class="o">=</span> <span class="n">get_suffix</span><span class="x">(</span><span class="n">suffix</span><span class="x">,</span> <span class="n">overlap</span><span class="x">)</span> <span class="c"># new</span>
        <span class="n">new_node</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">(</span><span class="n">T</span><span class="x">(</span><span class="n">new_suffix</span><span class="x">),</span> <span class="nb">true</span><span class="x">)</span> <span class="c"># new</span>
        <span class="n">idx</span> <span class="o">=</span> <span class="n">new_node</span><span class="o">.</span><span class="n">data</span> <span class="o">&lt;</span> <span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">?</span> <span class="mi">1</span> <span class="o">:</span> <span class="mi">2</span> <span class="c"># new</span>
        <span class="n">insert!</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">,</span> <span class="n">idx</span><span class="x">,</span> <span class="n">new_node</span><span class="x">)</span> <span class="c"># new</span>
    <span class="k">end</span> <span class="c"># new</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> search_children_with_overlap</span><span class="x">(</span><span class="n">node</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">,</span> <span class="n">key</span><span class="o">::</span><span class="kt">AbstractString</span><span class="x">)</span>
    <span class="k">for</span> <span class="n">len_prefix</span> <span class="k">in</span> <span class="n">length</span><span class="x">(</span><span class="n">key</span><span class="x">)</span><span class="o">:-</span><span class="mi">1</span><span class="o">:</span><span class="mi">1</span>
        <span class="k">for</span> <span class="n">child</span> <span class="k">in</span> <span class="n">node</span><span class="o">.</span><span class="n">children</span>
            <span class="n">data</span> <span class="o">=</span> <span class="n">first</span><span class="x">(</span><span class="n">child</span><span class="o">.</span><span class="n">data</span><span class="x">,</span> <span class="n">len_prefix</span><span class="x">)</span>
            <span class="k">if</span> <span class="n">startswith</span><span class="x">(</span><span class="n">key</span><span class="x">,</span> <span class="n">data</span><span class="x">)</span>
                <span class="k">return</span> <span class="n">child</span><span class="x">,</span> <span class="n">min</span><span class="x">(</span><span class="n">len_prefix</span><span class="x">,</span> <span class="n">length</span><span class="x">(</span><span class="n">data</span><span class="x">))</span>
            <span class="k">end</span>
        <span class="k">end</span>
    <span class="k">end</span>
    <span class="nb">nothing</span><span class="x">,</span> <span class="mi">0</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> split!</span><span class="x">(</span><span class="n">node</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">i</span><span class="o">::</span><span class="kt">Int</span><span class="x">)</span> <span class="k">where</span> <span class="n">T</span>
    <span class="n">suffix</span> <span class="o">=</span> <span class="n">get_suffix</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">data</span><span class="x">,</span> <span class="n">i</span><span class="x">)</span>
    <span class="n">new_node</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">{</span><span class="n">T</span><span class="x">}(</span><span class="n">T</span><span class="x">(</span><span class="n">suffix</span><span class="x">),</span> <span class="n">node</span><span class="o">.</span><span class="n">is_label</span><span class="x">,</span> <span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">)</span>
    <span class="n">node</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">first</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">data</span><span class="x">,</span> <span class="n">i</span><span class="x">)</span>
    <span class="n">node</span><span class="o">.</span><span class="n">children</span> <span class="o">=</span> <span class="x">[</span><span class="n">new_node</span><span class="x">]</span>
    <span class="n">node</span><span class="o">.</span><span class="n">is_label</span> <span class="o">=</span> <span class="nb">false</span>
    <span class="n">node</span>
<span class="k">end</span></code></pre></figure>

<h4 id="4-split-with-no-add">4 Split with no add</h4>

<p>There are two extra scenarios we have to account for.
The first is if the word is already in the tree, in which case we should ignore it.
The second is if we add a word that is fully a prefix of another word, then we shouldn’t add a new node after splitting.</p>

<p>Our test is:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">root</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">()</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"team"</span><span class="x">)</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"team"</span><span class="x">)</span> <span class="c"># ignore</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"tea"</span><span class="x">)</span>
<span class="nd">@test</span> <span class="n">root</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">==</span> <span class="s">"tea"</span>
<span class="nd">@test</span> <span class="n">root</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">==</span> <span class="s">"m"</span></code></pre></figure>

<p>This requires extra checks:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> Base.insert!</span><span class="x">(</span><span class="n">root</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="n">key</span><span class="o">::</span><span class="kt">AbstractString</span><span class="x">)</span> <span class="k">where</span> <span class="n">T</span>
    <span class="n">node</span><span class="x">,</span> <span class="n">match_length</span> <span class="o">=</span> <span class="n">get</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="n">key</span><span class="x">)</span>
    <span class="k">if</span> <span class="n">match_length</span> <span class="o">==</span> <span class="n">length</span><span class="x">(</span><span class="n">key</span><span class="x">)</span> <span class="c"># new</span>
        <span class="n">node</span><span class="o">.</span><span class="n">is_label</span> <span class="o">=</span> <span class="nb">true</span> <span class="c"># new</span>
        <span class="k">return</span> <span class="c"># new</span>
    <span class="k">end</span>  <span class="c"># new</span>
    <span class="n">suffix</span> <span class="o">=</span> <span class="n">get_suffix</span><span class="x">(</span><span class="n">key</span><span class="x">,</span> <span class="n">match_length</span><span class="x">)</span>
    <span class="n">child</span><span class="x">,</span> <span class="n">overlap</span> <span class="o">=</span> <span class="n">search_children_with_overlap</span><span class="x">(</span><span class="n">node</span><span class="x">,</span> <span class="n">suffix</span><span class="x">)</span>
    <span class="k">if</span> <span class="n">isnothing</span><span class="x">(</span><span class="n">child</span><span class="x">)</span>
        <span class="n">new_node</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">(</span><span class="n">T</span><span class="x">(</span><span class="n">suffix</span><span class="x">),</span> <span class="nb">true</span><span class="x">)</span>
        <span class="n">idx</span> <span class="o">=</span> <span class="n">searchsortedfirst</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">,</span> <span class="n">new_node</span><span class="x">;</span> <span class="n">lt</span><span class="o">=</span><span class="x">(</span><span class="n">n1</span><span class="x">,</span> <span class="n">n2</span><span class="x">)</span><span class="o">-&gt;</span><span class="n">n1</span><span class="o">.</span><span class="n">data</span> <span class="o">&lt;</span> <span class="n">n2</span><span class="o">.</span><span class="n">data</span><span class="x">)</span>
        <span class="n">insert!</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">,</span> <span class="n">idx</span><span class="x">,</span> <span class="n">new_node</span><span class="x">)</span>
    <span class="k">else</span>
        <span class="n">node</span> <span class="o">=</span> <span class="n">child</span>
        <span class="n">split!</span><span class="x">(</span><span class="n">node</span><span class="x">,</span> <span class="n">overlap</span><span class="x">)</span>
        <span class="k">if</span> <span class="x">(</span><span class="n">overlap</span><span class="x">)</span> <span class="o">&lt;</span> <span class="n">length</span><span class="x">(</span><span class="n">suffix</span><span class="x">)</span> <span class="c"># new</span>
            <span class="n">new_suffix</span> <span class="o">=</span> <span class="n">get_suffix</span><span class="x">(</span><span class="n">suffix</span><span class="x">,</span> <span class="n">overlap</span><span class="x">)</span>
            <span class="n">new_node</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">(</span><span class="n">T</span><span class="x">(</span><span class="n">new_suffix</span><span class="x">),</span> <span class="nb">true</span><span class="x">)</span>
            <span class="n">idx</span> <span class="o">=</span> <span class="n">new_node</span><span class="o">.</span><span class="n">data</span> <span class="o">&lt;</span> <span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="mi">1</span><span class="x">]</span><span class="o">.</span><span class="n">data</span> <span class="o">?</span> <span class="mi">1</span> <span class="o">:</span> <span class="mi">2</span>
            <span class="n">insert!</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">,</span> <span class="n">idx</span><span class="x">,</span> <span class="n">new_node</span><span class="x">)</span>
        <span class="k">else</span> <span class="c"># new</span>
            <span class="n">node</span><span class="o">.</span><span class="n">is_label</span> <span class="o">=</span> <span class="nb">true</span> <span class="c"># new</span>
            <span class="n">node</span> <span class="c"># new</span>
        <span class="k">end</span> <span class="c"># new</span>
    <span class="k">end</span>
<span class="k">end</span></code></pre></figure>

<h3 id="print-tree">Print tree</h3>

<p>We can now make fairly complex trees.
To prove this it will be helpful to print the entire tree.</p>

<p>The tree will be printed by visiting a node and printing its data, then moving on to each of its children and doing the same one by one.
This is known as a pre-order traversal.</p>

<p>Each time we go up a level we will increase the indent for easy reading.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">print_tree</span><span class="x">(</span><span class="n">io</span><span class="o">::</span><span class="kt">IO</span><span class="x">,</span> <span class="n">root</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">;</span> <span class="n">options</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="n">print_tree_preorder</span><span class="x">(</span><span class="n">io</span><span class="x">,</span> <span class="n">root</span><span class="x">;</span> <span class="n">options</span><span class="o">...</span><span class="x">)</span>
<span class="n">print_tree</span><span class="x">(</span><span class="n">root</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">;</span> <span class="n">options</span><span class="o">...</span><span class="x">)</span> <span class="o">=</span> <span class="n">print_tree</span><span class="x">(</span><span class="nb">stdout</span><span class="x">,</span> <span class="n">root</span><span class="x">;</span> <span class="n">options</span><span class="o">...</span><span class="x">)</span>

<span class="k">function</span><span class="nf"> print_tree_preorder</span><span class="x">(</span><span class="n">io</span><span class="o">::</span><span class="kt">IO</span><span class="x">,</span> <span class="n">node</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">,</span> <span class="n">level_indent</span><span class="o">=</span><span class="s">""</span>
    <span class="x">;</span> <span class="n">indent</span><span class="o">::</span><span class="kt">AbstractString</span><span class="o">=</span><span class="s">"--"</span><span class="x">,</span> <span class="n">use_data_as_separator</span><span class="o">::</span><span class="kt">Bool</span><span class="o">=</span><span class="nb">false</span>
    <span class="x">)</span>
    <span class="n">println</span><span class="x">(</span><span class="n">io</span><span class="x">,</span> <span class="n">level_indent</span> <span class="o">*</span> <span class="n">node</span><span class="o">.</span><span class="n">data</span><span class="x">)</span>
    <span class="n">separator</span> <span class="o">=</span> <span class="n">use_data_as_separator</span> <span class="o">?</span> <span class="n">node</span><span class="o">.</span><span class="n">data</span> <span class="o">:</span> <span class="s">"|"</span>
    <span class="n">next_level</span> <span class="o">=</span> <span class="n">level_indent</span> <span class="o">*</span> <span class="n">separator</span> <span class="o">*</span> <span class="n">indent</span>
    <span class="k">for</span> <span class="n">child</span> <span class="k">in</span> <span class="n">node</span><span class="o">.</span><span class="n">children</span>
        <span class="n">print_tree_preorder</span><span class="x">(</span><span class="n">io</span><span class="x">,</span> <span class="n">child</span><span class="x">,</span> <span class="n">next_level</span>
        <span class="x">;</span> <span class="n">indent</span><span class="o">=</span><span class="n">indent</span><span class="x">,</span> <span class="n">use_data_as_separator</span><span class="o">=</span><span class="n">use_data_as_separator</span>
        <span class="x">)</span>
    <span class="k">end</span>
<span class="k">end</span></code></pre></figure>

<p>A basic example:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">root</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">(</span><span class="s">"&lt;root&gt;"</span><span class="x">)</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"t"</span><span class="x">)</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"ten"</span><span class="x">)</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"team"</span><span class="x">)</span>
<span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="s">"tea"</span><span class="x">)</span>
<span class="n">print_tree</span><span class="x">(</span><span class="n">root</span><span class="x">)</span></code></pre></figure>

<p>The output:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>&lt;root&gt;
|--t
|--|--e
|--|--|--a
|--|--|--|--m
|--|--|--n
</code></pre></div></div>

<p>Here is a fairly complex example from Wikipedia:</p>
<figure class="post-figure">
    <img class="img-80" src="/assets/posts/radix-tree/wiki_romane_tree.png" alt="Romane radix tree" />
</figure>

<p>In code:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">root</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">(</span><span class="s">"&lt;root&gt;"</span><span class="x">)</span>
<span class="k">for</span> <span class="n">key</span> <span class="k">in</span> <span class="x">[</span><span class="s">"romane"</span><span class="x">,</span> <span class="s">"romanus"</span><span class="x">,</span> <span class="s">"romulus"</span><span class="x">,</span> <span class="s">"rubens"</span><span class="x">,</span> <span class="s">"ruber"</span><span class="x">,</span> <span class="s">"rubicon"</span><span class="x">,</span> <span class="s">"rubicundus"</span><span class="x">]</span>
    <span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="n">key</span><span class="x">)</span>
<span class="k">end</span>
<span class="n">print_tree</span><span class="x">(</span><span class="n">root</span><span class="x">)</span></code></pre></figure>

<p>The output:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>&lt;root&gt;
|--r
|--|--om
|--|--|--an
|--|--|--|--e
|--|--|--|--us
|--|--|--ulus
|--|--ub
|--|--|--e
|--|--|--|--ns
|--|--|--|--r
|--|--|--ic
|--|--|--|--on
|--|--|--|--undus
</code></pre></div></div>

<h3 id="height">Height</h3>

<p>An important statistic of the tree is its height. This is the maximum number of nodes it must traverse to find a key.
This height can be attained via a recursive function:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> get_height</span><span class="x">(</span><span class="n">node</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">,</span> <span class="n">height</span><span class="o">::</span><span class="kt">Int</span><span class="o">=</span><span class="mi">0</span><span class="x">)</span>
    <span class="k">if</span> <span class="n">is_leaf</span><span class="x">(</span><span class="n">node</span><span class="x">)</span>
        <span class="k">return</span> <span class="n">height</span>
    <span class="k">end</span>
    <span class="n">next_height</span> <span class="o">=</span> <span class="n">height</span> <span class="o">+</span> <span class="mi">1</span>
    <span class="k">for</span> <span class="n">child</span> <span class="k">in</span> <span class="n">node</span><span class="o">.</span><span class="n">children</span>
        <span class="n">height</span> <span class="o">=</span> <span class="n">max</span><span class="x">(</span><span class="n">height</span><span class="x">,</span> <span class="n">get_height</span><span class="x">(</span><span class="n">child</span><span class="x">,</span> <span class="n">next_height</span><span class="x">))</span>
    <span class="k">end</span>
    <span class="n">height</span>
<span class="k">end</span></code></pre></figure>

<p>For the Romane tree above this returns a height of 4.</p>

<h3 id="iteration">Iteration</h3>

<p>The last useful feature I want to add is an iterator, also known as a generator in other languages.
The utility of an iterator is to return one data point at a time. This reduces memory usage as opposed to returning the entire dataset.</p>

<p>Julia is a functional language and as such making an iterator requires more thought than some other languages.
In Python for example it is easy to implement one with the <code class="language-plaintext highlighter-rouge">yield</code> keyword.
In Julia, the onus is on the programmer to manage the state of the iterator.
At first I found it challenging to make one for a tree but Henrique Becker’s answer in this <a href="https://discourse.julialang.org/t/iterating-over-a-tree-recursively-with-base-iterate/62512">Discourse forum</a> gave me clarity.</p>

<p>Once again, the default is a pre-order traversal:</p>

<figure class="post-figure">
    <img class="img-95" src="/assets/posts/radix-tree/preorder.png" alt="Pre-order traversal through a radix tree" />
</figure>

<p>According to the documentation on <a href="https://docs.julialang.org/en/v1/manual/interfaces/">interfaces</a>, the following code</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">for</span> <span class="n">item</span> <span class="k">in</span> <span class="n">iter</span>   
    <span class="c"># body</span>
<span class="k">end</span></code></pre></figure>

<p>is translated into:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">next</span> <span class="o">=</span> <span class="n">iterate</span><span class="x">(</span><span class="n">iter</span><span class="x">)</span>
<span class="k">while</span> <span class="n">next</span> <span class="o">!==</span> <span class="nb">nothing</span>
    <span class="x">(</span><span class="n">item</span><span class="x">,</span> <span class="n">state</span><span class="x">)</span> <span class="o">=</span> <span class="n">next</span>
    <span class="c"># body</span>
    <span class="n">next</span> <span class="o">=</span> <span class="n">iterate</span><span class="x">(</span><span class="n">iter</span><span class="x">,</span> <span class="n">state</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>The iterator will be a <code class="language-plaintext highlighter-rouge">PreOrderTraversal</code> object which will step through all nodes of the tree.
We want to only return labels so we can stop the iteration when it reaches a label.
The item will be made up of a tuple: the <code class="language-plaintext highlighter-rouge">data</code> and a boolean for <code class="language-plaintext highlighter-rouge">is_label</code>.</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">function</span><span class="nf"> Base.iterate</span><span class="x">(</span><span class="n">root</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">,</span> <span class="n">state</span><span class="o">=</span><span class="nb">nothing</span><span class="x">)</span>
    <span class="n">iter</span> <span class="o">=</span> <span class="n">PreOrderTraversal</span><span class="x">(</span><span class="n">root</span><span class="x">)</span>
    <span class="n">next</span> <span class="o">=</span> <span class="n">isnothing</span><span class="x">(</span><span class="n">state</span><span class="x">)</span> <span class="o">?</span> <span class="n">iterate</span><span class="x">(</span><span class="n">iter</span><span class="x">)</span> <span class="o">:</span> <span class="n">iterate</span><span class="x">(</span><span class="n">iter</span><span class="x">,</span> <span class="n">state</span><span class="x">)</span>
    <span class="k">while</span> <span class="n">next</span> <span class="o">!==</span> <span class="nb">nothing</span>
        <span class="x">((</span><span class="n">data</span><span class="x">,</span> <span class="n">is_label</span><span class="x">),</span> <span class="n">state</span><span class="x">)</span> <span class="o">=</span> <span class="n">next</span>
        <span class="k">if</span> <span class="n">is_label</span>
            <span class="k">return</span> <span class="x">(</span><span class="n">data</span><span class="x">,</span> <span class="n">state</span><span class="x">)</span>
        <span class="k">end</span>
        <span class="n">next</span> <span class="o">=</span> <span class="n">iterate</span><span class="x">(</span><span class="n">iter</span><span class="x">,</span> <span class="n">state</span><span class="x">)</span>
    <span class="k">end</span>
<span class="k">end</span>

<span class="n">Base</span><span class="o">.</span><span class="n">IteratorSize</span><span class="x">(</span><span class="o">::</span><span class="n">RadixTreeNode</span><span class="x">)</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">SizeUnknown</span><span class="x">()</span> </code></pre></figure>

<p>This shifts the problem to making an iterator for the <code class="language-plaintext highlighter-rouge">PreOrderTraversal</code>. 
Firstly, this object is just a wrapper around the node:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="k">struct</span><span class="nc"> PreOrderTraversal</span><span class="x">{</span><span class="n">R</span><span class="o">&lt;:</span><span class="n">RadixTreeNode</span><span class="x">}</span>
    <span class="n">root</span><span class="o">::</span><span class="n">R</span>
<span class="k">end</span></code></pre></figure>

<p>The hardest part is, what is the state?
It is all the information about the node’s parents and its parents and so on, so that we can backtrack when we need to do so.
For example at step 5 in the figure, we are at “test” which is the first child (“est”) of the second child (“t”) of the root.
This is nothing more than a list of tuples of <code class="language-plaintext highlighter-rouge">(node, idx, word)</code>. 
We can implement this as a stack. 
If <code class="language-plaintext highlighter-rouge">idx ≤ length(node.children)</code>, then increment <code class="language-plaintext highlighter-rouge">idx</code> up by one, otherwise pop from the stack and backtrack.
In full:<sup id="fnref:stack" role="doc-noteref"><a href="#fn:stack" class="footnote" rel="footnote">3</a></sup></p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">Base</span><span class="o">.</span><span class="n">IteratorSize</span><span class="x">(</span><span class="o">::</span><span class="n">PreOrderTraversal</span><span class="x">)</span> <span class="o">=</span> <span class="n">Base</span><span class="o">.</span><span class="n">SizeUnknown</span><span class="x">()</span> 

<span class="n">Base</span><span class="o">.</span><span class="n">iterate</span><span class="x">(</span><span class="n">iter</span><span class="o">::</span><span class="n">PreOrderTraversal</span><span class="x">)</span> <span class="o">=</span> <span class="x">((</span><span class="n">iter</span><span class="o">.</span><span class="n">root</span><span class="o">.</span><span class="n">data</span><span class="x">,</span> <span class="n">iter</span><span class="o">.</span><span class="n">root</span><span class="o">.</span><span class="n">is_label</span><span class="x">),</span> <span class="x">[(</span><span class="n">iter</span><span class="o">.</span><span class="n">root</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="n">iter</span><span class="o">.</span><span class="n">root</span><span class="o">.</span><span class="n">data</span><span class="x">)])</span>

<span class="k">function</span><span class="nf"> Base.iterate</span><span class="x">(</span><span class="n">iter</span><span class="o">::</span><span class="n">PreOrderTraversal</span><span class="x">,</span> <span class="n">stack_</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="kt">Tuple</span><span class="x">{</span><span class="n">RadixTreeNode</span><span class="x">{</span><span class="n">T</span><span class="x">},</span> <span class="kt">Int</span><span class="x">,</span> <span class="n">T</span><span class="x">}})</span> <span class="k">where</span> <span class="n">T</span>
    <span class="k">if</span> <span class="n">isempty</span><span class="x">(</span><span class="n">stack_</span><span class="x">)</span>
        <span class="k">return</span> <span class="nb">nothing</span>
    <span class="k">end</span>
    <span class="n">node</span><span class="x">,</span> <span class="n">idx</span><span class="x">,</span> <span class="n">word</span> <span class="o">=</span> <span class="n">last</span><span class="x">(</span><span class="n">stack_</span><span class="x">)</span>
    <span class="k">if</span> <span class="n">idx</span> <span class="o">&lt;=</span> <span class="n">length</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">)</span>
        <span class="k">return</span> <span class="n">_increment_stack!</span><span class="x">(</span><span class="n">stack_</span><span class="x">)</span>
    <span class="k">else</span> <span class="c"># backtrack</span>
        <span class="n">pop!</span><span class="x">(</span><span class="n">stack_</span><span class="x">)</span>
        <span class="k">while</span> <span class="o">!</span><span class="x">(</span><span class="n">isempty</span><span class="x">(</span><span class="n">stack_</span><span class="x">))</span>
            <span class="n">node</span><span class="x">,</span> <span class="n">idx</span><span class="x">,</span> <span class="n">word</span> <span class="o">=</span> <span class="n">last</span><span class="x">(</span><span class="n">stack_</span><span class="x">)</span>
            <span class="k">if</span> <span class="n">idx</span> <span class="o">&lt;=</span> <span class="n">length</span><span class="x">(</span><span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">)</span>
                <span class="k">return</span> <span class="n">_increment_stack!</span><span class="x">(</span><span class="n">stack_</span><span class="x">)</span>
            <span class="k">end</span>
            <span class="n">pop!</span><span class="x">(</span><span class="n">stack_</span><span class="x">)</span>
        <span class="k">end</span>
    <span class="k">end</span>
    <span class="nb">nothing</span>
<span class="k">end</span>

<span class="k">function</span><span class="nf"> _increment_stack!</span><span class="x">(</span><span class="n">stack_</span><span class="o">::</span><span class="kt">Vector</span><span class="x">{</span><span class="o">&lt;:</span><span class="kt">Tuple</span><span class="x">})</span>
    <span class="n">node</span><span class="x">,</span> <span class="n">idx</span><span class="x">,</span> <span class="n">word</span><span class="o">=</span> <span class="n">last</span><span class="x">(</span><span class="n">stack_</span><span class="x">)</span>
    <span class="n">stack_</span><span class="x">[</span><span class="k">end</span><span class="x">]</span> <span class="o">=</span> <span class="x">(</span><span class="n">node</span><span class="x">,</span> <span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="x">,</span> <span class="n">word</span><span class="x">)</span>
    <span class="n">child</span> <span class="o">=</span> <span class="n">node</span><span class="o">.</span><span class="n">children</span><span class="x">[</span><span class="n">idx</span><span class="x">]</span>
    <span class="n">new_word</span> <span class="o">=</span> <span class="n">word</span> <span class="o">*</span> <span class="n">child</span><span class="o">.</span><span class="n">data</span>
    <span class="n">push!</span><span class="x">(</span><span class="n">stack_</span><span class="x">,</span> <span class="x">(</span><span class="n">child</span><span class="x">,</span> <span class="mi">1</span><span class="x">,</span> <span class="n">new_word</span><span class="x">))</span>
    <span class="x">(</span><span class="n">new_word</span><span class="x">,</span> <span class="n">child</span><span class="o">.</span><span class="n">is_label</span><span class="x">),</span> <span class="n">stack_</span> 
<span class="k">end</span></code></pre></figure>

<p>Testing it out:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">root</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">()</span>
<span class="k">for</span> <span class="n">key</span> <span class="k">in</span> <span class="x">[</span><span class="s">"toast"</span><span class="x">,</span> <span class="s">"toaster"</span><span class="x">,</span> <span class="s">"toasting"</span><span class="x">,</span> <span class="s">"test"</span><span class="x">,</span> <span class="s">"slow"</span><span class="x">,</span> <span class="s">"slower"</span><span class="x">,</span> <span class="s">"slowly"</span><span class="x">]</span>
    <span class="n">insert!</span><span class="x">(</span><span class="n">root</span><span class="x">,</span> <span class="n">key</span><span class="x">)</span>
<span class="k">end</span>
<span class="k">for</span> <span class="n">item</span> <span class="k">in</span> <span class="n">PreOrderTraversal</span><span class="x">(</span><span class="n">root</span><span class="x">)</span>
    <span class="n">print</span><span class="x">(</span><span class="n">item</span><span class="x">,</span> <span class="s">", "</span><span class="x">)</span>
<span class="k">end</span>
<span class="cm">#= ("", false), ("slow", true), ("slower", true), ("slowly", true), ("t", false), ("test", true), ("toast", true), ("toaster", true), ("toasting", true) =#</span>
<span class="k">for</span> <span class="n">item</span> <span class="k">in</span> <span class="n">root</span>
    <span class="n">print</span><span class="x">(</span><span class="n">item</span><span class="x">,</span> <span class="s">", "</span><span class="x">)</span>
<span class="k">end</span>
<span class="cm">#= slow, slower, slowly, test, toast, toaster, toasting, =#</span></code></pre></figure>

<h2 id="worked-example">3 Worked example</h2>

<p>Here is a list of 10,000 words compiled by MIT: <a href="https://www.mit.edu/~ecprice/wordlist.10000">www.mit.edu/~ecprice/wordlist.10000</a>.<sup id="fnref:MIT_word_list" role="doc-noteref"><a href="#fn:MIT_word_list" class="footnote" rel="footnote">4</a></sup></p>

<p>After downloading the list we can load and insert it into a tree:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">tree</span> <span class="o">=</span> <span class="n">RadixTreeNode</span><span class="x">()</span>
<span class="n">filepath</span> <span class="o">=</span> <span class="s">"mit_words.txt"</span>
<span class="n">open</span><span class="x">(</span><span class="n">filepath</span><span class="x">,</span> <span class="s">"r"</span><span class="x">)</span> <span class="k">do</span> <span class="n">f</span>
    <span class="k">for</span> <span class="n">line</span> <span class="k">in</span> <span class="n">eachline</span><span class="x">(</span><span class="n">f</span><span class="x">)</span>
        <span class="n">insert!</span><span class="x">(</span><span class="n">tree</span><span class="x">,</span> <span class="n">line</span><span class="x">)</span>
    <span class="k">end</span>
<span class="k">end</span></code></pre></figure>

<p>Some basic statistics:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">get_height</span><span class="x">(</span><span class="n">tree</span><span class="x">)</span> <span class="c"># 11</span>
<span class="n">Base</span><span class="o">.</span><span class="n">summarysize</span><span class="x">(</span><span class="n">tree</span><span class="x">)</span> <span class="c"># 978170 = 0.93 MB</span></code></pre></figure>

<p>Print the tree to a file:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">open</span><span class="x">(</span><span class="s">"tree.txt"</span><span class="x">,</span> <span class="s">"w"</span><span class="x">)</span> <span class="k">do</span> <span class="n">f</span>
    <span class="n">print_tree</span><span class="x">(</span><span class="n">f</span><span class="x">,</span> <span class="n">tree</span><span class="x">;</span> <span class="n">use_data_as_separator</span><span class="o">=</span><span class="nb">true</span><span class="x">)</span>
<span class="k">end</span></code></pre></figure>

<p>All words that start with “trea”:</p>

<figure class="highlight"><pre><code class="language-julia" data-lang="julia"><span class="n">node</span><span class="x">,</span> <span class="n">matched</span> <span class="o">=</span> <span class="n">get</span><span class="x">(</span><span class="n">tree</span><span class="x">,</span> <span class="s">"trea"</span><span class="x">)</span>
<span class="n">prefix</span> <span class="o">=</span> <span class="n">first</span><span class="x">(</span><span class="s">"trea"</span><span class="x">,</span> <span class="n">matched</span><span class="x">)</span>
<span class="n">suffix</span> <span class="o">=</span> <span class="n">get_suffix</span><span class="x">(</span><span class="s">"trea"</span><span class="x">,</span> <span class="n">num_found</span><span class="x">)</span>
<span class="k">for</span> <span class="n">child</span> <span class="k">in</span> <span class="n">node</span><span class="o">.</span><span class="n">children</span>
    <span class="k">if</span> <span class="n">startswith</span><span class="x">(</span><span class="n">child</span><span class="o">.</span><span class="n">data</span><span class="x">,</span> <span class="n">suffix</span><span class="x">)</span>
        <span class="k">for</span> <span class="n">data</span> <span class="k">in</span> <span class="n">child</span>
            <span class="n">print</span><span class="x">(</span><span class="n">prefix</span> <span class="o">*</span> <span class="n">data</span><span class="x">,</span> <span class="s">", "</span><span class="x">)</span>
        <span class="k">end</span>
    <span class="k">end</span>
<span class="k">end</span>
<span class="cm">#= treasure, treasurer, treasures, treasury, treat, treated, treating, treatment, treatments, treaty, =#</span></code></pre></figure>

<h2 id="worked-example">4 Conclusion</h2>

<p>Thank you for following along. I hope you found this useful.</p>

<hr />

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:inverted_index" role="doc-endnote">
      <p>For an example of a radix tree used for an inverted index, see this post from <a href="https://www.algolia.com/blog/engineering/inside-the-algolia-engine-part-2-the-indexing-challenge-of-instant-search">Algolia</a>. Although as far as inverted indexes go, <a href="https://lucene.apache.org/core/">Lucene</a> is the industry standard with the most optimised implementation. Its complicated inverted index is based on skip lists and finite state tranducers. Lucene forms the basis of the popular <a href="https://www.elastic.co/elasticsearch">ElasticSearch</a> search engine. <a href="#fnref:inverted_index" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:Dictionary" role="doc-endnote">
      <p>Another option is to make the tree a kind of dictionary by using the string at each node as a key and storing another value. This is the design choice made by <a href="https://juliacollections.github.io/DataStructures.jl/stable/trie/">DataStructures.jl</a> in their <code class="language-plaintext highlighter-rouge">Trie</code> data structure. The values we could store are the term frequency of the word or a list of documents where that word occurs (inverted index). <a href="#fnref:Dictionary" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:stack" role="doc-endnote">
      <p>I’ve called the variable <code class="language-plaintext highlighter-rouge">stack_</code> instead of <code class="language-plaintext highlighter-rouge">stack</code> because a function already exists with that name. <a href="#fnref:stack" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:MIT_word_list" role="doc-endnote">
      <p>Warning: there are profanities in this list. Also there are at least two mistakes: “trembl” and “documentcreatetextnode”. <a href="#fnref:MIT_word_list" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Lior Sinai</name></author><category term="coding" /><category term="radix" /><category term="tree" /><summary type="html"><![CDATA[A radix tree in Julia, built following Test Driven Development (TDD).]]></summary></entry></feed>