Last updated on

Effects, memoization

Welcome to week 11 of CS-214 — Software Construction!

As usual, ⭐️ indicates the most important exercises and questions and 🔥 indicates the most challenging ones. Exercises or questions marked 🧪 are intended to build up to concepts used in this week’s lab.

You do not need to complete all exercises to succeed in this class, and you do not need to do all exercises in the order they are written.

We strongly encourage you to solve the exercises on paper first, in groups. After completing a first draft on paper, you may want to check your solutions on your computer. To do so, you can download the scaffold code (ZIP).


The beauty of functional programming is that it provides referential transparency, the property that the results of a function depend only on its inputs. But many things can break referential transparency:

Traditionally, these are called effects. In the coming weeks we’ll see safe ways to encapsulate all these and recover referential transparency, using monads, but this week we’ll focus instead on special cases where local uses of these features do not affect referential transparency.

In fact, we have already seen one example: tail recursion elimination! Converting a recursive function to a loop that uses a mutable variable does not change anything for the callers of that function, but saves space on the stack, and prevents stack overflows. In fact, this transformation is so safe that the compiler does it automatically, just as how we’d do it by hand.

This week we’ll focus on this and two more examples:

Avoiding recursion

In your past experience, you probably used loops to implement most of the algorithms you encountered. However, during this course, we have used exclusively recursion to implement looping behaviour. Is recursion useless? Is it only complicating things for the sake of it? Let’s find out! To do so, we will rewrite recursive functions as loop and vice versa.

Converting tail-recursive functions to loops ⭐️

Let us start by converting tail recursive functions to loops. As a reminder, a tail recursive function is recursive function in which the recursive call(s) is(are) the last thing(s) to be executed in the function. For example, the following function is tail recursive:

@tailrec
def fTR(x: Int, acc: Int): Int =
  if (x == 0) acc
  else fTR(x - 1, acc + 1)

Indeed, the last thing to be executed in the body of f is either 0 or f(x - 1), which is the recursive call.

The @tailrec annotation is used to tell the compiler that the function is supposed to be tail recursive. The compiler will then emit an error if that is not the case. This is useful to make sure that you did not make a mistake when writing the function. When the compiler detects that a function is tail recursive, either because of the annotation or because it is able to infer it, it will compile it to a loop to improve the performance. Indeed a loop does not need to use space on the stack for each recursive call.

On the other hand, the following function is not tail recursive:

def f(x: Int): Int =
  if (x == 0) 0
  else 1 + f(x - 1)

In this case, the last thing executed in the case x != 0 is the addition 1 + f(x-1). This means that the recursive call must first be completed and then the result of the addition can be computed. This is not tail recursive.

Now that we have a better understanding of tail recursion, Let us convert some tail recursive functions to loops. For this exercise, you will need to convert the functions reverseAppend, length and foldLeft from the execises of week 1. The goal is to understand the process that the compiler follows by replicating manually, following a mechanical procedure.

Hint

Here is the procedure you can follow to convert a tail recursive function to a loop. We illustrate it with the tail recursive function fTR defined above.

  • Add a while true do loop:

    def fTRLoop(x: Int, acc: Int = 0): Int =
      while true do
    
  • Create a mutable variable for each parameter of the function and assign the values of the parameters to them.

    def fTRLoop(x: Int, acc: Int = 0): Int =
      var xVar = x
      var accVar = acc
      while true do
    
  • Create a mutable variable that will contains the return value of the function. Here it is the acc parameter. This parameter has a default value, so we can remove the parameter and assign the default value to the variable.

    def fTRLoop(x: Int): Int =
      var xVar = x
      var acc = 0
      while true do
    
  • Put the body of the function as the body of the loop.

    def fTRLoop(x: Int): Int =
      var xVar = x
      var acc = 0
      while true do
    
  • Replace the base case by a return statement returning the accumulator variable.

    def fTRLoop(x: Int): Int =
      var xVar = x
      var acc = 0
      while true do
        if xVar == 0 then
          return acc
    
  • Replace the recursive call by an assignment to the parameter variables.

    def fTRLoop(x: Int): Int =
      var acc = 0
      var xVar = x
      while true do
        if xVar == 0 then
          return acc
        acc = acc + 1
        xVar = xVar - 1
    
  • However, the compiler will not be happy with this code. Indeed, when type checking is performed, it will evaluate the while loop expression to be of type Unit. To solve this issue, the last statement should be of type Int, or throwing an exception, depending on your preference. Please note, however, that this last statement is unreachable and will then never be executed!

    def fTRLoop(x: Int): Int =
      var acc = 0
      var xVar = x
      while true do
        if xVar == 0 then
          return acc
        acc = acc + 1
        xVar = xVar - 1
      throw new AssertionError("Unreachable")
    

And there you go, the loop version of the tail recursive function!

def fTRLoop(x: Int): Int =
  var acc = 0
  var xVar = x
  while true do
    if xVar == 0 then
      return acc
    acc = acc + 1
    xVar = xVar - 1
  throw new AssertionError("Unreachable")

Here is a good resource about tail recursion elimination (the process of converting tail recursive functions to loops): Debunking the “expensive procedure calls” myth.

Now that you know the procedure, convert the following functions:

@tailrec
def reverseAppend(l1: List[Int], l2: List[Int]): List[Int] =
  if l1.isEmpty then l2
  else reverseAppend(l1.tail, l1.head :: l2)

def reverseAppendLoop(l1: List[Int], l2: List[Int]): List[Int] =
    ???

src/main/scala/tailRecursion/lists.scala

@tailrec
def foldLeft(l: List[Int], acc: Int)(f: (Int, Int) => Int): Int =
  if l.isEmpty then acc
  else foldLeft(l.tail, f(acc, l.head))(f)

def foldLeftLoop(l: List[Int], startValue: Int)(f: (Int, Int) => Int): Int =
    ???

src/main/scala/tailRecursion/lists.scala

@tailrec
def sum(l: List[Int], acc: Int = 0): Int =
  if l.isEmpty then acc
  else sum(l.tail, acc + l.head)

def sumLoop(l: List[Int]): Int =
    ???

src/main/scala/tailRecursion/lists.scala

Now we know how to transform mechanically a tail recursive function to a loop. So for the rest of this exercise set, feel free to write either loops or tail recursive functions, depending on what you prefer (when appropriate of course).

Now the question is: can we write any recursive function as tail recursive? If so, why do we bother with recursive functions? Let us find out!

@tailrec
def reverseAppend(l1: List[Int], l2: List[Int]): List[Int] =
  if l1.isEmpty then l2
  else reverseAppend(l1.tail, l1.head :: l2)

def reverseAppendLoop(l1: List[Int], l2: List[Int]): List[Int] =
  var current = l1
  var reversed = l2
  while true do
    if current.isEmpty then return reversed
    reversed = current.head :: reversed
    current = current.tail
  throw new AssertionError("Unreachable")

src/main/scala/tailRecursion/lists.scala

@tailrec
def foldLeft(l: List[Int], acc: Int)(f: (Int, Int) => Int): Int =
  if l.isEmpty then acc
  else foldLeft(l.tail, f(acc, l.head))(f)

def foldLeftLoop(l: List[Int], startValue: Int)(f: (Int, Int) => Int): Int =
  var current = l
  var accumulator = startValue
  while true do
    if current.isEmpty then return accumulator
    accumulator = f(accumulator, current.head)
    current = current.tail
  throw new AssertionError("Unreachable")

src/main/scala/tailRecursion/lists.scala

@tailrec
def sum(l: List[Int], acc: Int = 0): Int =
  if l.isEmpty then acc
  else sum(l.tail, acc + l.head)

def sumLoop(l: List[Int]): Int =
  var acc = 0
  var current = l
  while true do
    if current.isEmpty then return acc
    acc += current.head
    current = current.tail
  throw new AssertionError("Unreachable")

src/main/scala/tailRecursion/lists.scala

foldt

Let’s recall the foldt function of the SE exercises:

extension [T](l: List[T])
  def pairs(op: (T, T) => T): List[T] = l match
    case a :: b :: tl => op(a, b) :: tl.pairs(op)
    case _            => l
  def foldt(z: T)(op: (T, T) => T): T = l match
    case Nil       => z
    case List(t)   => t
    case _ :: tail => l.pairs(op).foldt(z)(op)

How would you write the function foldt with loops? You can start from the following template:

extension [T](l: List[T])
  def foldt(z: T)(op: (T, T) => T): T =
    ???

src/main/scala/tailRecursion/lists.scala

Extra exercise: how would you write pairs with a while loop?

Here is an example of solution:

extension [T](l: List[T])
  def foldt(z: T)(op: (T, T) => T): T =
    var list = l

    while true do
      list match
        // if list.size > 1
        case _ :: _ :: tail =>
          list = list.pairs(op)
        // if list.size == 1
        case a :: Nil => return a
        // if list.size == 0
        case Nil => return z

    throw new AssertionError("Unreachable")

src/main/scala/tailRecursion/lists.scala

Here is an example of solution for pairs:

extension [T](l: List[T])
  def pairs(op: (T, T) => T): List[T] =
    var ret: List[T] = Nil
    var list = l

    while true do
      list match
        case a :: b :: tail =>
          ret = op(a, b) :: ret
          list = tail
        case _ =>
          return ret.reverse ++ list

    throw new AssertionError("Unreachable")

src/main/scala/tailRecursion/lists.scala

groupBy

Now, you will implement a function groupBy by yourself, without using the standard groupBy method.

Implement two versions of groupBy:

def groupByForeach[T, S](f: T => S)(xs: List[T]): Map[S, List[T]] =
  var result = Map.empty[S, List[T]]
  for x <- xs do
    val key = f(x)
    val value = result.getOrElse(key, Nil)
    result = result + (key -> (x :: value))
  result

def groupByFoldRight[T, S](f: T => S)(xs: List[T]): Map[S, List[T]] =
  xs.foldRight(Map.empty[S, List[T]]) { (x, result) =>
    val key = f(x)
    val value = result.getOrElse(key, Nil)
    result + (key -> (x :: value))
  }

src/main/scala/tailRecursion/groupBy.scala

foldLeft versus foldRight

We commonly use foldLeft and foldRight to shorten simple recursive functions.

Can foldLeft be rewritten as a loop? How about foldRight? In both cases, write the code, or explain why it cannot be rewritten that way.

foldLeft can be rewritten as a loop because it processes elements of a collection from left to right in order. This order of traversal and accumulation fits naturally into an iterative loop structure:

def foldLeftForeach[B, A](z: B)(op: (B, A) => B)(xs: List[A]): B =
  var result = z
  xs foreach (x => result = op(result, x))
  result

src/main/scala/tailRecursion/fold.scala

foldRight cannot be easily rewritten as a simple loop because it processes the collection from right to left. This requires processing the last element first, which means recursion is more natural for this operation.

Tail recursion modulo context 🔥

We saw in the first exercise that a tail recursive function can be mechanically transfored to a loop. This transformation is mostly useful because it is performed automatically by the compiler.

In this exercise, we will explore a way to rewrite some non-tail recursive functions into tail-recursive ones.

To illustrate this technique, let us consider the map function on List[Int]:

def map(l: List[Int], f: Int => Int): List[Int] =
  if l.isEmpty then Nil
  else f(l.head) :: map(l.tail, f)

src/main/scala/tailRecursion/lists.scala

The only thing that happens after the call is the creation of a Cons instance. To create it, we need to know the head (which we know before the recursive call) but also the tail (which is computed recursively). So the tail is the difficult part: when the recursive call completes, the callee returns the tail. Does that suggest a solution?

Hint

The trick could be to shift responsibilities around so that the caller begins the construction of the Cons, and the callee finishes that construction by storing the computed tail.

To do so, we will create a list type with a mutable tail. This way we can construct the list before making the recursive call, and transfer the responsability to swap the tail to the recursive call.

Now that you have the idea, try to implement the mapTRWorker function:

enum MutableList:
  case Nil
  case Cons(val hd: Int, var tail: MutableList)

import MutableList.*

def mapTR(l: MutableList, f: Int => Int): MutableList =
  l match
    case Nil => Nil
    case Cons(hd, tl) =>
      val acc: Cons = Cons(f(hd), Nil)
      mapTRWorker(tl, f, acc)
      acc

// @tailrec uncomment when working on the exercise
def mapTRWorker(
    l: MutableList,
    f: Int => Int,
    acc: MutableList.Cons
): Unit =
  ???

src/main/scala/tailRecursion/lists.scala

enum MutableList:
  case Nil
  case Cons(val hd: Int, var tail: MutableList)

import MutableList.*

def mapTR(l: MutableList, f: Int => Int): MutableList =
  l match
    case Nil => Nil
    case Cons(hd, tl) =>
      val acc: Cons = Cons(f(hd), Nil)
      mapTRWorker(tl, f, acc)
      acc

// @tailrec uncomment when working on the exercise
def mapTRWorker(
    l: MutableList,
    f: Int => Int,
    acc: MutableList.Cons
): Unit =
  l match
    case Nil => ()
    case Cons(h, t) =>
      acc.tail = Cons(f(h), Nil)
      mapTRWorker(t, f, acc.tail.asInstanceOf[Cons])

src/main/scala/tailRecursion/lists.scala

This technique is often called “destination-passing style”. This is used notably by the map function from the standard scala library:

  final override def map[B](f: A => B): List[B] = {
    if (this eq Nil) Nil else {
      val h = new ::[B](f(head), Nil)
      var t: ::[B] = h
      var rest = tail
      while (rest ne Nil) {
        val nx = new ::(f(rest.head), Nil)
        t.next = nx
        t = nx
        rest = rest.tail
      }
      releaseFence()
      h
    }
  }

You are not expected to understand everything that is going on here. However, you can recognise that a new list with an empty tail and the head f(head) is created here val h = new ::[B](f(head), Nil). Then, the function is implemented using a while loop in which again a new list with an empty tail is created with val nx = new ::(f(rest.head), Nil), then the tail of the current list is modified here t.next = nx.

This function is more complicated and uses some internal structures specific to the implementation of the class scala.collection.List, but the idea is the same as the one you implemented.

Today some languages are able to do this automatically like OCaml, as described in this paper. Scala is not today yet, so recursion is still relevant :)

If you are interested to learn more about it, here are some resources:

Looping on Trees

In our quest to find a case in which recursion really is easier to use than loops, we will now look at trees. We will use the following definition of binary trees:

enum Tree[T]:
  case Leaf(value: T)
  case Node(left: Tree[T], right: Tree[T])

src/main/scala/tailRecursion/trees.scala

Sum of leaves - rotation

Let us start with a simple function that computes the sum of a tree’s leaves. The recursive version is the following:

def sumRec(t: Tree[Int]): Int =
  t match
    case Leaf(value)       => value
    case Node(left, right) => sumRec(left) + sumRec(right)

src/main/scala/tailRecursion/trees.scala

On right line trees ⭐️

In the 2023 midterm, we saw the concept of right line trees. As a reminder, a right line tree is a tree in which each node is either a leaf, or has a leaf child on the left. The following function checks whether a tree is a right line tree:

def isRightLineTree(t: Tree[Int]): Boolean =
  t match
    case Leaf(_)              => true
    case Node(Leaf(_), right) => isRightLineTree(right)
    case _                    => false

src/main/scala/tailRecursion/trees.scala

Can you see the similarity between a right line tree and a list in the context of tail recursive functions?

Before writing any code, think about this: what can (a + b) + c = a + (b + c) mean on trees? Can we exploit this to write a loop (or tail recursive) function?

Hint

This represents the right rotation on trees. This property of + is the associativity.

In our context, it means that the tree can be rearranged to compute the sum of leaves in a different way without affecting the result.

Let us write an imperative version (or tail recursive, as you prefer) of the sum function that works only for right line trees. Do not forget to add the correct scala call ensure that the tree is indeed a right line tree before computing the sum 🙂:

def sumRightLineTree(tr: Tree[Int]): Int =
  ???

src/main/scala/tailRecursion/trees.scala

What would happen if the operation is not associative, like, for example, the substraction?

Hint You can take inspiration from the sum function on the list we saw in the first exercise of this session. Think about the similarity between the structure of a list and the one from a right line tree.
def sumRightLineTree(tr: Tree[Int]): Int =
  require(isRightLineTree(tr))
  var acc = 0
  var t = tr
  while true do
    t match
      case Leaf(value) =>
        acc += value
        return acc
      case Node(Leaf(value), right) =>
        acc += value
        t = right
      case _ => // cannot happen thanks to the require clause
        return acc
  acc

def sumRightLineTreeTailRec(tr: Tree[Int], acc: Int = 0): Int =
  require(isRightLineTree(tr))
  tr match
    case Leaf(value) =>
      acc + value
    case Node(Leaf(value), right) =>
      sumRightLineTreeTailRec(right, acc + value)
    case _ => // cannot happen thanks to the require clause
      acc

src/main/scala/tailRecursion/trees.scala

Using rotations ⭐️

A right rotation is an operation on a tree that gives a new tree with less leaves on the left hand side. Can we use this operation to compute the sum of leaves on an arbitrary tree while reusing the idea of the sum we implemented on the right line tree? Let’s find out!

Implement the sumRotate function that computes the sum of leaves’ values using right rotations:

def sumRotate(tr: Tree[Int], acc: Int): Int =
  ???

src/main/scala/tailRecursion/trees.scala

Can you name which property the operation done on the leaves must satisfy for this to work?

def sumRotate(tr: Tree[Int], acc: Int): Int =
  tr match
    case Leaf(value)               => acc + value
    case Node(Leaf(value), right)  => sumRotate(right, acc + value)
    case Node(Node(ll, lr), right) => sumRotate(Node(ll, Node(lr, right)), acc)

src/main/scala/tailRecursion/trees.scala

The sum using rotation works correctly because the sum of leaves is agnostic to the shape of the tree.

Sum of leaves - DFS

Now, let us write an imperative version of the sum function. Before writing any code think well about it. On what elements would you iterate? How make sure you visit all the nodes? How would you keep track of the nodes you still need to visit?

def sumLoop(t: Tree[Int]): Int =
  ???

src/main/scala/tailRecursion/trees.scala

Hint As you might have realised, this is not straightforward. The main issue is that you need to keep track of the nodes to visit. What datastructure would you use to store the nodes you encounter and will visit later?
Spoiler You should indeed use a Stack to keep track of the nodes you have to visit. You can use again the Stack class from the scala library.
def sumLoop(t: Tree[Int]): Int =
  var sum = 0
  var toVisit = Stack(t)
  while toVisit.nonEmpty do
    toVisit.pop() match
      case Leaf(value) =>
        sum += value
      case Node(left, right) =>
        toVisit.push(right)
        toVisit.push(left)
  sum

src/main/scala/tailRecursion/trees.scala

Reduce on tree 🔥

We will now take a look at another function on trees: reduce. As a reminder, reduce is defined recursively as follows on trees:

def reduce[T](tr: Tree[T], f: (T, T) => T): T =
  tr match
    case Leaf(value)       => value
    case Node(left, right) => f(reduce(left, f), reduce(right, f))

src/main/scala/tailRecursion/trees.scala

We will write an imperative version of this function.

To kickstart, let us implement a mutable Stack structure, just as you used in the previous exercises. Our MStack is based on a List and will extend the following trait:

trait MStackTrait[A]:
  def push(a: A): Unit
  def pop(): A
  def isEmpty: Boolean
  def size: Int
  def contains(a: A): Boolean

case class MStack[A](var l: List[A] = Nil) extends MStackTrait[A]:
  def push(a: A): Unit =
    ???
  def pop(): A =
    ???
  def isEmpty: Boolean =
    ???
  def size: Int =
    ???
  def contains(a: A): Boolean =
    ???

src/main/scala/tailRecursion/trees.scala

Now let us implement a post order traversal on trees. This function will return the subtrees in post order, which means first the left child, then the right child, then the node itself. For example, the post order traversal of the following tree:

val tree =
  Node(
    Node(
      Leaf(1),
      Leaf(2)
    ),
    Leaf(3)
  )

is the following list:

List(
  Leaf(1),
  Leaf(2),
  Node(Leaf(1), Leaf(2)),
  Leaf(3),
  Node(Node(Leaf(1), Leaf(2)), Leaf(3))
)

Now, implement the postOrderTraversal function using a while loop and the MStack type that you just implemented. Think hard before writing the function. How do you keep track of the nodes you will visit? How you ensure that you add the nodes in the correct order?

def postOrderTraversal[T](tr: Tree[T]): List[Tree[T]] =
  ???

src/main/scala/tailRecursion/trees.scala

This postorder traversal should be enough to implement reduce!

Hint

You’ll need an intermediate data structure to keep track of partially reduced results while you go over the post order. You can use a Map that associates tree notes to the result of reduce on them, or you can use a Stack with a bit more thinking about the order in which nodes appear in the post-order.

def reduceLoop[T](tr: Tree[T], f: (T, T) => T): T =
  ???

src/main/scala/tailRecursion/trees.scala

trait MStackTrait[A]:
  def push(a: A): Unit
  def pop(): A
  def isEmpty: Boolean
  def size: Int
  def contains(a: A): Boolean

case class MStack[A](var l: List[A] = Nil) extends MStackTrait[A]:
  def push(a: A): Unit =
    l = a :: l
  def pop(): A =
    val a = l.head
    l = l.tail
    a
  def isEmpty: Boolean =
    l.isEmpty
  def size: Int =
    l.size
  def contains(a: A): Boolean =
    l.contains(a)

src/main/scala/tailRecursion/trees.scala

def postOrderTraversal[T](tr: Tree[T]): List[Tree[T]] =
  var toVisit = MStack[Tree[T]]()
  toVisit.push(tr)
  var postOrderNodes: List[Tree[T]] = Nil
  while !toVisit.isEmpty do
    val n = toVisit.pop()
    postOrderNodes = n :: postOrderNodes
    n match
      case Node(left, right) =>
        toVisit.push(left)
        toVisit.push(right)
      case Leaf(_) =>
  postOrderNodes

src/main/scala/tailRecursion/trees.scala

For reduceLoop, here are two versions: one that uses a Map

def reduceLoop[T](tr: Tree[T], f: (T, T) => T): T =
  var cache: Map[Tree[T], T] = Map()

  for (t, idx) <- postOrderTraversal(tr).zipWithIndex do
    t match
      case Leaf(v) => cache = cache + (t -> v)
      case Node(left, right) =>
        val leftValue = cache(left)
        val rightValue = cache(right)
        cache = cache + (t -> f(leftValue, rightValue))
  cache(tr)

src/main/scala/tailRecursion/trees.scala

… and one that uses a stack (note that elements could also be removed from the map as they are accessed in the algorithm above):

def reduceWithStack[T](tr: Tree[T], f: (T, T) => T): T =
  val stack = Stack.empty[T]
  for t <- postOrderTraversal(tr) do
    t match
      case Leaf(value) => stack.push(value)
      case Node(left, right) =>
        val (r, l) = (stack.pop(), stack.pop())
        stack.push(f(l, r))
  stack.pop()

src/main/scala/tailRecursion/trees.scala

As a final, 🔥🔥 exercise: could you further reduce the memory usage by merging both traversals?

Map on tree

Now that you implemented reduce, you can implement map using the same principles.

No solution provided, share on Ed!

Proof of correctness of reduce on trees 🔥

We will now revisit the reduce function that uses the post order traversal from the exercise Reduce on tree. If you did not do it, here is the implementation:

Solution
trait MStackTrait[A]:
  def push(a: A): Unit
  def pop(): A
  def isEmpty: Boolean
  def size: Int
  def contains(a: A): Boolean

case class MStack[A](var l: List[A] = Nil) extends MStackTrait[A]:
  def push(a: A): Unit =
    l = a :: l
  def pop(): A =
    val a = l.head
    l = l.tail
    a
  def isEmpty: Boolean =
    l.isEmpty
  def size: Int =
    l.size
  def contains(a: A): Boolean =
    l.contains(a)

src/main/scala/tailRecursion/trees.scala

def postOrderTraversal[T](tr: Tree[T]): List[Tree[T]] =
  var toVisit = MStack[Tree[T]]()
  toVisit.push(tr)
  var postOrderNodes: List[Tree[T]] = Nil
  while !toVisit.isEmpty do
    val n = toVisit.pop()
    postOrderNodes = n :: postOrderNodes
    n match
      case Node(left, right) =>
        toVisit.push(left)
        toVisit.push(right)
      case Leaf(_) =>
  postOrderNodes

src/main/scala/tailRecursion/trees.scala

def reduceLoop[T](tr: Tree[T], f: (T, T) => T): T =
  var cache: Map[Tree[T], T] = Map()

  for (t, idx) <- postOrderTraversal(tr).zipWithIndex do
    t match
      case Leaf(v) => cache = cache + (t -> v)
      case Node(left, right) =>
        val leftValue = cache(left)
        val rightValue = cache(right)
        cache = cache + (t -> f(leftValue, rightValue))
  cache(tr)

src/main/scala/tailRecursion/trees.scala

If you are interested in program verification and proofs, two courses are given at EPFL in this area:

Post order traversal

Let us start by proving the correctness of the post order traversal algorithm. In words, the algorithm is correct if the produced list contains all the nodes of the tree, and if the order of the nodes is indeed a post order traversal (i.e., the children appears in the list at smaller index than their parent ). In particular, the list should end with the root.

Your task is to write the above postcondition in scala code and a loop invariant for the postOrderTraversal function that proves it is indeed satisfied at the end. Be careful, the invariant must take the state of the stack into account.

object postOrderTraversalProof:
  def invariantPostOrder[T](currentList: List[Tree[T]], toVisit: MStackTrait[Tree[T]], root: Tree[T]): Boolean =
    currentList.forall(tr =>
      tr match
        case Leaf(_) => true
        case n @ Node(left, right) =>
          ((currentList.contains(left) && currentList.indexOf(left) < currentList.indexOf(n)) || toVisit.contains(
            left
          )) &&
          ((currentList.contains(right) && currentList.indexOf(right) < currentList.indexOf(n)) || toVisit.contains(
            right
          ))
    ) && (currentList.isEmpty && toVisit.size == 1 && toVisit.contains(root) || currentList.last == root)

  def postOrderTraversal[T](tr: Tree[T]): List[Tree[T]] =
    var toVisit = MStack[Tree[T]]()
    toVisit.push(tr)
    var postOrderNodes: List[Tree[T]] = Nil
    assert(invariantPostOrder(postOrderNodes, toVisit, tr))
    while !toVisit.isEmpty do
      toVisit.pop() match
        case n @ Leaf(t) =>
          postOrderNodes = n :: postOrderNodes
        case n @ Node(left, right) =>
          postOrderNodes = n :: postOrderNodes
          toVisit.push(right)
          toVisit.push(left)
      assert(invariantPostOrder(postOrderNodes, toVisit, tr))
    assert(invariantPostOrder(postOrderNodes, toVisit, tr))
    postOrderNodes

src/main/scala/tailRecursion/trees.scala

reduce

Now that you proved the correctness of the post order traversal, you can prove the correctness of the reduce function. The postcondition of reduce in our case is that the cache contains the root, and that this value is equal to reduce(root, f).

Your task is to write the above postcondition in scala code and a loop invariant for the reduce function that proves it is indeed satisfied at the end. You can write one invariant encoding the validity of the cache, i.e., that all values it contains are indeed correct with respect to the key and the function f, and one invariant that encodes the correctness of how the cache is updated in the loop.

object reduceProof:
  def reduce[T](tr: Tree[T], f: (T, T) => T): T =
    tr match
      case Leaf(value)       => value
      case Node(left, right) => f(reduce(left, f), reduce(right, f))

  def invariantReduce[T](currentCache: Map[Tree[T], T], f: (T, T) => T): Boolean =
    currentCache.keySet.forall(k => currentCache(k) == reduce(k, f))

  def forInvariant[T](
      postOrderList: List[Tree[T]],
      currentIndex: Int,
      currentCache: Map[Tree[T], T]
  ): Boolean =
    postOrderList.take(currentIndex).forall(k => currentCache.contains(k))

  def reduceLoop[T](tr: Tree[T], f: (T, T) => T): T =
    var cache: Map[Tree[T], T] = Map()
    assert(invariantReduce(cache, f))
    for (t, idx) <- postOrderTraversal(tr).zipWithIndex do
      assert(invariantReduce(cache, f))
      assert(forInvariant(postOrderTraversal(tr), idx, cache))
      t match
        case Leaf(v) => cache = cache + (t -> v)
        case Node(left, right) =>
          val leftValue = cache(left)
          val rightValue = cache(right)
          cache = cache + (t -> f(leftValue, rightValue))
      assert(invariantReduce(cache, f))
      assert(forInvariant(postOrderTraversal(tr), idx, cache))
    assert(invariantReduce(cache, f))
    assert(cache.contains(tr))
    cache(tr)

src/main/scala/tailRecursion/trees.scala

Exceptional control flow

An exceptional contains method ⭐️

  1. Consider the following two implementations of contains:

    extension [T](l: List[T])
      final def containsRec(t0: T): Boolean =
        l match
          case Nil      => false
          case hd :: tl => hd == t0 || tl.containsRec(t0)
    

    src/main/scala/exceptions/Exceptions.scala

    extension [T](l: List[T])
      final def containsFold(t0: T): Boolean =
        l.foldRight(false)((hd, found) => found || hd == t0)
    

    src/main/scala/exceptions/Exceptions.scala

    Is one of them preferable? Why?

  2. Which mechanism do you know to interrupt a computation before it completes? Use it to rewrite contains using forEach.

    Hint

    Use an exception! They work just the same in Scala as in Java.

    What advantages does this approach have?

  1. containsRec is much better: it stops as soon as it find a result!

  2. Exceptions allow us to leverage .foreach:

    extension [T](l: List[T])
      final def containsExn(t0: T): Boolean =
        case object Found extends Exception
        try
          for hd <- l if hd == t0 do throw Found
          false
        catch
          case Found => true
    

    src/main/scala/exceptions/Exceptions.scala

Avoiding accidental escape: boundary/break ⭐️

Exceptions are great, but they risk escaping: if you forget to catch an exception raised for control flow, it will propagate to the caller of your function, and cause havoc there.

  1. Read the boundary/break documentation.

  2. Use a boundary to reimplement contains a fourth time.

  3. 🔥 Which of these four implementations of contains is fastest? Make a guess, then confirm it by writing a JMH benchmark.

Value-carrying exceptions

  1. Define a custom error type to hold values. Use it to write an exception-based implementation of find.

  2. Use boundary/break instead of a custom error type.

import scala.util.boundary

extension [T](l: List[T])
  final def containsBoundary(t0: T): Boolean =
    boundary:
      for hd <- l if hd == t0 do boundary.break(true)
      false

src/main/scala/exceptions/Exceptions.scala

extension [T](l: List[T])
  final def findExn(p: T => Boolean): Option[T] =
    case class FoundWith(t: T) extends Exception

    try
      for hd <- l if p(hd) do throw FoundWith(hd)
      None
    catch
      case FoundWith(t) => Some(t)

src/main/scala/exceptions/Exceptions.scala

import scala.util.boundary

extension [T](l: List[T])
  final def findBoundary(p: T => Boolean): Option[T] =
    boundary:
      for hd <- l if p(hd) do boundary.break(Some(hd))
      None

src/main/scala/exceptions/Exceptions.scala

As for performance, the best way to answer this question is to benchmark! But we can still make some educated guesses:

  • boundary/break is implemented using exceptions under the hood, so containsExn and containsBoundary should be similar.

  • containsRec is tail recursive, so it will compile to a clean loop: it should be very fast.

  • containsFold does not exit early, so it should be much slower — not because of the implementation of fold, but because of the fact that it processes the whole list instead of stopping when it find the element.

To be sure, we can use the following benchmark:

@Warmup(iterations = 10, time = 100, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@BenchmarkMode(Array(Mode.AverageTime))
@Fork(1)
class SimpleBenchmark:
  @Benchmark
  def using_exn: Unit = SimpleBenchmark.l.containsExn(5000000L)

  @Benchmark
  def using_fold: Unit = SimpleBenchmark.l.containsFold(5000000L)

  @Benchmark
  def using_rec: Unit = SimpleBenchmark.l.containsRec(5000000L)

  @Benchmark
  def using_boundary: Unit = SimpleBenchmark.l.containsBoundary(5000000L)

object SimpleBenchmark:
  val l: List[Long] = (1L to 10000000L).toList

src/main/scala/exceptions/Benchmarks.scala

And here is one example run:

Benchmark                       Mode  Cnt    Score    Error  Units
SimpleBenchmark.using_boundary  avgt   50   24.121 ±  0.453  ms/op
SimpleBenchmark.using_exn       avgt   50   24.767 ±  0.509  ms/op
SimpleBenchmark.using_fold      avgt   50  125.344 ± 22.122  ms/op
SimpleBenchmark.using_rec       avgt   50   20.448 ±  0.224  ms/op

Memoization

Briefly, memoization is the process of augmenting a function with a mutable cache that records the output of the function every time it is called. If the function is subsequently called again with a previously-seen input, the result can be returned from cache instead of being recomputed.

A step-by-step example ⭐️

To see why memoization may be useful, consider a simple example: the Fibonacci function, which we studied previously:

def fib(n: Int): Int =
  if n <= 1 then 1 else fib(n - 1) + fib(n - 2)

src/main/scala/memo/Fib.scala

To compute fib(4) we made two recursive calls: one to fib(3), and one to fib(2). To compute fib(3), we again make two recursive calls: one to fib(2), and one to fib(1). Without special precautions, we end up computing fib(2) twice. Other parts of the computation are similarly repeated.

       fib(4)
=== (  fib(3)                     + fib(2)           )
=== ( (fib(2)           + fib(1)) + (fib(1) + fib(0)))
=== (((fib(1) + fib(0)) + 1     ) + (1        1     ))
=== (((1       + 1   )  + 1     ) + (1        1     ))

Interestingly, the cost of computing fib(n) grows exactly as fib(n): if it takes $T(k)$ steps to compute fib(k), then the cost of computing fib(n) is $T(n) = T(n - 1) + T(n - 2)$.

Memoization

All this redundant computation is unnecessary. Instead, as our first attempt to address this problem, we can create a cache to store fib’s results:

import scala.collection.mutable.Map

def fibMemo(n: Int): Int =
  val cache: Map[Int, Int] = Map()
  def loop(idx: Int): Int =
    cache.getOrElseUpdate(
      idx,
      if idx <= 1 then 1
      else loop(idx - 1) + loop(idx - 2)
    )
  loop(n)

src/main/scala/memo/Fib.scala

  1. Can you convince yourself that this function behaves identically to the version without a cache?

  2. How large does the cache grow (i.e., how many entries get created in the cache) as we evaluate fib(k)? What entries does it contain when the computation completes?

Subproblem graph

To save space, we need to understand the structure of the subproblem graph of the Fibonacci function. The subproblem graph is a graph where:

For example, the nodes of the computation graph of fib(4) are 4, 3, 2, 1, 0 and its edges are 4 → 3, 4 → 2, 3 → 2, 3 → 1, 2 → 1, 2 → 0.

Here is one representation of the graph (notice that there are no edges from 1 to 0: they are both leaves):

      ┌─────┬─────┐
      │     v     v
4  →  3  →  2  →  1      0
│     ^     ^│    ^      ^
└─────┴─────┘└────┴──────┘

Dynamic programming

The subproblem graph captures precisely the notion of dependency: we cannot compute the output of a function on a given input node unless we know the outputs of the function on all the nodes it points to. Given this:

The key questions to be able to do dynamic programming are: How long do we need to remember cached values? Which computation order minimizes this time?

To answer, we proceed in three steps:

  1. Find a traversal of the subproblem graph, starting from the leaves, such that dependencies are always computed before their parents. This is called a reverse topological sort of the subproblem graph.

  2. Rewrite our algorithm to construct the memoization cache iteratively, in the order given by step 1.

    For Fibonnaci, the order is very simple: 0, 1, 2, 3, 4, …. In other words, to compute fib(n), it is sufficient to know the values of all fib(k) where k < n. The result of step 2 is, hence, as follows:

    import scala.collection.mutable.Map
    
    def fibIter(n: Int): Int =
      val cache: Map[Int, Int] = Map()
      for idx <- 0 to n do
        cache(idx) =
          if idx <= 1 then 1 else cache(idx - 1) + cache(idx - 2)
      cache(n)
    

    src/main/scala/memo/Fib.scala

  3. Discard entries from the memoization cache as soon as they are not used any more. In the case of the Fibonacci function, we only need to keep the last two entries:

    import scala.collection.mutable.Map
    
    def fibIterOpt(n: Int): Int =
      val cache: Map[Int, Int] = Map()
      for idx <- 0 to n do
        cache(idx) =
          if idx <= 1 then 1 else cache(idx - 1) + cache(idx - 2)
        cache.remove(idx - 2)
      cache(n)
    

    src/main/scala/memo/Fib.scala

    And we can, as a last cleanup step, entirely eliminate the cache, keeping only two variables:

    import scala.collection.mutable.Map
    
    def fibIterFinal(n: Int): Int =
      var f0 = 1
      var f1 = 1
      for idx <- 2 to n do
        val f = f0 + f1
        f0 = f1
        f1 = f
      f1
    

    src/main/scala/memo/Fib.scala

First application: $\binom{n}{k}$ “$n$ choose $k$”

The function choose(n, k) computes how many ways there are to pick k elements among n, without considering order and without allowing repetitions. This choice can be done in two ways:

  1. Implement choose as a recursive function using this equation. Mind the base cases!

    def choose(n: Int, k: Int): Int =
      ???
    

    src/main/scala/memo/Choose.scala

  2. Draw the subproblem graph of choose(5, 3) as a tree. Each node should have a pair of numbers, since the function takes two arguments. Do you notice repeated work?

  3. Write a memoized implementation of choose. The cache should map pairs of numbers (inputs) to single numbers (outputs):

    def chooseMemo(n: Int, k: Int): Int =
      ???
    

    src/main/scala/memo/Choose.scala

  4. Redraw the subproblem graph, but this time lay it out as an array with 6 columns and 4 rows: place node (i, j) at position x = i, y = j. What do you notice about the structure of the graph? Propose a reverse topological ordering of it.

  5. Replace the Map-based cache with a two-dimensional array, and rewrite the memoized algorithm to build the cache iteratively, without recursion.

  6. Is the whole cache needed at all times? Rewrite the algorithm to use less memory.

  1. :

    def choose(n: Int, k: Int): Int =
      if k <= 0 || k >= n then 1
      else choose(n - 1, k - 1) + choose(n - 1, k)
    

    src/main/scala/memo/Choose.scala

  2. The tracing technique that we’ve seen in lecture produced a tree that is exactly the subproblem graph.

  3. :

def chooseMemo(n: Int, k: Int): Int =
  val cache: Map[(Int, Int), Int] = Map()

  def loop(n: Int, k: Int): Int =
    cache.getOrElseUpdate(
      (n, k),
      if k <= 0 || k >= n then 1
      else loop(n - 1, k - 1) + loop(n - 1, k)
    )

  loop(n, k)

src/main/scala/memo/Choose.scala

The following drawing shows the subproblem graph of nChooseK(5, 2), with each repeated computation replaced with [cached].

                     nChooseK(5, 2)
                    /             \
          nChooseK(4, 1)           nChooseK(4, 2)
          /         \             /             \
nChooseK(3, 0) nChooseK(3, 1)    [cached]       nChooseK(3, 2)
               /       \                        /       \
       nChooseK(2, 0) nChooseK(2, 1)    nChooseK(2, 1)   [cached]
                      /   \             /   \
            nChooseK(1, 0) [cached]  [cached]  [cached]

  1. Each cell points to its neighbor directly below and the one diagonally below to the left.

  2.  

    def chooseIter(n: Int, k: Int): Int =
      val dp = Array.ofDim[Int](n + 1, k + 1)
    
      for nn <- 0 to n do
        for kk <- 0 to Math.min(nn, k) do
          dp(nn)(kk) =
            if kk <= 0 || kk >= nn then 1
            else dp(nn - 1)(kk - 1) + dp(nn - 1)(kk)
    
      if k <= 0 || k >= n then 1
      else dp(n)(k)
    

    src/main/scala/memo/Choose.scala

  3. Here is one first solution, directly adapted from the previous one.

    def chooseIterFinal(n: Int, k: Int): Int =
      var prev = Array.empty[Int]
    
      for nn <- 0 to n do
        var nxt = Array.ofDim[Int](Math.min(nn, k) + 1)
        for kk <- 0 until nxt.length do
          nxt(kk) =
            if kk <= 0 || kk >= nn then 1
            else prev(kk - 1) + prev(kk)
        prev = nxt
    
      if k <= 0 || k >= n then 1
      else prev(k)
    

    src/main/scala/memo/Choose.scala

    We can adjust it further by keeping only one column:

    def chooseIterFinalOpt(n: Int, k: Int): Int =
      if k <= 0 || k >= n then 1
      else
        var col = Array.fill(math.min(n, k) + 1)(1)
        for
          nn <- 2 to n
          kk <- math.min(k, nn - 1) until 0 by -1
        do
          col(kk) = col(kk - 1) + col(kk)
        col(k)
    

    src/main/scala/memo/Choose.scala

    … or by iterating along diagonals:

    def chooseIterFinalGC(n: Int, k: Int): Int =
      if k <= 0 || k >= n then 1
      else
        val arrLen = n - k + 1
        val diag = Array.fill(arrLen)(1)
        for
          _ <- 1 to k
          i <- 1 until arrLen
        do
          diag(i) = diag(i) + diag(i - 1)
        diag(n - k)
    

    src/main/scala/memo/Choose.scala

    This matrix we’re building is called “Pascal’s triangle” — search for this term and “dynamic programming” to read more about this problem!

More applications: memoizing every previous CS214 problem ⭐️

Train yourself to add memoization to functions by revisiting previous CS-214 problems. Particularly relevant are coinChange and des chiffres et des lettres.

No solution provided. The functions should compute the same results, and you can check that they are properly memoized by tracing them and confirming that the same result does not appear twice.

Benchmarking 🔥

The original solution to the Anagrams lab was not exactly fast. Memoize the recursive part of the anagrams function, and measure the resulting speed improvements. How much faster does it get?

No solution provided. Share your results on Ed!

Tower of Hanoi

A popular item in dentist offices, children museums, and on “10 original gift ideas for the holidays” lists is the game called “Tower of Hanoi”.

The game has three pegs, and 7 disks of increasing size, each with a hole in their center. In the initial configuration, the disks are stacked from largest to smallest on the leftmost peg:

      -|-              |               |
     --|--             |               |
    ---|---            |               |
   ----|----           |               |
  -----|-----          |               |
 ------|------         |               |
-------|-------        |               |
==== PEG 0 ==== ==== PEG 1 ==== ==== PEG 2 ====

The aim is to move all the disks to the rightmost peg, with one rule: a larger disk may never rest on top of a smaller disk. Hence, this is a valid move:

       |               |               |
     --|--             |               |
    ---|---            |               |
   ----|----           |               |
  -----|-----          |               |
 ------|------         |               |
-------|-------        |              -|-
==== PEG 0 ==== ==== PEG 1 ==== ==== PEG 2 ====

… and so is this:

       |               |               |
       |               |               |
    ---|---            |               |
   ----|----           |               |
  -----|-----          |               |
 ------|------         |               |
-------|-------      --|--            -|-
==== PEG 0 ==== ==== PEG 1 ==== ==== PEG 2 ====

… but after this the only valid moves are $2 \to 1$ (moving the disk from peg 2 to peg 1) as well as $2 \to 0$ and $1 \to 0$.

A solution of the game is a list of moves $i \to j$ that moves all disks from peg 0 to peg 2. Here is a solution for the case of three disks:

  -|-      |       |
 --|--     |       |
---|---    |       |
=Left= =Middle= =Right=

Left → Right
   |       |       |
 --|--     |       |
---|---    |      -|-
=Left= =Middle= =Right=

Left → Middle
   |       |       |
   |       |       |
---|---  --|--    -|-
=Left= =Middle= =Right=

Right → Middle
   |       |       |
   |      -|-      |
---|---  --|--     |
=Left= =Middle= =Right=

Left → Right
   |       |       |
   |      -|-      |
   |     --|--  ---|---
=Left= =Middle= =Right=

Middle → Left
   |       |       |
   |       |       |
  -|-    --|--  ---|---
=Left= =Middle= =Right=

Middle → Right
   |       |       |
   |       |     --|--
  -|-      |    ---|---
=Left= =Middle= =Right=

Left → Right
   |       |      -|-
   |       |     --|--
   |       |    ---|---
=Left= =Middle= =Right=

These diagrams are generated by calling viewMoves(hanoi(3), 3). This function is provided in the code supplement to this exercise set.

  1. Any good text editor or IDE should have an implementation of tower of Hanoi (if you’re using Emacs, simply use M-x hanoi to start it). Use it to familiarize yourself with the game.

  2. Write a function that computes a solution to the problem with $n$ disks. Check below for an important hint, or skip the hint if you prefer a 🔥 exercise. In any case, remember the fundamental question of recursive problems: how do I express a solution to my problem in terms of smaller subproblems?

    enum Peg:
      case Left, Middle, Right
    
    case class Move(from: Peg, to: Peg)
    

    src/main/scala/memo/Hanoi.scala

    def hanoi(n: Int): Seq[Move] =
      ???
    

    src/main/scala/memo/Hanoi.scala

    Hint

    You need to solve a more general problem: how to move $n$ disks from peg $a$ to peg $b$. Can you solve the problem with 7 disks if you know how to move the first 6 disks to the middle peg?

    def hanoiHelper(src: Peg, dst: Peg, third: Peg, n: Int): Seq[Move] =
      ???
    

    src/main/scala/memo/Hanoi.scala

  3. Can this program benefit from memoization?

  1. The solution is to switch to Emacs.

  2. Here’s a scala implementation:

    def hanoiHelper(src: Peg, dst: Peg, third: Peg, n: Int): Seq[Move] =
      if n == 0 then Vector()
      else
        hanoiHelper(src, third, dst, n - 1) ++
          Vector(Move(src, dst)) ++
          hanoiHelper(third, dst, src, n - 1)
    

    src/main/scala/memo/Hanoi.scala

    def hanoi(n: Int): Seq[Move] =
      hanoiHelper(Peg.Left, Peg.Right, Peg.Middle, n)
    

    src/main/scala/memo/Hanoi.scala

    … and Wikipedia has all the details.

  3. Yes; in fact, it can benefit from more than just memoizing based on the 4 inputs src, dst, mid, and n, because solutions can be renamed: if we have a solution to transfer 6 disks from peg 0 to peg 1, then we immediately have a solution for peg 1 to peg 2, for example. This reduces the complexity of the code from exponential to linear, since the two recursive calls in hanoiHelper are with the same height n - 1.

A memoizing fixpoint combinator 🔥

Take a look back at the combinator exercise from the polymorphism week.

Where we stopped, the first step of memoization always looks the same: starting from a function def f(input: …) = …, we write the following:

def fMemo(input: …) =
  val cache: Map[…] = Map()
  def loop(input: …) =
    cache.getOrElseUpdate(input,
      …(body)…)
  loop(input)

It would be nice to be able to abstract over this pattern. Define a higher-order function memo to do so:

def memo[A, B](f: (A, A => B) => B)(a: A): B =
  ???

src/main/scala/memo/Combinator.scala

This function should be such that we can define fib and choose as follows:

val fib = memo: (n: Int, f: Int => Int) =>
  if n <= 1 then 1 else f(n - 1) + f(n - 2)

src/main/scala/memo/Combinator.scala

val choose = memo[(Int, Int), Int] {
  case ((n, k), f) =>
    if k <= 0 || k >= n then 1
    else f((n - 1, k - 1)) + f((n - 1, k))
}

src/main/scala/memo/Combinator.scala

Hint

Start from the fixpoint combinator from the previous exercise:

def fixpoint[A, B](f: (A, A => B) => B)(a: A): B =
  def loop(a: A): B = f(a, loop)
  f(a, loop)

src/main/scala/memo/Combinator.scala

def memo[A, B](f: (A, A => B) => B)(a: A): B =
  val cache = mutable.Map.empty[A, B]
  def loop(a: A): B =
    cache.getOrElseUpdate(a, f(a, loop))
  loop(a)

src/main/scala/memo/Combinator.scala