Big Data at the Intersection of Typed FP and Category Theory Abstract Algebra

About me

  • Software Engineer on data science team @ Coatue Management
  • Scala for the last 6 years, data engineering for the last 2.5
  • Based in New York, NY (Brooklyn)
  • Things I like:
    • Coffee, beer, whiskey
    • Travel
    • Music (hip hop, synthwave)
    • Sports (soccer, bouldering)
    • Fountain pens (!!)
    • Sneakers

Where I work

  • What we do: data engineering @ Coatue
  • Stack
    • Scala (cats, shapeless, fs2, http4s, etc.)
    • Spark
    • AWS (EMR, Redshift, etc.)
    • R, Python
    • Tableau
  • Offices in NYC and Menlo Park, CA
  • Email: lcao@coatue.com, Twitter: @oacgnol

What this talk is about

  • Aggregations at small and large scales
    • Small = local JVMs, List[T]
    • Large = big data, e.g. Spark's Dataset[T] or RDD[T]
  • Semigroups and monoids for the working data engineer: how are they related to aggregation functions?
    • ​How these abstractions enable code reuse
    • Defined with laws: a set of axioms that instances of these abstractions (should) obey

First: an aggregation in SQL

> select * from my_table;

 name | x |  y   
------+---+------
 a    | 1 |  500
 b    | 2 | 1000
 c    | 3 | 5000

> select sum(x) from my_table;

6

Nice and simple.

Translation to local Scala

val nums = List(1,2,3)

val sum = nums.fold(0) { (x1, x2) =>
  x1 + x2
}

// sum = 6

Still pretty simple.

Translation to Spark

val nums: Dataset[Int] =
  spark.createDataset(List(1,2,3))

val sum = nums.rdd.fold(0) { (x1, x2) =>
  x1 + x2
}

// sum = 6
  • Folds (reduces) in Spark are distributed
    • Parallelization across executors
    • Fold within partitions first, then fold final results

Hmm...

val nums = spark.createDataset(List(1,2,3))

val sum = nums.rdd.fold(0) { (x1, x2) =>
  x1 + x2
}

// sum = 6

Can we abstract these folds away for reuse?

val nums = List(1,2,3)

val sum = nums.fold(0) { (x1, x2) =>
  x1 + x2
}

// sum = 6

First pass at a typeclass

trait Combinable[A] {
  def zero: A
  def combine(x: A, y: A): A
}

implicit val addCombinable = new Combinable[Int] {
  def zero: Int = 0
  def combine(x: Int, y: Int): Int = x + y
}

def fold[T](
  list: List[T]
)(implicit m: Combinable[T]): T = {
  list.fold(m.zero) { (t1, t2) =>
    m.combine(t1, t2)
  }
}

fold(List(1,2,3)) == 6

AKA: Monoid[A]

trait Monoid[A] {
  def empty: A
  def combine(x: A, y: A): A
}

implicit val addMonoid = new Monoid[Int] {
  def empty: Int = 0
  def combine(x: Int, y: Int): Int = x + y
}

def fold[T](
  list: List[T]
)(implicit m: Monoid[T]): T = {
  list.fold(m.zero) { (t1, t2) =>
    m.combine(t1, t2)
  }
}

fold(List(1,2,3)) == 6

Monoid laws

  • Laws: like an extra set of rules that instances should also to adhere to
  • From Scaladoc:
    • List#fold: "a binary operator that must be associative."
    • RDD#fold: "using a given associative function [...] for functions that are not commutative, the result may differ from that of a fold applied to a non-distributed collection"

Monoid laws

  • Associativity
    • combine(combine(x, y), z) ==
      combine(x, combine(y, z))
  • Commutativity
    • combine(x, y) == combine(y, x)
  • Relevant in distributed aggregation since order and associativity aren't necessarily guaranteed across partitions

Monoid laws

import cats.kernel.CommutativeMonoid

val addMonoid = new CommutativeMonoid[Int] {
  def empty: Int = 0
  def combine(x: Int, y: Int): Int = x + y
}

// Run through cats law testing
checkAll(
  "addingMonoid",
  CommutativeMonoidTests[Int](addMonoid)
    .commutativeMonoid)
  • Let's add commutativity to our monoid:

Monoid laws

[info] SimpleFoldsTest:
[info] - addingMonoid.commutativeMonoid.associative
[info] - addingMonoid.commutativeMonoid.collect0
[info] - addingMonoid.commutativeMonoid.combine all
[info] - addingMonoid.commutativeMonoid.combineAllOption
[info] - addingMonoid.commutativeMonoid.commutative
[info] - addingMonoid.commutativeMonoid.is id
[info] - addingMonoid.commutativeMonoid.left identity
[info] - addingMonoid.commutativeMonoid.repeat0
[info] - addingMonoid.commutativeMonoid.repeat1
[info] - addingMonoid.commutativeMonoid.repeat2
[info] - addingMonoid.commutativeMonoid.right identity

(yes, all the tests are green and passing)

Monoids and Dataset

def fold[T](
  ds: Dataset[T]
)(implicit cm: CommutativeMonoid[T]): T = {
  ds.rdd.fold(cm.empty) { (t1, t2) =>
    cm.combine(t1, t2)
  }
}
  • Enforce commutativity as well on the typeclass instance:
  • Given we've now tried to follow Spark's guidance as much as possible, we trust the process to parallelize at runtime

How about a custom type?

case class MyType(
  name: String,
  x: Int,
  y: Option[Long])

// Can't really define empty for MyType
// Semigroup is like a Monoid without an `empty`
val cg = new CommutativeSemigroup[MyType] {
  override def combine(
    mt1: MyType,
    mt2: MyType
  ): MyType = {
    MyType(
      name = mt1.name, // pick a side... hmm...
      x = mt1.x combine mt2.x,
      y = mt1.y combine mt2.y)
  }
}

How about a custom type?

checkAll(
  "CommutativeSemigroup[MyType]",
  CommutativeSemigroupTests[MyType](cg)
    .commutativeSemigroup)

[info] SimpleCustomMonoidsTest:
[info] - CommutativeSemigroup[MyType].commutativeSemigroup.associative
[info] - CommutativeSemigroup[MyType].commutativeSemigroup.combineAllOption
[info] - CommutativeSemigroup[MyType].commutativeSemigroup.commutative *** FAILED ***
[info]   GeneratorDrivenPropertyCheckFailedException was thrown during property evaluation.
[info]    (Discipline.scala:14)
[info]     Falsified after 4 successful property evaluations.
[info]     Location: (Discipline.scala:14)
[info]     Occurred when passed generated values (
[info]       arg0 = MyType(儐,-193145333,None),
[info]       arg1 = MyType(,0,None)
[info]     )
[info]     Label of failing property:
[info]       Expected: MyType(,-193145333,None)
[info]   Received: MyType(儐,-193145333,None)

😢

How would I do this in SQL?

select
  name,
  sum(x) as x,
  sum(y) as y
from
  my_table
group by name;
  • Ultimately, I actually want to combine values for rows by key (name)

Monoid[List[MyType]]

  • Let's try translating that into a Monoid in Scala:
implicit val myTypeListMonoid = new Monoid[List[MyType]] {
  override def empty: List[MyType] = List.empty[MyType]

  override def combine(
    x: List[MyType],
    y: List[MyType]
  ): List[MyType] = {
    (x ++ y).groupBy(_.name)
      .map { case (k, myTypes) =>
        myTypes.reduce { (mt1, mt2) =>
          MyType(
            name = mt1.name,
            x = mt1.x combine mt2.x,
            y = mt1.y combine mt2.y)
        }
      }.toList
  }
}

Monoid[Dataset[MyType]]

implicit def myTypeDatasetMonoid(
  implicit spark: SparkSession
) = new Monoid[Dataset[MyType]] {
  import spark.implicits._

  override def empty: Dataset[MyType] =
    spark.emptyDataset[MyType]

  override def combine(
    ds1: Dataset[MyType],
    ds2: Dataset[MyType]
  ): Dataset[MyType] = {
    ds1.union(ds2).groupByKey(_.name)
      .reduceGroups { (mt1, mt2) =>
        MyType(
          name = mt1.name,
          x = mt1.x combine mt2.x,
          y = mt1.y combine mt2.y)
      }
      .map(_._2)
  }
}

Caveats

  • Having to provide concrete instances for each type and then each collection type (monomorphic)
  • Could we abstract more logic away to do this? (Hint: typeclasses, again)

Pulling out a common pattern

  • We are grouping by a key and combining only the values
  • Try a simple typeclass for this:
trait KeyValue[T, K, V] extends Serializable {
  def to(t: T): (K, V)
  def from(k: K, v: V): T
}

Pulling out a common pattern

object MyType {
  case class Values(
    x: Int,
    y: Option[Long])
}

implicit val kv =
  new KeyValue[MyType, String, MyType.Values] {
    def to(mt: MyType): (String, MyType.Values) =
      (mt.name, MyType.Values(mt.x, mt.y))
    def from(k: String, v: MyType.Values): MyType =
      MyType(k, v.x, v.y)
  }

Pulling out a common pattern

import cats.implicits._

// CommutativeSemigroup that combines
// only the _values_ of MyType
implicit val sg =
  new CommutativeSemigroup[MyType.Values] {
    override def combine(
      v1: MyType.Values,
      v2: MyType.Values
    ): MyType.Values = {
      MyType.Values(
        // delegate to cats Semigroup instances
        x = v1.x combine v2.x,
        y = v1.y combine v2.y)
    }
  }

Now try the new approach...

implicit def kvListMonoid[T, K, V: CommutativeSemigroup](
  implicit kver: KeyValue[T, K, V]
) = new Monoid[List[T]] {
  override def empty: List[T] = List.empty[T]

  override def combine(
    x: List[T],
    y: List[T]): List[T] = {
    (x ++ y).map(kver.to)
      .groupBy(_._1)
      .map { case (k, kvs) =>
        val combined: V = kvs
          .map(_._2)
          .reduce(_ combine _)

        kver.from(k, combined)
      }.toList
  }
}

New Monoid[Dataset[T]

def kvDatasetMonoid[T: Encoder, K: Encoder, V: Encoder: CommutativeSemigroup](
  implicit kver: KeyValue[T, K, V], spark: SparkSession
) = new Monoid[Dataset[T]] {
  import cats.implicits._

  private val tupleEncoder: Encoder[(K, V)] = Encoders.tuple[K, V](
    implicitly[Encoder[K]],
    implicitly[Encoder[V]])

  override def empty: Dataset[T] = spark.emptyDataset[T]

  override def combine(ds1: Dataset[T], ds2: Dataset[T]): Dataset[T] = {
    ds1.union(ds2).map(kver.to(_))(tupleEncoder)
      .groupByKey((kv: (K, V)) => kv._1)
      .reduceGroups { (kv1: (K, V), kv2: (K, V)) =>
        val (k1, v1) = kv1
        val (k2, v2) = kv2
        (k1, v1 combine v2)
      }
      .map { kkv: (K, (K, V)) =>
        val (k, (kv)) = kkv
        kver.from(k, kv._2)
      }
  }
}

😅

What did we just do?

  • Reusable CommutativeSemigroup[V]s
  • Reusable combines for List[T] and Dataset[T]

Caveats, other ideas

  • The new monoids we've just defined are a little finicky with respect to law checking (see the talk repo to be linked)
    • Note: sometimes unlawful instances can be useful (see: alleycats)
    • Also probably need performance tuning
  • Shapeless for automatic typeclass derivation?
  • Other big data frameworks:
    • Flink, Hadoop, Beam, etc.

Further reading

Thanks for listening!

Made with Slides.com