Staged meta-programming, new LMS frontend and computation graphs |
|
Ruben Fiszel
LAMP at EPFL
Supervised by
Nada Amin and Prof. Odersky
Semester project II
Outlook
- Staged meta-programming
- LMS
- Frontend improvements
- Case study: Computation graph
Meta-programming
Programs THAT treats programs as data
Staged meta-programming
- A meta-program is transformed into an object program by going through one or multiple stage
- A meta-program is staged into an object program
- Powerful applications: A staged interpreter is a compiler
- Staging annotations in the meta-program differentiate later stage values ("representations") and present stage values.
DSL and staging
- Specialized compilers/transpilers enriched with DSL
- Domain specific optimizations (out of reach for a standard compiler)
- Target multiple language/hardware (heterogeneous) from a unified DSL
- Useful for HPC
Lightweight Modular Staging (LMS)
Scala library for runtime code generation
Meta program
Object program
def f(x: Array[Int], y: Rep[Array[Int]]) => x.zip(y).map(_*_).sum
def g(y: Array[Int]) = f(Array(0, 1, 0, 0, 0), y)
val g = (x: Array[Int]) => x(1)
//Optimisations used: (1*a) = a, (0+a) = a, (0*a) = 0
Naive program
def f(x: Array[Int], y: Array[Int]) => x.zip(y).map(_*_).sum
def g(y: Array[Int]) = f(Array(0, 1, 0, 0, 0), y)
Staged interpreter is a compiler
def matchsearch(regexp: String, text: String): Boolean = ...
def matchsearch(regexp: String, text: Rep[String]): Boolean = ...
//GEN CODE
def f(x: String) = ... //small specialized function for "hello" regex search
//usage
f("hello")
Naive program
Staged interpreteR ....
def f(x: String) = compile((x:Rep[String]) => matchsearch("^hello$", x))
... Is a compiler
Frontend
Construct the IR from the meta program
Offer common data structures and controls as batteries
Backend
Transform the IR and finally generates the object program
Frontend improvement
Painless staging annotations
Preliminaries
- The IR is a typed Exp "sea of nodes", roughly a tree with Def nodes and leaves Sym and Const.
- DSL is agnostic of the implementation
- DSL is separated from implementation and only mixed-in at staging
Preliminaries
trait Exp[+T]
case class Const[T](x: T) extends Exp[T]
case class Sym[T](id: Int) extends Exp[T]
trait Def[+T]
implicit def toAtom(d: Def[T]): Exp[T]
//defined elsewhere
a: Sym[Int](0)
b: Sym[Int](1)
b*(2+a) //user-code is transformed into IR
IntTimes(Sym(1), IntPlus(Const(2), Sym(0)): Exp[Int]
The old way
- On the frontend, defines Rep[A]
- On the frontend, defines infix operation as Def nodes
- On the backend, defines Rep[A] = Exp[A]
trait IntDsl {
def infix_plus(e1: Rep[Int], e2: Rep[Int]): Rep[Int]
}
trait IntExp extends IntDsl with BaseExp {
case class IntPlus(e1: Exp[Int], e2: Exp[Int]) extends Def[Int]
def infix_plus(e1: Exp[Int], e2: Exp[Int]) = IntPlus(e1, e2)
}
The new way: Lifted types
- Lifted types shadow the type they are the Rep of
- Controls (eg: if-then-else) or functions are also among such lifted types
- Lifted types have typeclass instances Rep[A]
- A & Rep[A] becomes scala.A & dsl.A
- Functions: f(x: Rep[A]) becomes f[A: Rep](x: A)
Benefits
- In practice, most terms can be a "Rep" and we can still do most of the optimizations thanks to Const nodes. So why bother?
- With proper imports in scope, very few modifications to naive code
- Elegant signatures thanks to context bounds
- Implicit infix_ overloading resolution O(n^2) => Removed with the new frontend => Compiling time reduction
def ifThenElse[A: Rep](a: Boolean, b: A, c: A): A
How
- Using Rep context bound instead of Rep monad
- Lifted types and methods using case class instead of infix operations
case class Int(e: Exp[scala.Int]) {
def +(y: Int) = IntPlus(e, y.e)
}
implicit def repInt: Rep[Int] = ...
trait Rep[T] {
type Internal
def from(e:Exp[Internal]): T
def to(x:T):Exp[Internal]
def m: Manifest[Internal] //for reflection purposes
}
trait Prog extends Ints {
//the old way
def ackermann(m: Int): Rep[Int => Int]
= fun {
(n: Rep[Int]) =>
if (m==0) n+1
else if (n==0) a(m-1)(1)
else a(m-1)(a(m)(n-1))
}
}
User code
trait Prog extends Ints {
//the new way
def ackermann(m: scala.Int): Int => Int
(n: Int) =>
if (m==0) n+1
else if (n==0) a(m-1)(1)
else a(m-1)(a(m)(n-1))
}
}
trait ProgImpl extends Prog with IntOptImpl
trait Compile extends ScalaCompile {
self: IntExp =>
val codegen = new ScalaGenInt {
val IR:self.type = self
}
}
object App extends ProgImpl with Compile {
val compiled = compile(prog)
compiled(2)
compiled(3)
}
The rest
Delite
- Delite is a closely related project from the PPL lab at Stanford
- Written on top of scala-lms to provide tools to build high performance DSL relying on parallel collections
- New frontend is useful to them
My contribution
- Shifting the entire codebase of scala-lms to this new frontend.
- Some complete rewrites: most of common, in particular: functions, if-then-else, primitive types, Arrays, List
- Some others were mechanic rewrites (most of internal/).
Case study
Computation graphs
Side goal
- Proof-of-concept for feature-parity of the new frontend
- Features used: Lifted types interface inheritance, List, Matrices, Functions, If-Then-Else, Codegen, Shadowed types
- Show benefits of staged meta-programming
Computation graphs
- DAG that represent computation flow
- Very common in machine learning libraries like TensorFlow or Deeplearning4j
- Nodes represent an operation (+, -, *, dot product, max, min) applied to the input and the output is the result
- Build once, run often
Features achieved
- Arithmetic CG
- Evaluable Graph
- Matrices as Data
- Differentiable CG through Backpropagation
Improvements from Staging
- Topology, dimensions, acyclic checks at staging time => Safer than at runtime (Especially with String Node)
- Unrolling of the graph abstractions => Performance improvements
- Automatic optimisations on the unrolled operations => Performance improvements
Benchmark
Conclusion
Semester project II: Ruben Fiszel
By Ruben Fiszel
Semester project II: Ruben Fiszel
- 580