Staged programming in Scala 3
Chris Birchall
47 Degrees
# ALL THE WORLD'S A STAGE
# FUNDAMENTALS
"code about code"
we will be dealing with functions that:
Q: Why would we want to do that?
# FUNDAMENTALS
val x = 3
val y = 2
x + y + 5val expr: Expr[Int] = '{
val x = 3
val y = 2
x + y + 5
}quote
plain old
Scala code
representation
of that code
# FUNDAMENTALS
val x = 3
val y = 2
x + y + 5quote
plain old
Scala code
'{
...
$expr
...
}val expr: Expr[Int] = '{
val x = 3
val y = 2
x + y + 5
}representation
of that code
splice
# FUNDAMENTALS
${'{e}} = e
'{${e}} = eQuoting and splicing are duals.
For any expression e:
# FUNDAMENTALS
val x = 1Level
0
Plain old Scala code is level 0
# FUNDAMENTALS
def foo(using Quotes): Expr[Int] = '{ 1 + 2 }Level
0
1
Quoting increases level by 1
# FUNDAMENTALS
def foo(x: Expr[Int])(using Quotes): Expr[Int] = '{ ${x} + 1 }Level
0
1
Splicing reduces level by 1
# FUNDAMENTALS
Variables must be bound and used on the same level
(x is bound on level 0, trying to use on level 1)
def foo(using Quotes): Expr[Int] =
val x = 1
'{ x + 2 }(x is bound on level 1, trying to use on level 0)
def bar(using Quotes): Expr[Int => Int] =
'{ (x: Int) => $x + 2 }# FUNDAMENTALS
Quotes indicate that we are constructing a computation which will run in a future stage
A splice indicates that we must perform an immediate computation while building the quoted computation
def someComplicatedFunction(): Expr[Int] = ???
// returns '{ 4 + 5 }
'{ ${someComplicatedFunction()} * 2 }# FUNDAMENTALS
s"foo ${val y = "yeah"; s"hello $y wow"} baz"val x = "bar"
s"foo $x baz"# MACROS
import scala.quoted.*
def unlessImpl(pred: Expr[Boolean], f: Expr[Unit])(using Quotes): Expr[Unit] =
'{ if (!$pred) $f }# MACROS
inline def unless(pred: Boolean)(f: => Unit): Unit = ${ unlessImpl('pred, 'f) }import scala.quoted.*
def unlessImpl(pred: Expr[Boolean], f: Expr[Unit])(using Quotes): Expr[Unit] =
'{ if (!$pred) $f }# MACROS
unless(x >= 1){ println("x was less than 1") }inline def unless(pred: Boolean)(f: => Unit): Unit = ${ unlessImpl('pred, 'f) }import scala.quoted.*
def unlessImpl(pred: Expr[Boolean], f: Expr[Unit])(using Quotes): Expr[Unit] =
'{ if (!$pred) $f }# MACROS
unless(x >= 1){ println("x was less than 1")}${ unlessImpl('{ x >= 1 }, '{ println("x was less than 1") }) }inlining
if (!${'{ x >= 1 }}) ${'{ println("x was less than 1") }}splice
if (!(x >= 1)) { println("x was less than 1") }more splices
# MACROS
Without a macro:
def factorial(n: Int): Int =
n match {
case 1 => 1
case x => x * factorial(x - 1)
}Let's move that recursion from runtime to compile time
# MACROS
Without a macro:
def factorial(n: Int): Int =
n match {
case 1 => 1
case x => x * factorial(x - 1)
}Equivalent macro:
def factorialMacro(n: Expr[Int]): Expr[Int] =
n match {
case 1 => '{1}
case x => '{ x * ${factorialMacro(x - 1)} }
}can't match on an Expr like that
# MACROS
def factorialMacro(n: Expr[Int]): Expr[Int] =
n.valueOrError match {
case 1 => '{1}
case x => '{ x * ${factorialMacro(...)} }
}can't reference value x like that
[error] 14 | case x => '{ x * ${factorialMacro(...)} }
[error] | ^
[error] | access to value x from wrong staging level:
[error] | - the definition is at level 0,
[error] | - but the access is at level 1.# MACROS
def factorialMacro(n: Expr[Int]): Expr[Int] =
n.valueOrError match {
case 1 => '{1}
case x => '{ $x * ${factorialMacro(...)} }
}nope, that doesn't make sense
[error] 14 | case x => '{ $x * ${factorialMacro(...)} }
[error] | ^
[error] | Found: (x : Int)
[error] | Required: quoted.Expr[Any]umm... splice it?
# MACROS
def factorialMacro(n: Expr[Int]): Expr[Int] =
n.valueOrError match {
case 1 => '{1}
case x => '{ '{x} * ${factorialMacro(...)} }
}no, that gives us weird nested quotes
[error] 14 | case x => '{ '{x} * ${factorialMacro(...)} }
[error] | ^^^^^^
[error] |value * is not a member of quoted.Expr[Int], ...ok, quote it?
# MACROS
def factorialMacro(n: Expr[Int]): Expr[Int] =
n.valueOrError match {
case 1 => '{1}
case x => '{ ${Expr(x)} * ${factorialMacro(Expr(x - 1))} }
}we need to lift the static value into a representation
object Expr {
...
/** Creates an expression that will construct the value `x` */
def apply[T](x: T)(using ToExpr[T])(using Quotes): Expr[T] =
scala.Predef.summon[ToExpr[T]].apply(x)
}# MACROS
inline def factorial(n: Int): Int = ${ factorialMacro('n) }finally we can call our macro!
println(factorial(5)) // prints 120 # MACROS
println("Give me a number and I'll calculate its factorial for you")
println(factorial(scala.io.StdIn.readInt()))can we build factorial-as-a-service?
[error] 8 | println(factorial(scala.io.StdIn.readInt()))
[error] | ^^^^^^^^^^^^^^^^^^^^^^^^
[error] |Expected a known value.
[error] |
[error] |The value of: n$proxy1
[error] |could not be extracted using scala.quoted.FromExpr$PrimitiveFromExpr@4755601cnope, a macro can't match on a value that's not known until runtime
# STAGED PROGRAMMING
# STAGED PROGRAMMING
def factorialStaged(n: Int)(using Quotes): Expr[Int] =
n match {
case 1 => '{1}
case x => '{${Expr(x)} * ${factorialStaged(x - 1)}}
}import scala.quoted.staging.*
def runFactorialStaged(n: Int): Int =
given Compiler = Compiler.make(getClass.getClassLoader)
run(factorialStaged(n))# STAGED PROGRAMMING
Macros
Runtime staged programming
Use quotes 'n' splices...
Use quotes 'n' splices...
to construct a program fragment...
to construct a program fragment...
at compile time...
at runtime...
and inline it at compile time
and then interpret it at runtime
# STAGED PROGRAMMING
def member[A](list: List[A])(a: A): Boolean =
list match {
case Nil => false
case x :: xs => (a == x) || member(xs)(a)
}# STAGED PROGRAMMING
def member[A](list: List[A])(a: A): Boolean =
list match {
case Nil => false
case x :: xs => (a == x) || member(xs)(a)
}def memberStaged[A: Type: ToExpr](list: List[A])(a: Expr[A])(using Quotes): Expr[Boolean] =
list match {
case Nil => '{ false }
case x :: xs => '{ ($a == ${Expr(x)}) || ${memberStaged(xs)(a)} }
}# STAGED PROGRAMMING
def memberStaged[A: Type: ToExpr](list: List[A])(a: Expr[A])(using Quotes): Expr[Boolean] =
list match {
case Nil => '{ false }
case x :: xs => '{ ($a == ${Expr(x)}) || ${memberStaged(xs)(a)} }
}def stage(list: List[String]): String => Boolean =
given Compiler = Compiler.make(getClass.getClassLoader)
run(
val code: Expr[String => Boolean] = '{ (x: String) => ${memberStaged(list)('x)} }
println("Staged code: " + code.show)
code
)# STAGED PROGRAMMING
def stage(list: List[String]): String => Boolean =
given Compiler = Compiler.make(getClass.getClassLoader)
run(
val code: Expr[String => Boolean] = '{ (x: String) => ${memberStaged(list)('x)} }
println("Staged code: " + code.show)
code
)val contains: String => Boolean = stage(List("foo", "bar", "baz"))
// Staged code: ((x: String) => x.==("foo").||(x.==("bar").||(x.==("baz").||(false))))
contains("bar") // true
contains("wow") // false# STAGED PROGRAMMING
Quantified Boolean Formula
# STAGED PROGRAMMING
Let's implement a DSL for QBF
Two main approaches
# STAGED PROGRAMMING
enum QBF:
case Var(name: String)
case And(a: QBF, b: QBF)
case Or(a: QBF, b: QBF)
case Not(a: QBF)
case Implies(ante: QBF, cons: QBF)
case Forall(name: String, a: QBF)
case Exists(name: String, a: QBF)# STAGED PROGRAMMING
def eval(qbf: QBF, env: Map[String, Boolean]): Boolean =
qbf match {
case Var(name) => env(name)
case And(a, b) => eval(a, env) && eval(b, env)
case Or(a, b) => eval(a, env) || eval(b, env)
case Not(a) => !(eval(a, env))
case Implies(ante, cons) => eval(Or(cons, And(Not(ante), Not(cons))), env)
case Forall(name, a) =>
def check(value: Boolean) = eval(a, env + (name -> value))
check(true) && check(false)
case Exists(name, a) =>
def check(value: Boolean) = eval(a, env + (name -> value))
check(true) || check(false)
}def evaluate(qbf: QBF): Boolean = eval(qbf, Map.empty)# STAGED PROGRAMMING
def evalStaged(qbf: QBF, env: Map[String, Expr[Boolean]])(using Quotes): Expr[Boolean] =
qbf match {
case Var(name) => env(name)
case And(a, b) => '{ ${evalStaged(a, env)} && ${evalStaged(b, env)} }
case Or(a, b) => '{ ${evalStaged(a, env)} || ${evalStaged(b, env)} }
case Not(a) => '{ ! ${evalStaged(a, env)} }
case Implies(ante, cons) => evalStaged(Or(cons, And(Not(ante), Not(cons))), env)
case Forall(name, a) => '{
def check(value: Boolean) = ${evalStaged(a, env + (name -> 'value))}
check(true) && check(false)
}
case Exists(name, a) => '{
def check(value: Boolean) = ${evalStaged(a, env + (name -> 'value))}
check(true) || check(false)
}
}def evaluateStaged(qbf: QBF): Boolean =
given Compiler = Compiler.make(getClass.getClassLoader)
run(evalStaged(qbf, Map.empty))# STAGED PROGRAMMING
{
def check(value: scala.Boolean): scala.Boolean =
value.||(value.||(value.unary_!).||(value.unary_!.&&(value.||(value.unary_!).unary_!)))
check(true).&&(check(false))
}# STAGED PROGRAMMING
If we can construct
Expr[T]
there's nothing* to stop us from constructing
Expr[Expr[T]] Expr[Expr[Expr[T]]]
...
*apart from a desire to preserve our sanity, and a slightly clunky developer experience in Scala 3
# STAGED PROGRAMMING
Examples
# CONCLUSION
# CONCLUSION
# CONCLUSION
# CONCLUSION