A Gentle Introduction to Recursion Schemes

Jean-Remi Desjardins - LambdaConf 2016

Outline

  • Motivate it
  • How it works
  • Examples
  • How to use it

Recursion Schemes

Outline

  • Motivate it
  • How it works
  • Examples
  • How to use it
  • Filesystems
  • Compilers
  • JSON
  • etc

Recursive Data Types

  • while loops 💩
  • ad hoc recursion 🙂
  • recursion schemes 💪

How do we reason, transform, process recursive data?

define recursive algorithms once and for all

Recursion Schemes

define your recursive data types in fixed point style

case class CompositeTree[A](leftChild: A, rightChild: A)

Benefits of Recursion Schemes

  • define recursive algorithms only once
  • decouple how a function recurses over data from what the function actually does
  • avoid general recursion
  • more concise and DRY code

Drawback:

  • Arguably uglier data types

Example

Modeling a tournament

Initial Model

final class Participant(value: String) extends AnyVal

sealed trait Draw
case class FutureMatch(left: Draw, right: Draw) extends Draw
case class Match(left: Participant, right: Participant) extends Draw

Match

FutureMatch

Initial Model

final class Participant(value: String) extends AnyVal

sealed trait Draw
case class FutureMatch(left: Draw, right: Draw) extends Draw
case class Match(left: Participant, right: Participant) extends Draw
def size(d: Draw): Int = {
    d match {
        case Match(_,_) => 2
        case FutureMatch(left, right) => size(left) + size(right)
    }
}

Fix Point Style

final class Participant(value: String) extends AnyVal

type Draw = Fix[DrawF]

sealed trait DrawF[A]
case class FutureMatchF[A](left: A, right: A) extends DrawF[A]
case class MatchF[A](left: Participant, right: Participant) extends DrawF[A]

Match

FutureMatch

Fix Point Style

def size(d: Draw): Int = {
    d.cata {
        case MatchF(_,_) => 2
        case FutureMatchF(left, right) => left + right
    }
}
final class Participant(value: String) extends AnyVal

type Draw = Fix[DrawF]

sealed trait DrawF[A]
case class FutureMatchF[A](left: A, right: A) extends DrawF[A]
case class MatchF[A](left: Participant, right: Participant) extends DrawF[A]

Initial model

sealed trait Draw
case class FutureMatch(left: Draw, right: Draw) extends Draw
case class Match(left: Participant, right: Participant) extends Draw
// The participant in the draw with the best chance of winning the tournament
def bestChanceOfWinning(d: Draw): Task[Participant] = {
  def bestChanceOfWinningR(d: Draw): Task[(Participant,Float)] =
    d match {
      case Match(left, right) => chanceOfWining(left, right).map(chance =>
        if (chance > 0.5) (left, chance) else (right, (1 - chance))
      )
      case FutureMatch(left, right) =>
        for {
          left  <- bestChanceOfWinningR(left)
          (leftPart, leftChance) = left
          right <- bestChanceOfWinningR(right)
          (rightPart, rightChance) = right
          chance <- chanceOfWining(leftPart, rightPart)
        } yield {
          val leftNewChance = leftChance * chance
          val rightNewChance = rightChance * (1 - chance)
          if (leftNewChance > rightNewChance) (leftPart, leftNewChance)
          else (rightPart, rightNewChance)
        }
    }
  bestChanceOfWinningR(d)._1
}
def probabilityOfWin(candidate: Participant, over: Participant): Task[Float]

Fixed Point Style

// The participant in the draw with the best chance of winning the tournament
def bestChanceOfWinning(d: Draw) =
  d.cataM{
    case MatchF(left, right) =>
      chanceOfWining(left, right).map( chance =>
        if (chance > 0.5) (left, chance) else (right, (1 - chance))
      )
    case FutureMatchF((leftPart, leftChance), (rightPart, rightChance)) =>
      chanceOfWining(leftPart, rightPart).map { chance =>
        val leftNewChance = leftChance * chance
        val rightNewChance = rightChance * (1 - chance)
        if (leftNewChance > rightNewChance) (leftPart, leftNewChance)
        else (rightPart, rightNewChance)
      }
  }._1
def probabilityOfWin(candidate: Participant, over: Participant): Task[Float]
type Draw = Fix[DrawF]

sealed trait DrawF[A]
case class FutureMatchF[A](left: A, right: A) extends DrawF[A]
case class MatchF[A](left: Participant, right: Participant) extends DrawF[A]

Comparison

def bestChanceOfWinning(d: Draw) =
  d.cataM{
    case FutureMatchF((leftPart, leftChance), (rightPart, rightChance)) =>
      chanceOfWining(leftPart, rightPart).map { chance =>
        val leftNewChance = leftChance * chance
        val rightNewChance = rightChance * (1 - chance)
        if (leftNewChance > rightNewChance) (leftPart, leftNewChance)
        else (rightPart, rightNewChance)
      }
  }._1
def bestChanceOfWinning(d: Draw): Task[Participant] = {
  def bestChanceOfWinningR(d: Draw): Task[(Participant,Float)] =
    d match {
      case FutureMatch(left, right) =>
        for {
          left  <- bestChanceOfWinningR(left)
          (leftPart, leftChance) = left
          right <- bestChanceOfWinningR(right)
          (rightPart, rightChance) = right
          chance <- chanceOfWining(leftPart, rightPart)
        } yield {
          val leftNewChance = leftChance * chance
          val rightNewChance = rightChance * (1 - chance)
          if (leftNewChance > rightNewChance) (leftPart, leftNewChance)
          else (rightPart, rightNewChance)
        }
    }
  bestChanceOfWinningR(d)._1
}

One More Example -Unfolds

Initial Model - unfold

sealed trait Draw
case class FutureMatch(left: Draw, right: Draw) extends Draw
case class Match(left: Participant, right: Participant) extends Draw
  a match {
    case xs if xs.size < 2 => None
    case x1 :: x2 :: Nil => Some(Match(x1, x2))
    case xs =>
      val (left,right) = xs.split(xs.size / 2)
      (fromList(left) ⊛ fromList(right))(FutureMatch)
  }
def fromList(a: List[Participant]): Option[Draw] =

Fixed point style - unfold

  Fix[DrawF].anaM(a){
    case xs if xs.size < 2 => None
    case x1 :: x2 :: Nil => Some(MatchF(x1,x2))
    case xs =>
      val (left,right) = xs.split(xs.size / 2)
      Some(FutureMatchF(left,right))
  }
def fromList(a: List[Participant]): Option[Draw] =
type Draw = Fix[DrawF]

sealed trait DrawF[A]
case class FutureMatchF[A](left: A, right: A) extends DrawF[A]
case class MatchF[A](left: Participant, right: Participant) extends DrawF[A]

unfold comparison

def fromList(a: List[Participant]): Option[Draw] =  
  Fix[DrawF].anaM(a){
    case xs if xs.size < 2 => None
    case x1 :: x2 :: Nil => Some(MatchF(x1,x2))
    case xs =>
      val (left,right) = xs.split(xs.size / 2)
      Some(FutureMatchF(left,right))
  }
def fromList(a: List[Participant]): Option[Draw] =    
  a match {
    case xs if xs.size < 2 => None
    case x1 :: x2 :: Nil => Some(Match(x1, x2))
    case xs =>
      val (left,right) = xs.split(xs.size / 2)
      (fromList(left) ⊛ fromList(right))(FutureMatch)
  }

Intuition

sealed trait DrawF[A]
case class FutureMatchF[A](left: A, right: A) extends DrawF[A]
case class MatchF[A](left: Participant, right: Participant) extends DrawF[A]
FutureMatch[Int] = FutureMatch(3,4)
FutureMatch[String] = FutureMatch("Bob", "Dylan")
val draw: FutureMatch[Draw[Draw[Draw[...]]]] // How to model a recursive draw?

Outline

  • Motivate it
  • How it works
  • Examples
  • How to use it

How it works

val draw: FutureMatch[Draw[Draw[Draw[...]]]] // How to model a recursive draw?
final case class Fix[F[_]](unFix: F[Fix[F]]) {
  def cata...
  def cataM...
  def ana...
  def anaM...
  ...
}
type Draw = Fix[DrawF]

Theory:

  • Fix is the Y combinator
  • Fix[Draw] is the fixed point of the Draw functor

How it works

val draw: Fix[DrawF]
val futureMatch: FutureMatchF[Fix[DrawF]] = draw.unfix
val leftSide: Fix[DrawF] = futureMatch.left
val result: FutureMatchF[Fix[DrawF]] = leftSide.unfix
val result1: Fix[DrawF] = result.left
val result2: MatchF[Fix[Draw]] = result1.unfix
val leftParticipant: Participant = result2.left

An instance of Functor is required

def cata[F[_]: Functor, A](t: T[F])(f: F[A] => A): A =
    f(project(t) ∘ (cata(_)(f)))

Outline

  • Motivate it
  • How it works
  • Examples
  • How to use it

AST Example (stolen from SumTypeOfWay)

data Expr  
  = Index Expr Expr
  | Call Expr [Expr]
  | Unary String Expr
  | Binary Expr String Expr
  | Paren Expr
  | Literal Lit
  deriving (Show, Eq)
-- this would turn the expression  
--    (((anArray[(10)])))
-- into
--    anArray[10]

flatten :: Expr -> Expr  
flatten (Literal i) = Literal i
flatten (Paren e) = flatten e
flatten (Index e i)     = Index (flatten e) (flatten i)  
flatten (Call e args)   = Call (flatten e) (map flatten args)  
flatten (Unary op arg)  = Unary op (flatten arg)  
flatten (Binary l op r) = Binary (flatten l) op (flatten r)

Fixed point stye

data Expr a  
  = Index a a
  | Call [a]
  | Unary String a
  | Binary a String a
  | Paren a
  | Literal Lit
  deriving (Show, Eq, Functor)
-- this would turn the expression  
--    (((anArray[(10)])))
-- into
--    anArray[10]

flatten :: Term Expr -> Term Expr  
flatten (In (Paren e)) = e  -- remove all Parens  
flatten other = other       -- do nothing otherwise

Example in Quasar

// final case class InvokeF[A](func: Func, values: List[A]) extends LogicalPlan[A]

def simpleEvaluation(lp: Fix[LogicalPlan]): FileSystemErrT[InMemoryFs, Vector[Data]] = {
  val optLp = Optimizer.optimize(lp)
  EitherT[InMemoryFs, FileSystemError, Vector[Data]](State.gets { mem =>
    import quasar.LogicalPlan._
    import quasar.std.StdLib.set.{Drop, Take}
    import quasar.std.StdLib.identity.Squash
    optLp.para[FileSystemError \/ Vector[Data]] {
      case ReadF(path) =>
        // Documentation on `QueryFile` guarantees absolute paths, so calling `mkAbsolute`
        val aPath = mkAbsolute(rootDir, path)
        fileL(aPath).get(mem).toRightDisjunction(pathErr(pathNotFound(aPath)))
      case InvokeF(Drop, (_,src) :: (Fix(ConstantF(Data.Int(skip))),_) :: Nil) =>
        src.flatMap(s => skip.safeToInt.map(s.drop).toRightDisjunction(unsupported(optLp)))
      case InvokeF(Take, (_,src) :: (Fix(ConstantF(Data.Int(limit))),_) :: Nil) =>
        src.flatMap(s => limit.safeToInt.map(s.take).toRightDisjunction(unsupported(optLp)))
      case InvokeF(Squash,(_,src) :: Nil) => src
      case ConstantF(data) => Vector(data).right
      case other =>
            queryResponsesL
              .get(mem)
              .mapKeys(Optimizer.optimize)
              .get(Fix(other.map(_._1)))
              .toRightDisjunction(unsupported(optLp))
    }
  })
}

Example in Quasar

def para[F[_]: Functor, A](t: Fix[F])(f: F[(Fix[F], A)] => A): A

Outline

  • Motivate it
  • How it works
  • Examples
  • How to use it

All the Recursions

Matryoshka

libraryDependencies += "com.slamdata" %% "matryoshka-core" % "0.11.1"

Recursion Schemes (Haskell)

Acknowledgment

Made with Slides.com