Staged meta-programming, new LMS frontend and computation graphs    

 

Ruben Fiszel

LAMP at EPFL

Supervised by

Nada Amin and Prof. Odersky

Semester project II

Outlook

  • Staged meta-programming
  • LMS
  • Frontend improvements
  • Case study: Computation graph

Meta-programming

Programs THAT treats programs as data

Staged meta-programming

  • A meta-program is transformed into an object program by going through one or multiple stage
  • A meta-program is staged into an object program
  • Powerful applications: A staged interpreter is a compiler
  • Staging annotations in the meta-program differentiate later stage values ("representations") and present stage values.

DSL and staging

  • Specialized compilers/transpilers enriched with DSL
  • Domain specific optimizations (out of reach for a standard compiler)
  • Target multiple language/hardware (heterogeneous) from a unified DSL
  • Useful for HPC

Lightweight Modular Staging (LMS)

Scala library for runtime code generation

Meta program

Object program

def f(x: Array[Int], y: Rep[Array[Int]]) => x.zip(y).map(_*_).sum


def g(y: Array[Int]) = f(Array(0, 1, 0, 0, 0), y)
val g = (x: Array[Int]) => x(1)

//Optimisations used: (1*a) = a, (0+a) = a, (0*a) = 0

Naive program

def f(x: Array[Int], y: Array[Int]) => x.zip(y).map(_*_).sum

def g(y: Array[Int]) = f(Array(0, 1, 0, 0, 0), y)

Staged interpreter is a compiler

 def matchsearch(regexp: String, text: String): Boolean = ...
def matchsearch(regexp: String, text: Rep[String]): Boolean = ...
//GEN CODE
def f(x: String) = ... //small specialized function for "hello" regex search

//usage
f("hello")

Naive program

Staged interpreteR ....

def f(x: String) = compile((x:Rep[String]) => matchsearch("^hello$", x))

... Is a compiler

Frontend

Construct the IR from the meta program

 

Offer common data structures and controls as batteries

Backend

Transform the IR and finally generates the object program

Frontend improvement

Painless staging annotations

Preliminaries

 

  • The IR is a typed Exp "sea of nodes", roughly a tree with Def nodes and leaves Sym and Const.
  • DSL is agnostic of the  implementation
  • DSL is separated from implementation and only mixed-in at staging

Preliminaries

trait Exp[+T]
case class Const[T](x: T) extends Exp[T]
case class Sym[T](id: Int) extends Exp[T]
trait Def[+T]

implicit def toAtom(d: Def[T]): Exp[T]
//defined elsewhere
a: Sym[Int](0) 
b: Sym[Int](1)

b*(2+a) //user-code is transformed into IR
IntTimes(Sym(1), IntPlus(Const(2), Sym(0)): Exp[Int]

The old way

  • On the frontend, defines Rep[A]
  • On the frontend, defines infix operation as Def nodes                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
  • On the backend, defines Rep[A] = Exp[A]
trait IntDsl {
    def infix_plus(e1: Rep[Int], e2: Rep[Int]): Rep[Int]
}

trait IntExp extends IntDsl with BaseExp  {
    case class IntPlus(e1: Exp[Int], e2: Exp[Int]) extends Def[Int]

    def infix_plus(e1: Exp[Int], e2: Exp[Int]) = IntPlus(e1, e2)
}

The new way: Lifted types

 

  • Lifted types shadow the type they are the Rep of
  • Controls (eg: if-then-else) or functions are also among such lifted types
  • Lifted types have typeclass instances Rep[A]
  • A & Rep[A] becomes scala.A & dsl.A
  • Functions: f(x: Rep[A]) becomes f[A: Rep](x: A)

Benefits

  • In practice, most terms can be a "Rep" and we can still do most of the optimizations thanks to Const nodes. So why bother?
  • With proper imports in scope, very few modifications to naive code
  • Elegant signatures thanks to context bounds

 

 

  • Implicit infix_ overloading resolution O(n^2) => Removed with the new frontend => Compiling time reduction
def ifThenElse[A: Rep](a: Boolean, b: A, c: A): A

How

  • Using Rep context bound instead of Rep monad                                                                                                                                                                                                                                                                                                                      
  • Lifted types and methods using case class instead of infix operations                                                                                                                                                                                                                                                                                                                                                                     
case class Int(e: Exp[scala.Int]) {
    def +(y: Int) = IntPlus(e, y.e)
}

implicit def repInt: Rep[Int] = ...
  trait Rep[T] {
    type Internal
    def from(e:Exp[Internal]): T
    def to(x:T):Exp[Internal]
    def m: Manifest[Internal] //for reflection purposes
  }
trait Prog extends Ints {

  //the old way
  def ackermann(m: Int): Rep[Int => Int]
   = fun {
   (n: Rep[Int]) =>
      if (m==0) n+1
      else if (n==0) a(m-1)(1)
      else a(m-1)(a(m)(n-1))
  }

}


User code

trait Prog extends Ints {

  //the new way
  def ackermann(m: scala.Int): Int => Int
   (n: Int) =>
      if (m==0) n+1
      else if (n==0) a(m-1)(1)
      else a(m-1)(a(m)(n-1))
  }

}


trait ProgImpl extends Prog with IntOptImpl 

trait Compile extends ScalaCompile {
  self: IntExp =>

  val codegen =  new ScalaGenInt {
    val IR:self.type = self

  }
}

object App extends ProgImpl with Compile {

    val compiled = compile(prog)
    compiled(2)
    compiled(3)

}

The rest

Delite

  • Delite is a closely related project from the PPL lab at Stanford
  • Written on top of scala-lms to provide tools to build high performance DSL relying on parallel collections
  • New frontend is useful to them

My contribution

  • Shifting the entire codebase of scala-lms to this new frontend.
  • Some complete rewrites: most of common, in particular: functions, if-then-else, primitive types, Arrays, List
  • Some others were mechanic rewrites (most of internal/).

Case study

Computation graphs

Side goal

  • Proof-of-concept for feature-parity of the new frontend
  • Features used: Lifted types interface inheritance, List, Matrices, Functions, If-Then-Else, Codegen, Shadowed types
  • Show benefits of staged meta-programming

Computation graphs

  • DAG that represent computation flow
  • Very common in machine learning libraries like TensorFlow or Deeplearning4j
  • Nodes represent an operation (+, -, *, dot product, max, min) applied to the input and the output is the result
  • Build once, run often

Features achieved

  • Arithmetic CG
  • Evaluable Graph
  • Matrices as Data
  • Differentiable CG through Backpropagation

Improvements from Staging

  • Topology, dimensions, acyclic checks at staging time => Safer than at runtime (Especially with String Node)
  • Unrolling of the graph abstractions => Performance improvements
  • Automatic optimisations on the unrolled operations => Performance improvements

Benchmark

Conclusion

Made with Slides.com