Zainab Ali
case class Example(
gender: Gender,
age: Age,
ticketClass: TicketClass,
familySize: FamilySize
sealed trait Label
case object Survived
extends Label
case object Died
extends Label
1000 examples
binary classification
def predict(example: Example): Label
def predict(example: Example): Label = Died
def risk(predictions: Map[Int, Label],
actual: Map[Int, Label]): Double = {
predictions.map {
case (id, prediction) =>
if (prediction != actual(id)) 1.0 else 0.0
} / predictions.size
Fraction of incorrect predictions
def predict(example: Example): Label =
example.gender match {
case Female => Survived
case Male => Died
sealed trait Tree
case class Leaf(label: Label) extends Tree
case class Node(feature: Feature,
children: Map[Value, Tree]) extends Tree
sealed trait TreeF[A]
case class Leaf[A](label: Label) extends TreeF[A]
case class Node[A](feature: Feature,
children: Map[Value, A]) extends TreeF[A]
case class Fix[F[_]](unFix: F[Fix[F]])
val tree: Fix[TreeF] = Fix(Node(gender, Map(
"male" -> Fix(Leaf[Fix[TreeF]](Died)),
"female" -> Fix(Leaf[Fix[TreeF]](Survived))
type Tree = Fix[TreeF]
implicit val treeFunctor: Functor[TreeF] =
new Functor[TreeF] {
def map[A, B](fa: TreeF[A])(f: A => B): TreeF[B] =
fa match {
case Leaf(l) => Leaf(l)
case Node(feature, children) =>
Node(feature, children.mapValues(f))
type Coalgebra[F[_], A] = A => F[A]
def ana[F[_]: Functor, A](a: A)(
coalgebra: Coalgebra[F, A]): Fix[F]
Generalized unfold
type Input = (List[Example], Set[Feature])
val build: Coalgebra[TreeF, Input] = {
case (examples, features) =>
if(features.nonEmpty) {
val (feature, maxGain) = maxGain(examples, features)
val nextFeatures = features - feature
val nextExamples = groupByValue(examples, feature)
Node(feature, nextExamples.mapValues(xs =>
(xs, nextFeatures)))
} else {
val tree: Tree = (examples, features).ana(build)
def explore(example: Example):
Coalgebra[Label Either ?, Tree] =
_.unFix match {
case Leaf(label) => Left(label)
case Node(feature, children) =>
Right(children(value(feature, example)))
Anamorphism with Either
val lizWalton: Example = Example(Adult, Female, ...)
val path: Fix[Label Either ?] = tree.ana(explore(lizWalton))
//path = Fix(Right(Fix(Right(...(Fix(Left(Label.Survived))))
type Algebra[F[_], A] = F[A] => A
def cata[F[_]: Functor, A](fix: Fix[F])(
algebra: Algebra[F, A]): A
Generalized fold
val collapse: Algebra[Label Either ?, Label] = _.merge
val prediction = path.cata(collapse)
//prediction = Survived
def hylo[F[_]: Functor, A, B](a: A)(
algebra: Algebra[F[_], B],
coalgebra: Coalgebra[F[_], A]): B
Generalized refold
def predict(tree: Tree)(example: Example): Label =
tree.hylo(collapse, explore(example))
case class AttrF[A, B](a: A, tree: TreeF[B])
implicit def attrFunctor[A]: Functor[AttrF[A, ?]] = ...
type Counts = Map[Label, Int]
def buildCounts: Coalgebra[AttrF[Counts, ?], Input] = {
case (examples, features) =>
val counts = labelCounts(examples)
val tree = build((examples, features))
AttrF(counts, tree)
val tree = (examples, features).ana(buildCounts)
case class CostInfo(
leafCount: Int,
risk: Int,
counts: Counts
val costInfo: Algebra[AttrF[Counts, ?], Attr[CostInfo]] = {
case AttrF(counts, t: Leaf(_)) =>
Fix(AttrF(leafCostInfo(counts, t), t))
case AttrF(_, t @ Node(_, children)) =>
Fix(AttrF(nodeCostInfo(children), t))
val tree = (examples, features)
val tree = (examples, features).hylo(buildCounts, costInfo)
val minCost: Algebra[AttrF[CostInfo, ?], Double] = {
case AttrF(_, Leaf(_)) => Double.PositiveInfinity
case AttrF(info, Node(_, children)) =>
(info.cost :: children.values).min
def prune(minCost: Double):
Algebra[AttrF[CostInfo, ?], Attr[CostInfo]] = {
case AttrF(c, Leaf(l)) =>
Fix(AttrF(c, Leaf(l)))
case AttrF(info, n @ Node(_, children)) =>
if(info.cost == minCost) {
val leaf = makeLeaf(info)
Fix(AttrF(leafCostInfo(info.counts, leaf), leaf))
} else {
Fix(AttrF(nodeCostInfo(children), n))
val tree = (examples, features)
.hylo(buildCounts, costInfo)
val cost1 = tree.cata(minCost)
val subTree1 = tree.cata(prune(cost1))
val cost2 = subTree1.cata(minCost)
val subTree2 = subTree1.cata(prune(cost2))
