Staged programming in Scala 3
Chris Birchall
47 Degrees
# ALL THE WORLD'S A STAGE
# FUNDAMENTALS
"code about code"
we will be dealing with functions that:
Q: Why would we want to do that?
# FUNDAMENTALS
val x = 3
val y = 2
x + y + 5
val expr: Expr[Int] = '{
val x = 3
val y = 2
x + y + 5
}
quote
plain old
Scala code
representation
of that code
# FUNDAMENTALS
val x = 3
val y = 2
x + y + 5
quote
plain old
Scala code
'{
...
$expr
...
}
val expr: Expr[Int] = '{
val x = 3
val y = 2
x + y + 5
}
representation
of that code
splice
# FUNDAMENTALS
${'{e}} = e
'{${e}} = e
Quoting and splicing are duals.
For any expression e:
# FUNDAMENTALS
val x = 1
Level
0
Plain old Scala code is level 0
# FUNDAMENTALS
def foo(using Quotes): Expr[Int] = '{ 1 + 2 }
Level
0
1
Quoting increases level by 1
# FUNDAMENTALS
def foo(x: Expr[Int])(using Quotes): Expr[Int] = '{ ${x} + 1 }
Level
0
1
Splicing reduces level by 1
# FUNDAMENTALS
Variables must be bound and used on the same level
(x is bound on level 0, trying to use on level 1)
def foo(using Quotes): Expr[Int] =
val x = 1
'{ x + 2 }
(x is bound on level 1, trying to use on level 0)
def bar(using Quotes): Expr[Int => Int] =
'{ (x: Int) => $x + 2 }
# FUNDAMENTALS
Quotes indicate that we are constructing a computation which will run in a future stage
A splice indicates that we must perform an immediate computation while building the quoted computation
def someComplicatedFunction(): Expr[Int] = ???
// returns '{ 4 + 5 }
'{ ${someComplicatedFunction()} * 2 }
# FUNDAMENTALS
s"foo ${val y = "yeah"; s"hello $y wow"} baz"
val x = "bar"
s"foo $x baz"
# MACROS
import scala.quoted.*
def unlessImpl(pred: Expr[Boolean], f: Expr[Unit])(using Quotes): Expr[Unit] =
'{ if (!$pred) $f }
# MACROS
inline def unless(pred: Boolean)(f: => Unit): Unit = ${ unlessImpl('pred, 'f) }
import scala.quoted.*
def unlessImpl(pred: Expr[Boolean], f: Expr[Unit])(using Quotes): Expr[Unit] =
'{ if (!$pred) $f }
# MACROS
unless(x >= 1){ println("x was less than 1") }
inline def unless(pred: Boolean)(f: => Unit): Unit = ${ unlessImpl('pred, 'f) }
import scala.quoted.*
def unlessImpl(pred: Expr[Boolean], f: Expr[Unit])(using Quotes): Expr[Unit] =
'{ if (!$pred) $f }
# MACROS
unless(x >= 1){ println("x was less than 1")}
${ unlessImpl('{ x >= 1 }, '{ println("x was less than 1") }) }
inlining
if (!${'{ x >= 1 }}) ${'{ println("x was less than 1") }}
splice
if (!(x >= 1)) { println("x was less than 1") }
more splices
# MACROS
Without a macro:
def factorial(n: Int): Int =
n match {
case 1 => 1
case x => x * factorial(x - 1)
}
Let's move that recursion from runtime to compile time
# MACROS
Without a macro:
def factorial(n: Int): Int =
n match {
case 1 => 1
case x => x * factorial(x - 1)
}
Equivalent macro:
def factorialMacro(n: Expr[Int]): Expr[Int] =
n match {
case 1 => '{1}
case x => '{ x * ${factorialMacro(x - 1)} }
}
can't match on an Expr like that
# MACROS
def factorialMacro(n: Expr[Int]): Expr[Int] =
n.valueOrError match {
case 1 => '{1}
case x => '{ x * ${factorialMacro(...)} }
}
can't reference value x like that
[error] 14 | case x => '{ x * ${factorialMacro(...)} }
[error] | ^
[error] | access to value x from wrong staging level:
[error] | - the definition is at level 0,
[error] | - but the access is at level 1.
# MACROS
def factorialMacro(n: Expr[Int]): Expr[Int] =
n.valueOrError match {
case 1 => '{1}
case x => '{ $x * ${factorialMacro(...)} }
}
nope, that doesn't make sense
[error] 14 | case x => '{ $x * ${factorialMacro(...)} }
[error] | ^
[error] | Found: (x : Int)
[error] | Required: quoted.Expr[Any]
umm... splice it?
# MACROS
def factorialMacro(n: Expr[Int]): Expr[Int] =
n.valueOrError match {
case 1 => '{1}
case x => '{ '{x} * ${factorialMacro(...)} }
}
no, that gives us weird nested quotes
[error] 14 | case x => '{ '{x} * ${factorialMacro(...)} }
[error] | ^^^^^^
[error] |value * is not a member of quoted.Expr[Int], ...
ok, quote it?
# MACROS
def factorialMacro(n: Expr[Int]): Expr[Int] =
n.valueOrError match {
case 1 => '{1}
case x => '{ ${Expr(x)} * ${factorialMacro(Expr(x - 1))} }
}
we need to lift the static value into a representation
object Expr {
...
/** Creates an expression that will construct the value `x` */
def apply[T](x: T)(using ToExpr[T])(using Quotes): Expr[T] =
scala.Predef.summon[ToExpr[T]].apply(x)
}
# MACROS
inline def factorial(n: Int): Int = ${ factorialMacro('n) }
finally we can call our macro!
println(factorial(5)) // prints 120
# MACROS
println("Give me a number and I'll calculate its factorial for you")
println(factorial(scala.io.StdIn.readInt()))
can we build factorial-as-a-service?
[error] 8 | println(factorial(scala.io.StdIn.readInt()))
[error] | ^^^^^^^^^^^^^^^^^^^^^^^^
[error] |Expected a known value.
[error] |
[error] |The value of: n$proxy1
[error] |could not be extracted using scala.quoted.FromExpr$PrimitiveFromExpr@4755601c
nope, a macro can't match on a value that's not known until runtime
# STAGED PROGRAMMING
# STAGED PROGRAMMING
def factorialStaged(n: Int)(using Quotes): Expr[Int] =
n match {
case 1 => '{1}
case x => '{${Expr(x)} * ${factorialStaged(x - 1)}}
}
import scala.quoted.staging.*
def runFactorialStaged(n: Int): Int =
given Compiler = Compiler.make(getClass.getClassLoader)
run(factorialStaged(n))
# STAGED PROGRAMMING
Macros
Runtime staged programming
Use quotes 'n' splices...
Use quotes 'n' splices...
to construct a program fragment...
to construct a program fragment...
at compile time...
at runtime...
and inline it at compile time
and then interpret it at runtime
# STAGED PROGRAMMING
def member[A](list: List[A])(a: A): Boolean =
list match {
case Nil => false
case x :: xs => (a == x) || member(xs)(a)
}
# STAGED PROGRAMMING
def member[A](list: List[A])(a: A): Boolean =
list match {
case Nil => false
case x :: xs => (a == x) || member(xs)(a)
}
def memberStaged[A: Type: ToExpr](list: List[A])(a: Expr[A])(using Quotes): Expr[Boolean] =
list match {
case Nil => '{ false }
case x :: xs => '{ ($a == ${Expr(x)}) || ${memberStaged(xs)(a)} }
}
# STAGED PROGRAMMING
def memberStaged[A: Type: ToExpr](list: List[A])(a: Expr[A])(using Quotes): Expr[Boolean] =
list match {
case Nil => '{ false }
case x :: xs => '{ ($a == ${Expr(x)}) || ${memberStaged(xs)(a)} }
}
def stage(list: List[String]): String => Boolean =
given Compiler = Compiler.make(getClass.getClassLoader)
run(
val code: Expr[String => Boolean] = '{ (x: String) => ${memberStaged(list)('x)} }
println("Staged code: " + code.show)
code
)
# STAGED PROGRAMMING
def stage(list: List[String]): String => Boolean =
given Compiler = Compiler.make(getClass.getClassLoader)
run(
val code: Expr[String => Boolean] = '{ (x: String) => ${memberStaged(list)('x)} }
println("Staged code: " + code.show)
code
)
val contains: String => Boolean = stage(List("foo", "bar", "baz"))
// Staged code: ((x: String) => x.==("foo").||(x.==("bar").||(x.==("baz").||(false))))
contains("bar") // true
contains("wow") // false
# STAGED PROGRAMMING
Quantified Boolean Formula
# STAGED PROGRAMMING
Let's implement a DSL for QBF
Two main approaches
# STAGED PROGRAMMING
enum QBF:
case Var(name: String)
case And(a: QBF, b: QBF)
case Or(a: QBF, b: QBF)
case Not(a: QBF)
case Implies(ante: QBF, cons: QBF)
case Forall(name: String, a: QBF)
case Exists(name: String, a: QBF)
# STAGED PROGRAMMING
def eval(qbf: QBF, env: Map[String, Boolean]): Boolean =
qbf match {
case Var(name) => env(name)
case And(a, b) => eval(a, env) && eval(b, env)
case Or(a, b) => eval(a, env) || eval(b, env)
case Not(a) => !(eval(a, env))
case Implies(ante, cons) => eval(Or(cons, And(Not(ante), Not(cons))), env)
case Forall(name, a) =>
def check(value: Boolean) = eval(a, env + (name -> value))
check(true) && check(false)
case Exists(name, a) =>
def check(value: Boolean) = eval(a, env + (name -> value))
check(true) || check(false)
}
def evaluate(qbf: QBF): Boolean = eval(qbf, Map.empty)
# STAGED PROGRAMMING
def evalStaged(qbf: QBF, env: Map[String, Expr[Boolean]])(using Quotes): Expr[Boolean] =
qbf match {
case Var(name) => env(name)
case And(a, b) => '{ ${evalStaged(a, env)} && ${evalStaged(b, env)} }
case Or(a, b) => '{ ${evalStaged(a, env)} || ${evalStaged(b, env)} }
case Not(a) => '{ ! ${evalStaged(a, env)} }
case Implies(ante, cons) => evalStaged(Or(cons, And(Not(ante), Not(cons))), env)
case Forall(name, a) => '{
def check(value: Boolean) = ${evalStaged(a, env + (name -> 'value))}
check(true) && check(false)
}
case Exists(name, a) => '{
def check(value: Boolean) = ${evalStaged(a, env + (name -> 'value))}
check(true) || check(false)
}
}
def evaluateStaged(qbf: QBF): Boolean =
given Compiler = Compiler.make(getClass.getClassLoader)
run(evalStaged(qbf, Map.empty))
# STAGED PROGRAMMING
{
def check(value: scala.Boolean): scala.Boolean =
value.||(value.||(value.unary_!).||(value.unary_!.&&(value.||(value.unary_!).unary_!)))
check(true).&&(check(false))
}
# STAGED PROGRAMMING
If we can construct
Expr[T]
there's nothing* to stop us from constructing
Expr[Expr[T]] Expr[Expr[Expr[T]]]
...
*apart from a desire to preserve our sanity, and a slightly clunky developer experience in Scala 3
# STAGED PROGRAMMING
Examples
# CONCLUSION
# CONCLUSION
# CONCLUSION
# CONCLUSION