MicroGrad.jl: Part 4 Extensions

Posted on 17 August, 2024

https://wallpapers.com/

A series on automatic differentiation in Julia. Part 4 extends part 3 to handle maps, getfield and anonymous functions. It creates a generic gradient descent and uses this to fit a polynomial.

This is part of a series. The other articles are:

All source code can be found at MicroGrad.jl.

Table of Contents

1 Introduction

By end of part 3 we had code that could automatically differentiate many functions as long as we had rrules and there was no control flow.

However, the code failed for the polynomial model:

struct Polynomial{V<:AbstractVector}
    weights::V
end
(m::Polynomial)(x) = evalpoly(x, m.weights)
(m::Polynomial)(x::AbstractVector) = map(m, x)
model = Polynomial([3.0, 2.0, -3.0, 1.0])
x = [1.0, 2.0, 3.0, 4.0]
pullback(model, x) # ERROR: No method found for Tuple{typeof(fieldtype) ....}

Calling @code_ir model(x), we can see that code is lowered as follows:

1: (%1, %2)
  %7 = Main.map(%6, %2)
  return %7

And further that model(1.0) is lowered to:

1: (%1, %2)
  %3 = Base.getproperty(%1, :weights)
  %4 = Main.evalpoly(%2, %3)
  return %4

We could have also defined the map using an anonymous function:

(m::Polynomial)(x::AbstractVector) = map(x->evalpoly(x, m.weights), x)

In which case it would have been lowered to:

1: (%1, %2)
  %3 = Main.:(var"#43#44")
  %4 = Core.typeof(%1)
  %5 = Core.apply_type(%3, %4)
  %6 = %new(%5, %1)
  %7 = Main.map(%6, %2)
  return %7

The calls to Core.typeof and Core.apply_type are in the list of ignored functions. However we need to handle map, getproperty and %new. These sort of functions do not have formal mathematical derivatives and so they do not have rrules in ChainRules.jl. Instead, Zygote.jl handles these functions with their own custom pullbacks. Zygote also replaces some low level functions like new, getproperty and getindex entirely with custom code.

2 Extending pullback

2.1 map

The pullback for map is fairly complex. What will be presented here is a simplified version. It might also help to look at the less generic code in the example in part 1.

Consider the following code:

f(x) = sin(x)
x = [0.1, 0.2, 0.5]
map(f, x)

The pullback for map should return 3 values: $\text{s̄elf}$ for map, $\bar{f}$ for the function f and $\bar{x}$ for each value in x.

The code will start by getting pullbacks for each value in x:

ys_and_backs = map((xs...) -> pullback(f, xs...), x) # ((0.099, Pullback), (0.198, Pullback), (0.479, Pullback))

This list is in a “zipped” format: there are $n$ entries of $(y_i, \mathcal{B}_i)$ for an array length $n$. This will be unzipped into two lists each of length $n$: $(y_1,…,y_n), (\mathcal{B}_1,…,\mathcal{B}_n)$:

Δ = ones(length(x))
ys = map(first, ys_and_backs) # (0.099, 0.198, 0.479)
∂f_and_∂x_zipped = map(((_, pb), δ) -> pb(δ), ys_and_backs, Δ) # ((nothing, 0.995), (nothing, 0.980), (nothing, 0.877))

The gradients list of $n$ entries

\[((\text{s̄elf}_1, \bar{x}_{11}, ..., \bar{x}_{k1}), ...,(\text{s̄elf}_n, \bar{x}_{1n}, ..., \bar{x}_{kn}))\]

needs to be further unzipped into $k+1$ lists for $\text{s̄elf}$ and $k$ arguments:

\[(\text{s̄elf}_1,...,\text{s̄elf}_{n}), (\bar{x}_{11},...,\bar{x}_{1n}), ... (\bar{x}_{k1},...,\bar{x}_{kn})\]

This is done with an unzip function which generalises first to any index i (source):

struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]
(::StaticGetter{i})(::Nothing) where {i} = nothing

function _unzip(tuples, ::Val{N}) where {N}
  getters = ntuple(n -> StaticGetter{n}(), N)
  map(g -> map(g, tuples), getters)
end

function unzip(tuples)
  N = length(first(tuples))
  _unzip(tuples, Val(N))
end

The result:

∂f_and_∂x = unzip(∂f_and_∂x_zipped) # [nothing, nothing, nothing], [0.995, 0.98, 0.877]

As a final step, all the gradients for the function are accumulated into one value:

∂f = reduce(accum, ∂f_and_∂x[1]) # nothing

Putting all this code in a single function (source):

function pullback(::typeof(map), f::F, args::Vararg{Any, N}) where {F, N}
    ys_and_backs = map((xs...) -> pullback(f, xs...), args...)
    ys = map(first, ys_and_backs)
    function map_pullback(Δ)
      # technically should apply f in reverse and reverse back afterwards in case f is stateful
      ∂f_and_∂x_zipped = map(((_, pb), δ) -> pb(δ), ys_and_backs, Δ)
      ∂f_and_∂x = unzip(∂f_and_∂x_zipped) 
      ∂f = reduce(accum, ∂f_and_∂x[1])
      ∂args = ∂f_and_∂x[2:end]
      return (nothing, ∂f, ∂args...)
    end
    ys, map_pullback
end

Testing:

x = [0.1, 0.2, 0.5]
z, back = pullback(map, sin, x) 
back(ones(length(x))) # (nothing, nothing, [0.995, 0.98, 0.877])

And also:

f(a,b)=a/(a+b*b)
z, back = pullback(map, f, [2.0, 4.0], [3.0, 5.0]) 
back([1.0, 1.0]) # (nothing, nothing, [0.074, 0.029], [-0.099, -0.047])

2.2 Instrument

Zygote.jl modifies some of the source code before creating the primal and reverse passes. Here is a simplified version of this instrument function which only replaces new and getfield (source):

function instrument(ir::IR)
    pr = Pipe(ir)
    for (v, st) in pr
        ex = st.expr
        if isexpr(ex, :new)
            pr[v] = xcall(Main, :__new__, ex.args...)
        elseif is_literal_getfield(ex)
            pr[v] = xcall(Main, :literal_getfield, ex.args[2], Val(unwrapquote(ex.args[3])))
        end
    end
    finish(pr)
end

iscall(x, m::Module, n::Symbol) = isexpr(x, :call) && x.args[1] == GlobalRef(m, n)
unwrapquote(x) = x
unwrapquote(x::QuoteNode) = x.value

is_literal_getfield(ex) =
  (iscall(ex, Core, :getfield) || iscall(ex, Base, :getfield)) &&
  ex.args[3] isa Union{QuoteNode,Integer}

Modify the existing _generate_pullback_via_decomposition and _generate_callable_pullback functions to call it:

function _generate_pullback_via_decomposition(T, world)
    m = meta(T; world=world)
    isnothing(m) && return nothing
    ir = IR(m)
    length(blocks(ir)) == 1 || error("control flow is not supported")
    ir = instrument(ir) # new
    pr, calls = primal(ir, T)
    m, pr, calls
end

function _generate_callable_pullback(j::Type{<:Pullback{S, T}}, world, Δ) where {S, T}
    m = meta(S; world=world)
    ir = IR(m)
    isnothing(ir) && return :(error("Non-differentiable function ", repr(args[1])))
    length(blocks(ir)) == 1 || error("control flow is not supported")
    ir = instrument(ir) # new
    back = reverse_differentiate(ir)
    back = slots!(inlineable!(back))
    ci = build_codeinfo_(back)
    ci.slotnames = [Symbol("#self#"), :Δ]
    ci
end 

Now we need to define literal_getfield and __new__ and their pullbacks.

2.3 getfield

Calls to getproperty default to getfield, where a field is is declared in a struct’s declaration. The getfield function is substituted with literal_getfield (source):

literal_getfield(x, ::Val{f}) where f = getfield(x, f)

The pullback will return a NamedTuple for each field, where the gradient is Δ for the relevant field and nothing for the others (source):

@generated nt_nothing(x) = Expr(:tuple, [:($f=nothing) for f in fieldnames(x)]...)
@generated pair(::Val{k}, v, _=nothing) where k = :($k = v,)

function pullback(::typeof(literal_getfield), x, ::Val{f}) where f
  val = getfield(x, f)
  function literal_getfield_back(Δ)
    if isimmutable(x)
      dx = (; nt_nothing(x)..., pair(Val(f), Δ)...)
      (nothing, dx, nothing)
    else
      error("multable stucts not supported")
    end
  end
  val, literal_getfield_back
end

pullback(::typeof(getfield), x, field_name::Symbol) = pullback(literal_getfield, x, Val(field_name))

For example:

struct Foo
    a
    b
    c
end
foo = Foo(1.0, 'a', "hello")
z, back = pullback(getfield, foo, :b) # ('a', literal_getfield_back)
back(1.0) # (nothing, (a = nothing, b = 1.0, c = nothing), nothing)

And for the polynomial model:

z, back = pullback(model, 1.0)
back(2.3) # ((weights = [2.3, 2.3, 2.3, 2.3],), -2.3)

For the first time we have a value $\text{s̄elf}$, which is the named tuple for the fields.

2.4 new

The code now works with:

(m::Polynomial)(x::AbstractVector) = map(m, x)

It returns $\text{s̄elf}$ and $\bar{x}$:

model = Polynomial([3.0, 2.0, -3.0, 1.0])
x = [1.0, 2.0, 3.0, 4.0]
z, back = pullback(model, x)
back(ones(4)) # ((weights = [4.0, 10.0, 30.0, 100.0],), [-1.0, 2.0, 11.0, 26.0])

However with an anonymous function:

(m::Polynomial)(x::AbstractVector) = map(x->evalpoly(x, m.weights), x)

nothing is returned for $\text{s̄elf}$:

z, back = pullback(model, x)
back(ones(4)) # (nothing, [-1.0, 2.0, 11.0, 26.0])

If we inspect the primal(ir), we see that it’s because no pullbacks and hence no gradients are recorded against variable %1 (self):

1: (%1, %2)
  %3 = Main.:(var"#74#75")
  %4 = Core.typeof(%1)
  %5 = Core.apply_type(%3, %4)
  %6 = %new(%5, %1)
  %7 = Main.pullback(Main.map, %6, %2)
  %8 = Base.getindex(%7, 1)
  %9 = Base.getindex(%7, 2)
  %10 = Base.tuple(%9)
  %11 = (Pullback{Any})(%10)
  %12 = Base.tuple(%8, %11)
  return %12

The solution is to swap %new with a call to a custom function __new__ with a pullback. This function is as follows (source):

macro __splatnew__(T, args)
  esc(Expr(:splatnew, T, args))
end

@inline __new__(T, args...) = @__splatnew__(T, args)

And the pullback is (source):

using Base: RefValue
struct Jnew{T,G}
  g::G
end

Jnew{T}(g) where T = Jnew{T,typeof(g)}(g)

function pullback(::typeof(__new__), T, args...)
  x = __new__(T, args...)
  g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(x)
  x, Jnew{T,typeof(g)}(g)
end

@generated function (back::Jnew{T,G})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G}
  !ismutabletype(T) && Δ == Nothing && return :nothing
  Δ = G == Nothing ? :Δ :
      Δ <: RefValue ? :(back.g[]) :
      :(accum(back.g[], Δ))
  quote
     = $Δ
    $(G == Nothing || :(back.g[] = nt_nothing($Δ)))
    (nothing, nothing, $(map(f -> :(.$f), fieldnames(T))...))
  end
end

Now if we try the following (after redefining @generated function pullback and function (methodinstance::Pullback)) we should get the same results:

z, back = pullback(model, x)
back(ones(4)) # ((weights = [4.0, 10.0, 30.0, 100.0],), [-1.0, 2.0, 11.0, 26.0])

3 Gradient Descent revisited

3.1 Generic Gradient Descent

Now that we have an automatic differentiation engine, it is possible to create a much more generic gradient descent function than in part 1:

function gradient_descent!(
    model,
    loss,
    X::AbstractVecOrMat,
    Y::AbstractVecOrMat
    ; learning_rate::AbstractFloat=0.1,
    max_iters::Integer=100
    )
    losses = Float64[]
    for i in 1:max_iters
        loss_iter, back = pullback(model) do m
            result = m(X)
            loss(result, Y)
        end 
        Δf, Δm = back(1.0)
        update_params!(parameters(model), Δm; learning_rate=learning_rate)
        push!(losses, loss_iter)  
    end
    losses
end

Note that pullback(m->f(m), model) is directly equivalent to pullback(model) do f(m) end.

The update_params! function is defined as follows:

function update_params!(params::NamedTuple, grads::NamedTuple; options...)
    for key in keys(params)
        update_params!(params[key], grads[key]; options...)
    end
end

function update_params!(params::Tuple, grads::Tuple; options...)
    for (p, g) in zip(params, grads)
        update_params!(p, g; options...)
    end
end

function update_params!(params, grads; learning_rate::AbstractFloat=0.1)
    params .-= learning_rate .* grads # must broadcast to edit elements and not copies!
end

The parameters function is defined per model. (Flux uses the generic Functors.jl library to accomplish something similar.)

3.2 Polynomial curve fitting revisited

Let’s create the exact same data set from part 1:

using StatsBase
target_weights = [15.0, -2.1, 13.9, 1.5]
noise_factor = 0.2
xs = (rand(100) .- 0.5) .* 10
ys = map(x -> evalpoly(x, target_weights), xs)
scale_factor = mean(abs.(ys))
ys .+= randn(length(ys)) * scale_factor * noise_factor

The Polynomial model is defined in the introduction. We also need a custom method for parameters:

parameters(m::Polynomial) = (;weights=m.weights)

Define the model:

model = Polynomial(rand(4))

Some sanity checks:

x = [1.0, 2.0, 3.0]
z, back = pullback(model, x) # ([1.68, 7.21, 21.2], Pullback) 
back([1.0, 1.0, 1.0]) # ((weights = [3.0, 6.0, 14.0, 36.0],), [-1.0, 2.0, 11.0])
z, back = pullback(m->m(x), model) 
back([1.0, 1.0, 1.0]) # (nothing, (weights = [3.0, 6.0, 14.0, 36.0],))
y = [2.0, 4.0, 8.0]
z, back = pullback(m->mse(m(x), y), model) 
back(1.0) # (nothing, (weights = [10.7 30.5, 87.6, 254.6],))

Train the model:

history = gradient_descent!(model, mse, xs, ys; learning_rate=1e-5, max_iters=2000)

This works just as well as before.

4 Conclusion

We now have a fully working AD package. It has some limitations, such as it cannot handle control flow or keyword arguments. However it can already work on a wide variety of code. All that might be needed is new rrule definitions. The next and final part of this series is a demonstration of exactly that.