Intro to Zygote.jl


21st Century AD Technique

What is AD

Automatic/Algorithmic Differentiation

Automatic Differentiation

  • One way to compute derivative
  • Core technique of modern ML/DL optimization

How AD work

  • Wengert List (Tape/Graph)
  • Derivative Definition
  • Chain Rules

Derivative Definition

Chain Rules

Wengert List

  • A list of expression/instruction
  • Transform the expression with derivative definition

f(x) = 5sin(log(x)) \\\\ \text{Wengert List of}\ f\\ y1 = log(x) \\ y2 = sin(y1) \\ y3 = 5 * y2 \\

f'(x) = \frac{d}{dx}f(x) = \frac{5\cos(\log(x))}{x}\\ \text{Wengert List of}\ f'\\ dy1 = 1 / x \\ dy2 = cos(y1) \\ dy3 = dy2 * dy1 \\ dy4 = 5 * dy3 \\

\frac{d}{dx}f(x) = \frac{5\cos(\log(x))}{x}\\ \text{Wengert List of}\ f'\\ dy1 = 1 / x \\ dy2 = cos(y1) \\ dy3 = dy2 * dy1 \\ dy4 = 5 * dy3 \\
f(x) = 5sin(log(x)) \\\\ \text{Wengert List of}\ f\\ y1 = log(x) \\ y2 = sin(y1) \\ y3 = 5 * y2 \\

How AD work

  1. Get the Wengert List of the given expression
  2. Transform each instruction in the Wengert List
  3. Apply Chain rule

Different Types of AD

  • Forward mode
  • Reverse mode
  • Mix mode
  • ...

Forward mode

  • Dual number
  • chain rule multiplication start from the input (dy1/dx * dy2/dy1)
  • computational efficiency on multivariable differentiation with more output than input

Reverse mode

  • Tracker
  • Chain rule multiplication start from the output (dy/dy2 * dy2/dy1)
  • computational efficiency on multivariable differentiation with more input than output

 DL situation!!

Dual Number

struct Dual{T<:Real} <: Real

import Base: +, -, *, /
a::Dual + b::Dual = Dual(a.x + b.x, a.ϵ + b.ϵ)
a::Dual - b::Dual = Dual(a.x - b.x, a.ϵ - b.ϵ)
a::Dual * b::Dual = Dual(a.x * b.x, b.x * a.ϵ + a.x * b.ϵ)
a::Dual / b::Dual = Dual(a.x * b.x, b.x * a.ϵ - a.x * b.ϵ)

Base.sin(d::Dual) = Dual(sin(d.x), d.ϵ * cos(d.x))
Base.cos(d::Dual) = Dual(cos(d.x), - d.ϵ * sin(d.x))
Base.log(d::Dual) = Dual(log(d.x), d.ϵ / d.x)

Base.convert(::Type{Dual{T}}, x::Dual) where T = Dual(convert(T, x.x), convert(T, x.ϵ))
Base.convert(::Type{Dual{T}}, x::Real) where T = Dual(convert(T, x), zero(T))
Base.promote_rule(::Type{Dual{T}}, ::Type{R}) where {T,R} = Dual{promote_type(T,R)}

D(f, x) = f(Dual(x, one(x))).ϵ

f(x) = 5*sin(log(x))
df(x) = 5*cos(log(x))/x

D(f, 3.) == df(3.)


  • Record every operation on variables to Wengert List
  • Derive the Wengert List w.r.t. given variable
  • Static graph -> build Wengert List one time
  • Dynamic graph -> build Wengert List every time

<- TensorFlow

Pytorch, ->

 eager mode

Where is the Wengert List in Forward mode?

Wengert underneath Dual Number

julia> @code_warntype f(3.)
1 ─ %1 = invoke Main.log(_2::Float64)::Float64
│   %2 = invoke Main.sin(%1::Float64)::Float64
│   %3 = (Base.mul_float)(5.0, %2)::Float64
└──      return %3

julia> @code_warntype D(f, 3)
1 ─ %1  = (Base.sitofp)(Float64, x)::Float64
│   %2  = invoke Base.Math.log(%1::Float64)::Float64
│   %3  = (Base.sitofp)(Float64, 1)::Float64
│   %4  = (Base.sitofp)(Float64, x)::Float64
│   %5  = (Base.div_float)(%3, %4)::Float64
│   %6  = invoke Main.sin(%2::Float64)::Float64
│   %7  = invoke Main.cos(%2::Float64)::Float64
│   %8  = (Base.mul_float)(%5, %7)::Float64
│   %9  = (Base.mul_float)(%6, 0.0)::Float64
│   %10 = (Base.mul_float)(5.0, %8)::Float64
│   %11 = (Base.add_float)(%9, %10)::Float64
└──       return %11

\text{Wengert List of}\ f'\\ y1 = log(x) \\ dy1 = 1 / x \\ dy2 = cos(y1) \\ dy3 = dy2 * dy1 \\ dy4 = 5 * dy3 \\

Why Julia

  • Fast
  • really good compiler design



  • Source to source AD
  • support control flow, recursion, closures, structs, dictionaries, ...


julia> using Zygote

julia> f(x) = 5x + 3

julia> f(10), f'(10)
(53, 5)

julia> @code_llvm f'(10)
define i64 @"julia_#625_38792"(i64) {
  ret i64 5


julia> f(x) = 5*sin(log(x))
f (generic function with 1 method)
julia> f'                                              
#34 (generic function with 1 method)
julia> f'(3.)                  
julia> @code_llvm f'(3.)  
define double @"julia_#34_13685"(double) {
  %1 = call double @julia_log_4663(double %0)
  %2 = call double @julia_sin_13686(double %1)
  %3 = call double @julia_cos_13689(double %1)
  %4 = fmul double %3, 5.000000e+00
  %5 = fdiv double 1.000000e+00, %0
  %6 = fmul double %5, %4
  ret double %6

julia> @code_native f'(3.)


julia> f(x) = 5*sin(log(x))
f (generic function with 1 method)
julia> f'                                              
#34 (generic function with 1 method)
julia> f'(3.)                  
julia> ddf(x) = -(5(sin(log(x)) + cos(log(x))))/x^2
ddf (generic function with 1 method)

julia> ddf(3.)

julia> (f')'(3.)


julia> fs = Dict("sin" => sin, "cos" => cos, "tan" => tan);

julia> gradient(x -> fs[readline()](x), 1)

julia> function pow(x::T, n::Int) where T
           r = 1::T
           while n > 0
               n -= 1
               r *= x
           return r
pow (generic function with 1 method)

julia> g(x) = pow(x, 5)
g (generic function with 1 method)

julia> g'(2)

julia> gradient(pow, 2, 5)
(80, nothing)


julia> using Zygote: @adjoint

julia> add(a, b) = a + b

julia> @adjoint add(a, b) = add(a, b), Δ -> (Δ, Δ)

Source to Source AD

Source-to-source AD

Differentiate SSA Form

Julia Compile Process

SSA From

  • Static Single Assignment Form
  • All the variable will only be assigned once
  • Most variable comes from function calls
  • All the control flows become branches

SSA Form

Not just Unroll control flow

function grad_pow(x, n)
    r = 1
    Bs = Tuple{Int, Int}[]
    while n > 0
        push!(Bs, (r, x))
        r *= x
        n -= 1
    dx = 0
    dr = 1
    for i = length(Bs):-1:1
        (r, x) = Bs[i]
        dx += dr*r
        dr = dr*x
    return dx
function pow(x, n)
    r = 1
    while n > 0
        n -= 1
        r *= x
    return r



  • Compile every Julia function into differentiable
  • Easy to add gradient hook
  • Differentiable Programming!!

Differentiable Programming

Zygote + Pytorch

Zygote + XLA.jl



By Peter Cheng

21st Century AD Technique

