Cassette.jl

Overdub Your Julia Code

About Me

  • Student

The talk is based on the document & the developer's talk

What is Cassette?

What can cassette do?

  • extend the Julia language by directly injecting the Julia compiler with new, context-specific behaviors.
  • dynamically injecting code transformation passes into Julia’s just-in-time (JIT) compilation cycle

First Thought

for example

Method overloading

  • Common & Standard
  • need manually implementation

...But Stinks

  •  Ultimately thwarted by dispatch and/or structural type constraints in non-generic target programs.
  • Proper usage of overloading-based nonstandard execution tools require proper genericity criteria, i.e. "what weird subset of Julia do I really support?".
  • Not all relevant Julia language mechanisms are fully exposed/interceptable via method overloading (e.g. control flow, literals, bindings, calling scope)

Why Julia?

Why Julia?

  • Julia provides a general&powerful api(AST & Julia IR), let you to play with julia compiler easily

Julia compile loop

WARNING

Situation of v0.7

  • Julia v0.7 in rc
  • packages need to be update
  • mess in document

Julia v1.0 Just Release!!

warning of cassette

  • each version only support specific version of Julia 
  • highly depend on Julia compiler
  • might have performance and correctness bugs caused by either Cassette or Julia itself

Cassette v0.1.0 release at JuliaConf2018

Take a look!

Simple logging

import Cassette: @context, prehook, @overdub

@context PrintCtx

prehook(::PrintCtx, f, args...) = println(f, args)

@overdub(PrintCtx(), 1/2)

# /(1/2)
# float(1,)
# AbstractFloat(1,)
# Float64(1,)
# sitofp(Float64, 1)
# float(2,)
# AbstractFloat(2,)
# Float64(2,)
# sitofp(Float64, 2)
# /(1.0, 2.0)
# div_float(1.0, 2.0) 
# 0.5

Counting Call

import Cassette: @context, prehook, @overdub

mutable struct Count{T}
    count::Int
end

@context CountCtx

function prehook(ctx::CountCtx{Count{T}}, ::Any, ::T, ::Any...) where T
    ctx.metadata.count += 1
end

c = Count{DataType}(0)

@overdub(CountCtx(metadata=c), 1/2)
# 0.5

julia> c
# Count{DataType}(2)

Overdub

Mental Model

# turn f(args...) into
#
begin
    Cassette.prehook(context, f, args...)
    tmp = Cassette.execute(context, f, args...)
    tmp = isa(tmp, Cassette.OverdubInstead) ? overdub(context, f, args...) : tmp
    Cassette.posthook(context, tmp, f, args...)
    tmp
end

Contextual dispatch

Contextual dispatch

  • run program in a specific context
  • can propagate metadata within a context

example

using Cassette

Cassette.@context TraceCtx

function Cassette.execute(ctx::TraceCtx, args...)
    subtrace = Any[]
    push!(ctx.metadata, args => subtrace)
    if Cassette.canoverdub(ctx, args...)
        newctx = Cassette.similarcontext(ctx, metadata = subtrace)
        return Cassette.overdub(newctx, args...)
    else
        return Cassette.fallback(ctx, args...)
    end
end

trace = Any[]
x, y, z = rand(3)
f(x, y, z) = x*y + y*z
Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z))

# returns `true`
trace == Any[
   (f,x,y,z) => Any[
       (*,x,y) => Any[(Base.mul_float,x,y)=>Any[]]
       (*,y,z) => Any[(Base.mul_float,y,z)=>Any[]]
       (+,x*y,y*z) => Any[(Base.add_float,x*y,y*z)=>Any[]]
   ]
]

Contextual pass injection

pass injection

  • modify the code that being pass to the compiler

pass injection

using Cassette
using Cassette: @pass, @overdub


fitsin32bit(x) = false
fitsin32bit(x::Integer) = (typemin(Int32) <= x <= typemax(Int32))
fitsin32bit(x::AbstractFloat) = (typemin(Float32) <= x <= typemax(Float32))

to32bit(x::Integer) = convert(Int32, x)
to32bit(x::AbstractFloat) = convert(Float32, x)

bit32pass = @pass (ctxtype, method_signature, method_body) -> begin
    Cassette.replace_match!(to32bit, fitsin32bit, method_body.code)
    return method_body
end


Cassette.@context Ctx

a = rand()
@overdub(Ctx(pass = bit32pass), foo(a))

example

using Cassette

Cassette.@context Ctx

mutable struct Callback
    f::Any
end

function Cassette.execute(ctx::Ctx, ::typeof(println), args...)
    previous = ctx.metadata.f
    ctx.metadata.f = () -> (previous(); println(args...))
    return nothing
end

example

julia> begin
           a = rand(3)
           b = rand(3)
           function add(a, b)
               println("I'm about to add $a + $b")
               c = a + b
               println("c = $c")
               return c
           end
           add(a, b)
       end
I'm about to add [0.457465, 0.62078, 0.954555] + [0.0791336, 0.744041, 0.976194]
c = [0.536599, 1.36482, 1.93075]
3-element Array{Float64,1}:
 0.5365985032259399
 1.3648210555868863
 1.9307494378914405

julia> ctx = Ctx(metadata = Callback(() -> nothing));

julia> c = Cassette.overdub(ctx, add, a, b)
3-element Array{Float64,1}:
 0.5365985032259399
 1.3648210555868863
 1.9307494378914405

julia> ctx.metadata.f()
I'm about to add [0.457465, 0.62078, 0.954555] + [0.0791336, 0.744041, 0.976194]
c = [0.536599, 1.36482, 1.93075]

example

using Cassette
using Core: CodeInfo, SlotNumber, SSAValue

Cassette.@context Ctx

function Cassette.execute(ctx::Ctx, callback, f, args...)
    if Cassette.canoverdub(ctx, f, args...)
        _ctx = Cassette.similarcontext(ctx, metadata = callback)
        return Cassette.overdub(_ctx, f, args...) # return result, callback
    else
        return Cassette.fallback(ctx, f, args...), callback
    end
end

function Cassette.execute(ctx::Ctx, callback, ::typeof(println), args...)
    return nothing, () -> (callback(); println(args...))
end

example

function sliceprintln(::Type{<:Ctx}, ::Type{S}, ir::CodeInfo) where {S}
    callbackslotname = gensym("callback")
    push!(ir.slotnames, callbackslotname)
    push!(ir.slotflags, 0x00)
    callbackslot = SlotNumber(length(ir.slotnames))
    getmetadata = Expr(:call, Expr(:nooverdub, GlobalRef(Core, :getfield)), Expr(:contextslot), QuoteNode(:metadata))

    # insert the initial `callbackslot` assignment into the IR.
    Cassette.insert_statements!(ir.code, ir.codelocs,
                                 (stmt, i) -> i == 1 ? 2 : nothing,
                                 (stmt, i) -> [Expr(:(=), callbackslot, getmetadata), stmt])

    # replace all calls of the form `f(args...)` with `callback(f, args...)`, taking care to
    # properly destructure the returned `(result, callback)` into the appropriate statements
    Cassette.insert_statements!(ir.code, ir.codelocs,
                                 (stmt, i) -> begin
                                    i > 1 || return nothing # don't slice the callback assignment
                                    stmt = Base.Meta.isexpr(stmt, :(=)) ? stmt.args[2] : stmt
                                    return Base.Meta.isexpr(stmt, :call) ? 3 : nothing
                                 end,
                                 (stmt, i) -> begin
                                     items = Any[]
                                     callstmt = Base.Meta.isexpr(stmt, :(=)) ? stmt.args[2] : stmt
                                     push!(items, Expr(:call, callbackslot, callstmt.args...))
                                     push!(items, Expr(:(=), callbackslot, Expr(:call, Expr(:nooverdub, GlobalRef(Core, :getfield)), SSAValue(i), 2)))
                                     result = Expr(:call, Expr(:nooverdub, GlobalRef(Core, :getfield)), SSAValue(i), 1)
                                     if Base.Meta.isexpr(stmt, :(=))
                                         result = Expr(:(=), stmt.args[1], result)
                                     end
                                     push!(items, result)
                                     return items
                                 end)

    # replace return statements of the form `return x` with `return (x, callback)`
    Cassette.insert_statements!(ir.code, ir.codelocs,
                                  (stmt, i) -> Base.Meta.isexpr(stmt, :return) ? 2 : nothing,
                                  (stmt, i) -> begin
                                      return [
                                          Expr(:call, Expr(:nooverdub, GlobalRef(Core, :tuple)), stmt.args[1], callbackslot)
                                          Expr(:return, SSAValue(i))
                                      ]
                                  end)
    return ir
end

example

const sliceprintlnpass = Cassette.@pass sliceprintln

julia> begin
           a = rand(3)
           b = rand(3)
           function add(a, b)
               println("I'm about to add $a + $b")
               c = a + b
               println("c = $c")
               return c
           end
           add(a, b)
       end
I'm about to add [0.325019, 0.19358, 0.200598] + [0.195759, 0.653, 0.498859]
c = [0.520778, 0.84658, 0.699457]
3-element Array{Float64,1}:
 0.5207782045663867
 0.846579992552251
 0.6994565474128307

julia> ctx = Ctx(pass=sliceprintlnpass, metadata = () -> nothing);

julia> result, callback = Cassette.overdub(ctx, add, a, b)
#([0.520778, 0.84658, 0.699457], getfield(Main, Symbol("##4#5")){getfield(Main, Symbol("##4#5")){getfield(Main, Symbol("##18#19")),Tuple{String}},Tuple{String}}(getfield(Main, Symbol("##4#5")){getfield(Main, Symbol("##18#19")),Tuple{String}}(getfield(Main, Symbol("##18#19"))(), ("I'm about to add [0.325019, 0.19358, 0.200598] + [0.195759, 0.653, 0.498859]",)), ("c = [0.520778, 0.84658, 0.699457]",)))

julia> callback()
I'm about to add [0.325019, 0.19358, 0.200598] + [0.195759, 0.653, 0.498859]
c = [0.520778, 0.84658, 0.699457]

Contextual Tagging

Contextual Tagging

  • allow you to "Tagging" value w.r.t. a context

example

using Cassette: @context, @pass, @overdub, overdub, hasmetadata, metadata, hasmetameta,
                metameta, untag, tag, enabletagging, untagtype, istagged, istaggedtype,
                Tagged, fallback, canoverdub, similarcontext

const Typ = Core.Typeof

@context DiffCtx

const DiffCtxWithTag{T} = DiffCtx{Nothing,T}

Cassette.metadatatype(::Type{<:DiffCtx}, ::Type{T}) where {T<:Real} = T

tangent(x, context) = hasmetadata(x, context) ? metadata(x, context) : zero(untag(x, context))

function D(f, x)
    ctx = enabletagging(DiffCtx(), f)
    result = overdub(ctx, f, tag(x, ctx, oftype(x, 1.0)))
    return tangent(result, ctx)
end

function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(sin), x::Tagged{T,<:Real}) where {T}
    vx, dx = untag(x, ctx), tangent(x, ctx)
    return tag(sin(vx), ctx, cos(vx) * dx)
en

example

function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(cos), x::Tagged{T,<:Real}) where {T}
    vx, dx = untag(x, ctx), tangent(x, ctx)
    return tag(cos(vx), ctx, -sin(vx) * dx)
end

function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(*), x::Tagged{T,<:Real}, y::Tagged{T,<:Real}) where {T}
    vx, dx = untag(x, ctx), tangent(x, ctx)
    vy, dy = untag(y, ctx), tangent(y, ctx)
    return tag(vx * vy, ctx, vy * dx + vx * dy)
end

function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(*), x::Tagged{T,<:Real}, y::Real) where {T}
    vx, dx = untag(x, ctx), tangent(x, ctx)
    return tag(vx * y, ctx, y * dx)
end

function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(*), x::Real, y::Tagged{T,<:Real}) where {T}
    vy, dy = untag(y, ctx), tangent(y, ctx)
    return tag(x * vy, ctx, x * dy)
end

function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(+), x::Tagged{T,<:Real}, y::Tagged{T,<:Real}) where {T}
    vx, dx = untag(x, ctx), tangent(x, ctx)
    vy, dy = untag(y, ctx), tangent(y, ctx)
    return tag(vx + vy, ctx, dx + dy)
end

function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(+), x::Tagged{T,<:Real}, y::Real) where {T}
    vx, dx = untag(x, ctx), tangent(x, ctx)
    return tag(vx + y, ctx, dx)
end

function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(+), x::Real, y::Tagged{T,<:Real}) where {T}
    vy, dy = untag(y, ctx), tangent(y, ctx)
    return tag(x + vy, ctx, dy)
end

example

D(sin, 1) === cos(1)
D(x -> D(sin, x), 1) === -sin(1)
D(x -> sin(x) * cos(x), 1) === cos(1)^2 - sin(1)^2
D(x -> x * D(y -> x * y, 1), 2) === 4
D(x -> x * D(y -> x * y, 2), 1) === 2
D(x -> x * foo_bar_identity(x), 1) === 2.0

x = rand()
D(x -> (x + 2) * (3 + x), x) === 2x + 5
D(x -> CrazyPropModule.crazy_sum_mul([x], [x]), x) === (x + x)
D(x -> CrazyPropModule.crazy_sum_mul([x, 2], [3, x]), x) === 2x + 5

example

module CrazyPropModule
    const CONST_BINDING = Float64[]

    global GLOBAL_BINDING = 0.0

    struct Foo
        vector::Vector{Float64}
    end

    mutable struct FooContainer
        foo::Foo
    end

    mutable struct PlusFunc
        x::Float64
    end

    (f::PlusFunc)(x) = f.x + x

    const PLUSFUNC = PlusFunc(0.0)

    # implements a very convoluted `sum(x) * sum(y)`
    function crazy_sum_mul(x::Vector{Float64}, y::Vector{Float64})
        @assert length(x) === length(y)
        fooc = FooContainer(Foo(x))
        tmp = y

        # this loop sets:
        # `const_binding == x`
        # `global_binding == prod(y)`
        for i in 1:length(y)
            if iseven(i) # `fooc.foo.vector === y && tmp === x`
                v = fooc.foo.vector[i]
                push!(CONST_BINDING, tmp[i])
                global GLOBAL_BINDING = PLUSFUNC(v)
                PLUSFUNC.x = GLOBAL_BINDING
                fooc.foo = Foo(x)
                tmp = y
            else # `fooc.foo.vector === x && tmp === y`
                v = fooc.foo.vector[i]
                push!(CONST_BINDING, v)
                global GLOBAL_BINDING = PLUSFUNC(tmp[i])
                PLUSFUNC.x = GLOBAL_BINDING
                fooc.foo = Foo(y)
                tmp = x
            end
        end

        # accumulate result
        z = sum(CONST_BINDING) * GLOBAL_BINDING

        # reset global state
        empty!(CONST_BINDING)
        PLUSFUNC.x = 0.0
        global GLOBAL_BINDING = 0.0
        return z
    end
end

Conclusion

Cassette.jl

  • write context specific program
  • modify julia IR with pass injection
  • Tag julia value

Reference

  • https://jrevels.github.io/Cassette.jl/latest/
  • https://www.youtube.com/watch?v=_E2zEzNEy-8
  • https://www.youtube.com/watch?v=lyX-isPDS2M

Q&A

Made with Slides.com