Last updated on

Functional Parallelism with Collections

Welcome to week 6 of CS-214 — Software Construction!

As usual, ⭐️ indicates the most important exercises and questions and 🔥 indicates the most challenging ones. Exercises or questions marked 🧪 are intended to build up to concepts used in this week’s lab.

You do not need to complete all exercises to succeed in this class, and you do not need to do all exercises in the order they are written.

We strongly encourage you to solve the exercises on paper first, in groups. After completing a first draft on paper, you may want to check your solutions on your computer. To do so, you can download the scaffold code (ZIP).

Using map and reduce

This exercise is intended to help you get more familiar with the .map and .reduce APIs on parallel collections. To this end, we’ll write code to compute statistics on the contents of a book.

Our input book has already been split into chapters, and each chapter has been split into words:

type Word = String
type Chapter = Seq[Word]
type Book = Seq[Chapter]

src/main/scala/parallelism/BookStats.scala

All words are in lowercase, and punctuation has been removed. For example:

val bookExample: Book =
  Vector(Vector("the", "quick", "brown", "fox", "is", "jumping", "over"), Vector("the", "lazy", "dog"))

src/main/scala/parallelism/BookStats.scala

Complete the following functions, using the .map and .reduce APIs whenever possible.

  1. The length function calculates the total number of words in the book:

    def length(b: Book): Int =
      ???
    

    src/main/scala/parallelism/BookStats.scala

  2. The maxChapterLength function finds the length of the longest chapter of the book:

    def maxChapterLength(b: Book): Int =
      ???
    

    src/main/scala/parallelism/BookStats.scala

  3. The longestWord function identifies the longest word in the book:

    def longestWord(b: Book): Word =
      ???
    

    src/main/scala/parallelism/BookStats.scala

  4. The countWord function determines how many times a given word appears in the book:

    def countWord(b: Book, w: Word): Int =
      ???
    

    src/main/scala/parallelism/BookStats.scala

  5. The mostCommonWord function finds the most frequently used word in the book:

    def mostCommonWord(b: Book): Word =
      ???
    

    src/main/scala/parallelism/BookStats.scala

To test your implementation on a computer, run testOnly -- "*bookstats:*".

Writing your own parallel map

The Parallel collections library lets us parallelize operations like map and reduce effortlessly. What would we do without it? We’d write these operations ourselves!

Write a parallel version of map, using Threads rather than parallel collections.

extension [A](seq: Seq[A])
  def parMap[B](f: A => B): Seq[B] =
    ???

src/main/scala/parallelism/MapReduce.scala

Parallel Aggregation ⭐️

In the world of functional programming, methods like foldLeft serve as foundational building blocks to crunch large datasets or work with complex structures, as they encapsulate very generic operations on collections.

At its core, the fold operation takes a collection and recursively combines its elements using a function, starting with an initial accumulator value. It’s akin to folding a long piece of paper into smaller sections, step-by-step.

However, fold’s power doesn’t come without some limitations. The operation is inherently sequential. Each step relies on the outcome of the previous step. In our paper analogy, imagine trying to fold the next section of the paper before the previous fold has been completed.

Siblings: reduce and fold

For a minute, take a look instead at the reduce operation

def reduce(op: (A, A) => A): A

which is quite close to fold itself. It repeatedly combines elements of the collection till there is a single element left.

Reduce as a fold

Implement reduce in terms of foldLeft or foldRight on List[A] by completing the skeleton in FoldReduce.scala:

extension [A](l: List[A])
  def reduceWithFold(op: (A, A) => A): A =
    ???

src/main/scala/parallelism/FoldReduce.scala

Properties of folds

With an idea of how to implement reduce with foldLeft (as an example, foldRight is similar), let’s try to reverse-engineer why we were able to do this. This may seem like a strange thing to ask, but let’s begin with the signature of foldLeft:

abstract class List[+A]:
  def foldLeft[B](z: B)(op: (B, A) => B): B

On the other hand, look at the instantiated signature that actually appears when implementing reduce (not quite correct to write, just for demonstration):

// restricted case:
  def foldLeft(z: A)(op: (A, A) => A): A

Thus, we had to restrict foldLeft to the special case where the operator op acts only on A, i.e. B is A. In doing so, we lost much of the general power of foldLeft, where B and op were unconstrained.

With reduce, however, it is possible to have a parallel implementation! Are there any conditions we must impose on op or the input list l for this parallelization to be safe, i.e., deterministic?

Hint

In the operation List(1, 2, 3).par.reduce(_ - _), there are two possible evaluations:

Option 1:
1 - (2 - 3) === 2

Option 2:
(1 - 2) - 3 === -4

but with List(1, 2, 3).par.reduce(_ + _):

Option 1:
1 + (2 + 3) === 6

Option 2:
(1 + 2) + 3 === 6
Solution

The operator op must be associative, since when running in parallel elements may be combined in any order.

Parallel reduction

Implement a parallel version of reduce based on this idea (do not worry about complexity, for now):

extension [A](l: List[A])
  def reducePar(op: (A, A) => A): A =
    ???

src/main/scala/parallelism/FoldReduce.scala

Discussion

With this, we have three conditions that allowed us to first define reduce in terms of foldLeft, and then parallelize it:

  1. restriction of output type B =:= A
  2. restriction of operator type op: (B, B) => B
  3. restriction of operator behaviour associative op

Discuss these conditions in a group. Are they necessary? What do they change from the sequential constraints in foldLeft?

What was fold again?

This jumble with types and restrictions makes a case to go back to the signature of foldLeft, and take a closer look at what it does.

abstract class List[+A]:
  def foldLeft[B](z: B)(op: (B, A) => B): B

which we can represent as a more intuitive “brick diagram”:

foldLeft brick diagram

Here, elements of type A are shown as blocks with a rounded protrusion, which fits into rounded holes, accepting values of type A, while those of type B are shown with a sharp triangle. Applying a function corresponds to “inserting” its inputs, and getting the correct output type on the other side of the block.

In this picture, what does the function (x, y) => op(op(z, x), y) look like?

Solution

Lambda x y to op (op z x) y

So, in this interpretation, let’s look at what a partial/parallel result of foldLeft might look like:

foldLeft type incompat

On the left we have a foldLeft result on a list of two elements, and on the top, a result with one element. However, now we have two triangles, and no magic ??? that accepts two triangles and produces a triangle, which our output expects.

This suggests that what we’re missing in being able to parallelize foldLeft is then a way to combine partial results, each of type B.

Aggregating what we’ve learnt

Adding the missing combination operator to foldLeft gives us a function called aggregate:

abstract class List[+A]:
  def aggregate[B](z: => B)(seqop: (B, A) => B, combop: (B, B) => B): B

Given all the missing pieces, combine them to implement parallel aggregate using map and reduce by completing this skeleton:

extension [A](l: List[A])
  def aggregate[B](z: B)(seqop: (B, A) => B, combop: (B, B) => B): B =
    ???

src/main/scala/parallelism/FoldReduce.scala

Many ways to aggregate

Laws of aggregation

There are many “correct” ways to implement aggregate, since we have not specified it well enough yet. These implementations may differ from each other in subtle ways. Look at the following candidate implementations for aggregate:

Reveal Implementations
extension [A](l: Seq[A])
  def aggregate1[B](z: B)(f: (B, A) => B, g: (B, B) => B): B =
    l.par
      .map(f(z, _))
      .reduce(g)

  def aggregate2[B](z: B)(f: (B, A) => B, g: (B, B) => B): B =
    l.foldLeft(z)(f)

  def aggregate3[B](z: B)(f: (B, A) => B, g: (B, B) => B): B =
    if l.length <= 1 then l.foldLeft(z)(f)
    else
      l.grouped(l.length / 2)
        .toSeq
        .par
        .map(s => s.aggregate3(z)(f, g))
        .reduce(g)

  def aggregate4[B](z: B)(f: (B, A) => B, g: (B, B) => B): B =
    if l.length <= 1 then l.foldLeft(z)(f)
    else
      l.grouped(l.length / 2)
        .toSeq
        .par
        .map(s => s.aggregate4(z)(f, g))
        .foldLeft(z)(g)

src/main/scala/parallelism/AggregateImpl.scala

Are all these implementations “correct”? Can you find cases where they fail to agree with your intuition of what aggregate should do?

Similar to the associativity of op that you discovered as a condition for safely parallelizing reduce, can you use your test cases to narrow down a condition that ensures aggregate behaves the same as the sequential foldLeft?

Hint

Running aggregate on a small example, consider two different types of parallel splits, and consider under what situations they would necessarily be equal.

Your condition should involve both seqop and combop.

What’s special about this? 🔥

Taking a List(a, b, c, d) and inputs f, g, and z for seqop, combop, and z respectively, attempt to write down a few possible ways aggregate can be computed over it. As an example:

g(f(f(z, a), b), f(f(z, c), d))

Now, exhaustively simplify this expression based on the rules you derived above to this expression. What do you obtain?

Analyzing the implementations 🔥

Can you come up with minimal mathematical conditions on seqop, combop, and z such that the implementations above will produce different results? In which cases are the functions the same on all inputs?

Hint

Try to identify which different operations the implementations use, and what the safety conditions of those smaller operations are.

Writing the concrete output on a small symbolic list, e.g. List(a, b, c), can go a long away in identifying general hypotheses to test.

Revisiting recursion exercises with map and reduce

Many of the list functions that we studied in week 1 can be expressed as combinations of map and reduce, and parallelized. Go over them and identify the ones that can and the ones that cannot.

Parallel parentheses matching ⭐️

You may need to do the aggregate exercise above first to understand the last subsection of this exercise.

Imagine for a moment that you’re building a modern architectural masterpiece. This building, unlike any other, is held together by a series of intricate and delicate balancing acts. Each element, be it a beam, a plank, or a brick, needs another corresponding element to keep it balanced. If even one piece is out of place, the entire structure could collapse.

In mathematical expressions and your code, this glorious task is performed by the humble parentheses (). What happens if one parenthesis is left unmatched? A mathematical expression becomes unsolvable, a sentence becomes confusing, and your code reports a hundred errors (sigh).

Let’s take a look at strings and identify these balancing acts.

What is ‘balance’?

We say a string is balanced, here, if every opening parenthesis ( is matched to a unique ) succeeding it in the string, and vice versa, no closing parenthesis ) should be left alone.

Which of the following strings are balanced?

  1. (o_()
  2. (if (zero? x) max (/ 1 x))
  3. :-)
  4. ())(
  5. I told him (that it's not (yet) done). (But he wasn't listening)
  6. (5 + 6))(7 + 8
  7. ((()())())
Reveal the answer

2, 5, and 7 are balanced, the others are not.

Not all strings are made the same

Before we jump into actually writing code to solve our problem, it can be worthwhile to look at several examples to understand the structure of the problem first.

Consider the strings

  1. ))()
  2. ((()
  3. ((())

From the perspective of the balancing problem, are all of these strings the same? Or are there noticeable differentiating factors between 1 and 2, between 2 and 3? Assume that you can extend the strings on the right, but not on the left.

Solution

String 1 cannot be extended on the right to be balanced, it is in an unrecoverable state, and any extension of it will be unbalanced.

2 and 3 on the other hand, can both be extended to balanced strings. 3 is “closer” to being balanced, as it needs only one more closing parenthesis.

So, the number of open nested parentheses seems to be something we naturally look for. Can we make an implementation around this idea?

We can get the same analysis from right-to-left as well.

We have been trying to deal with all of our problems recursively till now. Based on your observation, can you come up with a recursive pattern on strings to check for balancing? What properties or quantities does your recursion rely on?

Use these ideas to write a function that recurses on a string, represented as a list of characters, List[Char], to check whether it is balanced:

def isBalancedRecursive(str: List[Char]): Boolean =
  ???

src/main/scala/parallelism/ParenthesesBalancing.scala

You may use your own helper functions as you want inside isBalancedRecursive.

Folding Up

As we have seen again and again, most recursive functions we write are really standard functional operations in disguise. Rewrite your recursive function using a fold operation on str by completing the skeleton

def isBalancedFold(str: List[Char]): Boolean =
  ???

src/main/scala/parallelism/ParenthesesBalancing.scala

Parallelizing Folds

As we now know of aggregate as the parallel version of a fold, this gives us an opportunity to take our balancing check and extending it to be parallel.

Can we use the number of open parentheses as before and simply apply aggregate?

def isBalancedParSimple(str: List[Char]): Boolean =
  val foldingFunction: (Int, Char) => Int = ??? // your folding function

  val numOpen = str.par.aggregate(0)(foldingFunction, _ + _)

  (numOpen == 0)

src/main/scala/parallelism/ParenthesesBalancing.scala

The aggregate here applies fold as you did above, and combines the partial results by adding the resulting numOpen values. Does this always produce the correct balancing decision for any string?

Solution

No. Consider a simple string )(. In the sequential fold version, you would find that the string is in an unrecoverable state, and return numOpen = -1. However, here, since we are computing numOpen in parallel, the following can happen:

          parallel thread 1   parallel thread 2
input            ")"                 "("
numOpen          -1                  +1
after reduce    (-1) + (+1) === 0

and since == 0 is our condition for balancing, we will claim that the string )( is balanced, which is incorrect according to our definition.

Design a new parallel implementation based on aggregate by filling in the skeleton below. Change occurrences of the type Any to the type you want to use as your result type.

def isBalancedPar(str: List[Char]): Boolean =
  val seqOp: (Any, Char) => Any = ???
  val combOp: (Any, Any) => Any = ???

  str.par.aggregate(???)(seqOp, combOp) == ???

src/main/scala/parallelism/ParenthesesBalancing.scala

Hint

In foldLeft or foldRight, we know that we are moving the data in one chosen direction. So, we maintained a numOpen value, which represented the number of parentheses we needed to close while moving in this direction.

However, when we work in parallel and use reduce, the “reduction” happens simultaneously in both directions:

        1      2   3
String   (((    )))  ()()
numOpen  +3     -3   0

reduce could first combine the results from 1 and 2, or from 2 and 3. Thus, String 2 can now be “extended” in either direction.

Thus, we can maintain two numOpen values, representing the number of open brackets to the left, and to the right.

For example, the string )))()(( would have the result opens = (3, 2). Your aggregating combOp now has to combine two tuples representing these values.

Tabulation

Quite often, we find ourselves with collections whose elements are determined by a pattern. For example, a list of even numbers up to 6, Array(0, 2, 4, 6), which is the function x => 2 * x evaluated from indices 0 to 3. To specialize the generation of such collections, the Scala library provides the simple function Array.tabulate (similarly List.tabulate, etc). The function provides the following interface:

  /** Returns an array containing values of a given function over a range of integer
   *  values starting from 0.
   *
   *  @param  n   The number of elements in the array
   *  @param  f   The function computing element values
   *  @return An `Array` consisting of elements `f(0),f(1), ..., f(n - 1)`
   */
  def tabulate[T: ClassTag](n: Int)(f: Int => T): Array[T]

It accepts a generating function, and the size of the new array. The ClassTag is required by an Array due to interactions with Java library code. You can optionally read a discussion about it here.

So, great! tabulate exists. Let’s take a quick look at the parallel version, ParArray.tabulate, whose source code can be found here:

  /** Produces a $coll containing values of a given function over a range of integer values starting from 0.
   *  @param  n   The number of elements in the $coll
   *  @param  f   The function computing element values
   *  @return A $coll consisting of elements `f(0), ..., f(n -1)`
   */
  def tabulate[A](n: Int)(f: Int => A): CC[A] = {
    val b = newBuilder[A]
    b.sizeHint(n)
    var i = 0
    while (i < n) {
      b += f(i)
      i += 1
    }
    b.result()
  }

Don’t be intimidated by it! We don’t have much to do with it, except noticing… it does nothing in parallel. But my parallelization!

Too bad. Well, we have to write it ourselves 😊

Tabulating Sequentially 🧪

From scratch, or using the library implementation as a hint, try to write a sequential Array.tabulate with data structures and operations familiar to you:

extension (a: Array.type)
  def seqTabulate[A: ClassTag](n: Int)(f: Int => A): Array[A] =
    ???

src/main/scala/parallelism/Tabulate.scala

The extension on Array.type is there so you can follow this up by using your new function as Array.seqTabulate(n)(f) just like the Scala library!

Parallelizing 🧪

Finally, parallelize your tabulation! This may be quite easy if you wrote it completely functionally, or a bit more work if you did it imperatively like the library. But that extra work comes with its own benefit. See the next section for a comparison of performance.

extension (p: ParArray.type) {
  def parTabulate[A: ClassTag](n: Int)(f: Int => A): ParArray[A] =
    ???
}

src/main/scala/parallelism/Tabulate.scala

Zipping through

With your functional parTabulate ready to go, implement a function zipWith, which takes two arrays, and computes a new array by zipping them and applying a function to them.

As an example, consider vector addition:

def vectorAdd(a: Array[Int], b: Array[Int]) =
  a.zipWith((l: Int, r: Int) => l + r)(b)

src/main/scala/parallelism/Tabulate.scala

val a = Array(1, 2, 3)
val b = Array(4, 5, 6)
vectorAdd(a, b) // == Array(5, 7, 9)

Complete the following skeleton:

extension [A](seq: Array[A])
  def zipWith[B, C: ClassTag](f: (A, B) => C)(other: Array[B]): Array[C] =
    ???

src/main/scala/parallelism/Tabulate.scala

Many Paths Down the Optimization Abyss 🔥

Given the following three implementations of parTabulate:

Reveal Implementation
extension (p: ParArray.type)
  def parTabulateArrayMap[A: ClassTag](n: Int)(f: Int => A): ParArray[A] =
    (0 until n).toArray.par.map(f)

  def parTabulateBinarySplit[A: ClassTag](n: Int)(f: Int => A): ParArray[A] =
    val destination = new Array[A](n)
    tabulateToArray(destination, f, 0, n)
    destination.par

  def parTabulateMapReduce[A: ClassTag](n: Int)(f: Int => A): ParArray[A] =
    // does not take advantage of the fact that we know the output is an Array of the same size
    // we may want to use this if the output type was arbitrary instead
    (0 until n).par.map(i => Array(f(i))).reduce(_ ++ _).par

def tabulateToArray[A: ClassTag](
    destination: Array[A],
    f: Int => A,
    from: Int,
    to: Int
): Unit =
  val lim = 500
  if to - from < lim then
    // just run sequentially
    (from until to).foreach(i => destination(i) = f(i))
  else
    // fork
    Vector(
      from -> (from + to) / 2,
      (from + to) / 2 -> to
    ).par.map((from, to) => tabulateToArray(destination, f, from, to))

src/main/scala/parallelism/Tabulate.scala

and the following performance results:

Reveal Test Results
[info] Benchmark                           Mode  Cnt     Score     Error  Units
[info] TabulateBench.arrayMapTabulate     thrpt   25  9435.925 ± 720.130  ops/s
[info] TabulateBench.binarySplitTabulate  thrpt   25  8156.671 ± 458.419  ops/s
[info] TabulateBench.mapReduceTabulate    thrpt   25   775.462 ±   8.016  ops/s

where ops/s stands for operations per second, higher is better. Cnt stands for count, the number of times the benchmarks were repeated, and Mode is the chosen way of measuring the performance (throughput).

The tests were run with an array size of 10,000 and the tabulation function as f: Int => Int = x => 2 * x.

The benchmarking was done with the Java Microbenchmark Harness (jmh).

Discuss the different implementations in groups and with the TAs to come up with possible hypotheses to explain their performance differences. Run the different versions yourself with different examples to test your hypotheses.

You can also run the JMH benchmarks yourself inside sbt with the following command:

Jmh/run

To run a faster (but less reliable benchmark), you can use -f 1 -w 1s -r 1s (see Jmh/run -h to see what that means!).

Work and depth ⭐️

As a brief reminder: Work and Depth are two measures of the complexity of a parallel algorithm.

We’ll make use of two parallel constructs:

For each of the following implementations, give a big-$\mathcal{O}$ expression for work and depth. You may assume that creating and starting threads has cost $\mathcal{O}(1)$. and that vector operations such as concatenating two vectors, taking a slice from a vector, or splitting a vector all have cost $\mathcal{O}(1)$.

Assume that all lists, arrays, and vectors have length $n$. For higher-order functions, assume that calling the argument function f has work $W_f$ and depth $D_f$ independently of the input.

map

On vectors

extension [A](self: Vector[A])
  def map_vector[B](f: A => B): Vector[B] = self match
    case Vector()  => Vector()
    case Vector(x) => Vector(f(x))
    case _ =>
      val (la, ra) = self.splitAt(self.size / 2)
      val (lb, rb) = par2(la.map_vector(f), ra.map_vector(f))
      lb ++ rb

src/main/scala/parallelism/WorkDepth.scala

On arrays, using par2

extension [A](src: Array[A])
  def map_array_par2[B](f: A => B)(dst: Array[B]) =
    require(dst.length == src.length)

    def rec(from: Int, until: Int): Unit =
      require(0 <= from && from <= until && until <= src.length)
      if until == from then
        ()
      else if until == from + 1 then
        dst(from) = f(src(from))
      else
        val mid = from + (until - from) / 2
        par2(rec(from, mid), rec(mid, until))

    rec(0, src.length)
    dst

src/main/scala/parallelism/WorkDepth.scala

On arrays, using n-way parallelism

extension [A](src: Array[A])
  def map_array_parN[B](f: A => B)(dst: Array[B]) =
    require(dst.length == src.length)
    parN(src.length)(i => dst(i) = f(src(i)))
    dst

src/main/scala/parallelism/WorkDepth.scala

On lists

extension [A](ls: List[A])
  def map_list[B](f: A => B): List[B] = ls match
    case Nil => Nil
    case ha :: ta =>
      val (hb, tb) = par2(f(ha), ta.map_list(f))
      hb :: tb

src/main/scala/parallelism/WorkDepth.scala

sum

On vectors

extension (self: Vector[Int])
  def sum_vector: Int = self match
    case Vector()  => 0
    case Vector(x) => x
    case _ =>
      val (l, r) = self.splitAt(self.size / 2)
      val (sl, sr) = par2(l.sum_vector, r.sum_vector)
      sl + sr

src/main/scala/parallelism/WorkDepth.scala

On arrays, using two-ways parallelism

extension (self: Array[Int])
  def sum_array: Int =
    def sum_range(from: Int, until: Int): Int =
      require(0 <= from && from <= until && until <= self.length)
      if until == from then 0
      else if until == from + 1 then self(from)
      else
        val mid = from + (until - from) / 2
        val (sl, sr) = par2(sum_range(from, mid), sum_range(mid, until))
        sl + sr
    sum_range(0, self.length)

src/main/scala/parallelism/WorkDepth.scala

Matrix multiplication on arrays

case class Matrix(rows: Array[Array[Int]], nR: Int, nC: Int):
  require(rows.length == nR && rows.forall(_.length == nC))
  override def toString =
    rows.map(_.map(_.toString).mkString(" ")).mkString(";\n")
object Matrix:
  def apply(nR: Int, nC: Int): Matrix =
    Matrix(Array.ofDim[Int](nR, nC), nR, nC)

def matmul(m1: Matrix, m2: Matrix): Matrix =
  require(m1.nC == m2.nR)
  val mul = Matrix(m1.nR, m2.nC)
  parN(mul.nR): r =>
    parN(mul.nC): c =>
      (0 to m1.nC).foreach: i =>
        mul.rows(r)(c) += m1.rows(r)(i) * m2.rows(i)(c)
  mul

src/main/scala/parallelism/WorkDepth.scala

Cumulative sums

The cumulative sum of a sequence $x = x_1, … x_n$ is the sequence $s$ defined by $s_1 = x_1$ and $s_{k} = s_{k-1} + x_{k}$. For example, the cumulative sum of 1, 5, 3, 8, -1 is 1, 6, 9, 17, 16.

Sequential implementation

Implement a sequential version of cumsum:

def cumsum_sequential(v: Vector[Int]): Vector[Int] =
  ???

src/main/scala/parallelism/WorkDepth.scala

Parallel implementation

The definition of cumsum sounds rather sequential, but it turns out to be parallelizable! Take a moment to pause and think about how that might work.

Ask yourself: if I have a vector $v = v_1 {+\mkern{-5mu}+} v_2$, what information do I need to compute the $v_2$ part of the cumulative sum?

Solution

Knowing the sum $s$ of the elements of $v_1$ is enough: we can then compute the cumulative sum of the elements of $v_2$, starting from $s$ instead of 0.

Consider pre-computing these sums in a separate data structure!

The final algorithm is quite similar to the parentheses matching algorithm above.

The following partial implementation uses two-pass algorithm to solve this problem. First, we define an intermediate data structure which we call a “sum tree”: its leaves are numbers, and each branch holds a cached version of the sum of the elements found in that subtree:

enum SumTree:
  case Empty
  case Leaf(n: Int)
  case Branch(_sum: Int, left: SumTree, rright: SumTree)

  def sum = this match
    case Empty              => 0
    case Leaf(n)            => n
    case Branch(_sum, _, _) => _sum

src/main/scala/parallelism/WorkDepth.scala

mkSumTree

Second, we write a parallel function to create a sum tree from a vector. Complete the following template, making sure to maximize parallelism:

import SumTree.*
def mkSumTree(v: Vector[Int]): SumTree = v match
  case Vector()  => Empty
  case Vector(x) => Leaf(x)
  case _         =>
    ???

src/main/scala/parallelism/WorkDepth.scala

cumsum_sumtree

Third, we write a function to compute the cumulative sums of the leaves of a sumtree, reusing the partial sums cached in the leaves of the tree. Complete the following template, making sure to maximize parallelism:

def cumsum_sumtree(st: SumTree, leftSum: Int = 0): Vector[Int] = st match
  case Empty           => Vector()
  case Leaf(s)         => Vector(leftSum + s)
  case Branch(s, l, r) =>
    ???

src/main/scala/parallelism/WorkDepth.scala

cumsum_parallel

Finally, we put these two functions together:

def cumsum_par2(v: Vector[Int]): Vector[Int] =
  cumsum_sumtree(mkSumTree(v))

src/main/scala/parallelism/WorkDepth.scala

What is the complexity (work and depth) of this algorithm?