Cassette.jl
Overdub Your Julia Code
About Me
- Student
The talk is based on the document & the developer's talk
What is Cassette?
What can cassette do?
- extend the Julia language by directly injecting the Julia compiler with new, context-specific behaviors.
- dynamically injecting code transformation passes into Julia’s just-in-time (JIT) compilation cycle
First Thought
for example
Method overloading
- Common & Standard
- need manually implementation
...But Stinks
- Ultimately thwarted by dispatch and/or structural type constraints in non-generic target programs.
- Proper usage of overloading-based nonstandard execution tools require proper genericity criteria, i.e. "what weird subset of Julia do I really support?".
- Not all relevant Julia language mechanisms are fully exposed/interceptable via method overloading (e.g. control flow, literals, bindings, calling scope)
Why Julia?
Why Julia?
- Julia provides a general&powerful api(AST & Julia IR), let you to play with julia compiler easily
Julia compile loop
WARNING
Situation of v0.7
Julia v0.7 in rc- packages need to be update
- mess in document
Julia v1.0 Just Release!!
warning of cassette
- each version only support specific version of Julia
- highly depend on Julia compiler
- might have performance and correctness bugs caused by either Cassette or Julia itself
Cassette v0.1.0 release at JuliaConf2018
Take a look!
Simple logging
import Cassette: @context, prehook, @overdub
@context PrintCtx
prehook(::PrintCtx, f, args...) = println(f, args)
@overdub(PrintCtx(), 1/2)
# /(1/2)
# float(1,)
# AbstractFloat(1,)
# Float64(1,)
# sitofp(Float64, 1)
# float(2,)
# AbstractFloat(2,)
# Float64(2,)
# sitofp(Float64, 2)
# /(1.0, 2.0)
# div_float(1.0, 2.0)
# 0.5
Counting Call
import Cassette: @context, prehook, @overdub
mutable struct Count{T}
count::Int
end
@context CountCtx
function prehook(ctx::CountCtx{Count{T}}, ::Any, ::T, ::Any...) where T
ctx.metadata.count += 1
end
c = Count{DataType}(0)
@overdub(CountCtx(metadata=c), 1/2)
# 0.5
julia> c
# Count{DataType}(2)
Overdub
Mental Model
# turn f(args...) into
#
begin
Cassette.prehook(context, f, args...)
tmp = Cassette.execute(context, f, args...)
tmp = isa(tmp, Cassette.OverdubInstead) ? overdub(context, f, args...) : tmp
Cassette.posthook(context, tmp, f, args...)
tmp
end
Contextual dispatch
Contextual dispatch
- run program in a specific context
- can propagate metadata within a context
example
using Cassette
Cassette.@context TraceCtx
function Cassette.execute(ctx::TraceCtx, args...)
subtrace = Any[]
push!(ctx.metadata, args => subtrace)
if Cassette.canoverdub(ctx, args...)
newctx = Cassette.similarcontext(ctx, metadata = subtrace)
return Cassette.overdub(newctx, args...)
else
return Cassette.fallback(ctx, args...)
end
end
trace = Any[]
x, y, z = rand(3)
f(x, y, z) = x*y + y*z
Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z))
# returns `true`
trace == Any[
(f,x,y,z) => Any[
(*,x,y) => Any[(Base.mul_float,x,y)=>Any[]]
(*,y,z) => Any[(Base.mul_float,y,z)=>Any[]]
(+,x*y,y*z) => Any[(Base.add_float,x*y,y*z)=>Any[]]
]
]
Contextual pass injection
pass injection
- modify the code that being pass to the compiler
pass injection
using Cassette
using Cassette: @pass, @overdub
fitsin32bit(x) = false
fitsin32bit(x::Integer) = (typemin(Int32) <= x <= typemax(Int32))
fitsin32bit(x::AbstractFloat) = (typemin(Float32) <= x <= typemax(Float32))
to32bit(x::Integer) = convert(Int32, x)
to32bit(x::AbstractFloat) = convert(Float32, x)
bit32pass = @pass (ctxtype, method_signature, method_body) -> begin
Cassette.replace_match!(to32bit, fitsin32bit, method_body.code)
return method_body
end
Cassette.@context Ctx
a = rand()
@overdub(Ctx(pass = bit32pass), foo(a))
example
using Cassette
Cassette.@context Ctx
mutable struct Callback
f::Any
end
function Cassette.execute(ctx::Ctx, ::typeof(println), args...)
previous = ctx.metadata.f
ctx.metadata.f = () -> (previous(); println(args...))
return nothing
end
example
julia> begin
a = rand(3)
b = rand(3)
function add(a, b)
println("I'm about to add $a + $b")
c = a + b
println("c = $c")
return c
end
add(a, b)
end
I'm about to add [0.457465, 0.62078, 0.954555] + [0.0791336, 0.744041, 0.976194]
c = [0.536599, 1.36482, 1.93075]
3-element Array{Float64,1}:
0.5365985032259399
1.3648210555868863
1.9307494378914405
julia> ctx = Ctx(metadata = Callback(() -> nothing));
julia> c = Cassette.overdub(ctx, add, a, b)
3-element Array{Float64,1}:
0.5365985032259399
1.3648210555868863
1.9307494378914405
julia> ctx.metadata.f()
I'm about to add [0.457465, 0.62078, 0.954555] + [0.0791336, 0.744041, 0.976194]
c = [0.536599, 1.36482, 1.93075]
example
using Cassette
using Core: CodeInfo, SlotNumber, SSAValue
Cassette.@context Ctx
function Cassette.execute(ctx::Ctx, callback, f, args...)
if Cassette.canoverdub(ctx, f, args...)
_ctx = Cassette.similarcontext(ctx, metadata = callback)
return Cassette.overdub(_ctx, f, args...) # return result, callback
else
return Cassette.fallback(ctx, f, args...), callback
end
end
function Cassette.execute(ctx::Ctx, callback, ::typeof(println), args...)
return nothing, () -> (callback(); println(args...))
end
example
function sliceprintln(::Type{<:Ctx}, ::Type{S}, ir::CodeInfo) where {S}
callbackslotname = gensym("callback")
push!(ir.slotnames, callbackslotname)
push!(ir.slotflags, 0x00)
callbackslot = SlotNumber(length(ir.slotnames))
getmetadata = Expr(:call, Expr(:nooverdub, GlobalRef(Core, :getfield)), Expr(:contextslot), QuoteNode(:metadata))
# insert the initial `callbackslot` assignment into the IR.
Cassette.insert_statements!(ir.code, ir.codelocs,
(stmt, i) -> i == 1 ? 2 : nothing,
(stmt, i) -> [Expr(:(=), callbackslot, getmetadata), stmt])
# replace all calls of the form `f(args...)` with `callback(f, args...)`, taking care to
# properly destructure the returned `(result, callback)` into the appropriate statements
Cassette.insert_statements!(ir.code, ir.codelocs,
(stmt, i) -> begin
i > 1 || return nothing # don't slice the callback assignment
stmt = Base.Meta.isexpr(stmt, :(=)) ? stmt.args[2] : stmt
return Base.Meta.isexpr(stmt, :call) ? 3 : nothing
end,
(stmt, i) -> begin
items = Any[]
callstmt = Base.Meta.isexpr(stmt, :(=)) ? stmt.args[2] : stmt
push!(items, Expr(:call, callbackslot, callstmt.args...))
push!(items, Expr(:(=), callbackslot, Expr(:call, Expr(:nooverdub, GlobalRef(Core, :getfield)), SSAValue(i), 2)))
result = Expr(:call, Expr(:nooverdub, GlobalRef(Core, :getfield)), SSAValue(i), 1)
if Base.Meta.isexpr(stmt, :(=))
result = Expr(:(=), stmt.args[1], result)
end
push!(items, result)
return items
end)
# replace return statements of the form `return x` with `return (x, callback)`
Cassette.insert_statements!(ir.code, ir.codelocs,
(stmt, i) -> Base.Meta.isexpr(stmt, :return) ? 2 : nothing,
(stmt, i) -> begin
return [
Expr(:call, Expr(:nooverdub, GlobalRef(Core, :tuple)), stmt.args[1], callbackslot)
Expr(:return, SSAValue(i))
]
end)
return ir
end
example
const sliceprintlnpass = Cassette.@pass sliceprintln
julia> begin
a = rand(3)
b = rand(3)
function add(a, b)
println("I'm about to add $a + $b")
c = a + b
println("c = $c")
return c
end
add(a, b)
end
I'm about to add [0.325019, 0.19358, 0.200598] + [0.195759, 0.653, 0.498859]
c = [0.520778, 0.84658, 0.699457]
3-element Array{Float64,1}:
0.5207782045663867
0.846579992552251
0.6994565474128307
julia> ctx = Ctx(pass=sliceprintlnpass, metadata = () -> nothing);
julia> result, callback = Cassette.overdub(ctx, add, a, b)
#([0.520778, 0.84658, 0.699457], getfield(Main, Symbol("##4#5")){getfield(Main, Symbol("##4#5")){getfield(Main, Symbol("##18#19")),Tuple{String}},Tuple{String}}(getfield(Main, Symbol("##4#5")){getfield(Main, Symbol("##18#19")),Tuple{String}}(getfield(Main, Symbol("##18#19"))(), ("I'm about to add [0.325019, 0.19358, 0.200598] + [0.195759, 0.653, 0.498859]",)), ("c = [0.520778, 0.84658, 0.699457]",)))
julia> callback()
I'm about to add [0.325019, 0.19358, 0.200598] + [0.195759, 0.653, 0.498859]
c = [0.520778, 0.84658, 0.699457]
Contextual Tagging
Contextual Tagging
- allow you to "Tagging" value w.r.t. a context
example
using Cassette: @context, @pass, @overdub, overdub, hasmetadata, metadata, hasmetameta,
metameta, untag, tag, enabletagging, untagtype, istagged, istaggedtype,
Tagged, fallback, canoverdub, similarcontext
const Typ = Core.Typeof
@context DiffCtx
const DiffCtxWithTag{T} = DiffCtx{Nothing,T}
Cassette.metadatatype(::Type{<:DiffCtx}, ::Type{T}) where {T<:Real} = T
tangent(x, context) = hasmetadata(x, context) ? metadata(x, context) : zero(untag(x, context))
function D(f, x)
ctx = enabletagging(DiffCtx(), f)
result = overdub(ctx, f, tag(x, ctx, oftype(x, 1.0)))
return tangent(result, ctx)
end
function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(sin), x::Tagged{T,<:Real}) where {T}
vx, dx = untag(x, ctx), tangent(x, ctx)
return tag(sin(vx), ctx, cos(vx) * dx)
en
example
function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(cos), x::Tagged{T,<:Real}) where {T}
vx, dx = untag(x, ctx), tangent(x, ctx)
return tag(cos(vx), ctx, -sin(vx) * dx)
end
function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(*), x::Tagged{T,<:Real}, y::Tagged{T,<:Real}) where {T}
vx, dx = untag(x, ctx), tangent(x, ctx)
vy, dy = untag(y, ctx), tangent(y, ctx)
return tag(vx * vy, ctx, vy * dx + vx * dy)
end
function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(*), x::Tagged{T,<:Real}, y::Real) where {T}
vx, dx = untag(x, ctx), tangent(x, ctx)
return tag(vx * y, ctx, y * dx)
end
function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(*), x::Real, y::Tagged{T,<:Real}) where {T}
vy, dy = untag(y, ctx), tangent(y, ctx)
return tag(x * vy, ctx, x * dy)
end
function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(+), x::Tagged{T,<:Real}, y::Tagged{T,<:Real}) where {T}
vx, dx = untag(x, ctx), tangent(x, ctx)
vy, dy = untag(y, ctx), tangent(y, ctx)
return tag(vx + vy, ctx, dx + dy)
end
function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(+), x::Tagged{T,<:Real}, y::Real) where {T}
vx, dx = untag(x, ctx), tangent(x, ctx)
return tag(vx + y, ctx, dx)
end
function Cassette.execute(ctx::DiffCtxWithTag{T}, ::Typ(+), x::Real, y::Tagged{T,<:Real}) where {T}
vy, dy = untag(y, ctx), tangent(y, ctx)
return tag(x + vy, ctx, dy)
end
example
D(sin, 1) === cos(1)
D(x -> D(sin, x), 1) === -sin(1)
D(x -> sin(x) * cos(x), 1) === cos(1)^2 - sin(1)^2
D(x -> x * D(y -> x * y, 1), 2) === 4
D(x -> x * D(y -> x * y, 2), 1) === 2
D(x -> x * foo_bar_identity(x), 1) === 2.0
x = rand()
D(x -> (x + 2) * (3 + x), x) === 2x + 5
D(x -> CrazyPropModule.crazy_sum_mul([x], [x]), x) === (x + x)
D(x -> CrazyPropModule.crazy_sum_mul([x, 2], [3, x]), x) === 2x + 5
example
module CrazyPropModule
const CONST_BINDING = Float64[]
global GLOBAL_BINDING = 0.0
struct Foo
vector::Vector{Float64}
end
mutable struct FooContainer
foo::Foo
end
mutable struct PlusFunc
x::Float64
end
(f::PlusFunc)(x) = f.x + x
const PLUSFUNC = PlusFunc(0.0)
# implements a very convoluted `sum(x) * sum(y)`
function crazy_sum_mul(x::Vector{Float64}, y::Vector{Float64})
@assert length(x) === length(y)
fooc = FooContainer(Foo(x))
tmp = y
# this loop sets:
# `const_binding == x`
# `global_binding == prod(y)`
for i in 1:length(y)
if iseven(i) # `fooc.foo.vector === y && tmp === x`
v = fooc.foo.vector[i]
push!(CONST_BINDING, tmp[i])
global GLOBAL_BINDING = PLUSFUNC(v)
PLUSFUNC.x = GLOBAL_BINDING
fooc.foo = Foo(x)
tmp = y
else # `fooc.foo.vector === x && tmp === y`
v = fooc.foo.vector[i]
push!(CONST_BINDING, v)
global GLOBAL_BINDING = PLUSFUNC(tmp[i])
PLUSFUNC.x = GLOBAL_BINDING
fooc.foo = Foo(y)
tmp = x
end
end
# accumulate result
z = sum(CONST_BINDING) * GLOBAL_BINDING
# reset global state
empty!(CONST_BINDING)
PLUSFUNC.x = 0.0
global GLOBAL_BINDING = 0.0
return z
end
end
Conclusion
Cassette.jl
- write context specific program
- modify julia IR with pass injection
- Tag julia value
Reference
- https://jrevels.github.io/Cassette.jl/latest/
- https://www.youtube.com/watch?v=_E2zEzNEy-8
- https://www.youtube.com/watch?v=lyX-isPDS2M
Q&A
Intro to Cassette.jl
By Peter Cheng
Intro to Cassette.jl
Overdub Your Julia Code
- 2,029