Last updated on
Effects, memoization
Welcome to week 11 of CS-214 — Software Construction!
As usual, ⭐️ indicates the most important exercises and questions and 🔥 indicates the most challenging ones. Exercises or questions marked 🧪 are intended to build up to concepts used in this week’s lab.
You do not need to complete all exercises to succeed in this class, and you do not need to do all exercises in the order they are written.
We strongly encourage you to solve the exercises on paper first, in groups. After completing a first draft on paper, you may want to check your solutions on your computer. To do so, you can download the scaffold code (ZIP).
The beauty of functional programming is that it provides referential transparency, the property that the results of a function depend only on its inputs. But many things can break referential transparency:
- Randomness
- Exceptions
- State (global variables, mutable objects)
- Non-deterministic parallelism and concurrency
- I/O (reading to/from disk, the network, the terminal, …)
Traditionally, these are called effects. In the coming weeks we’ll see safe ways to encapsulate all these and recover referential transparency, using monads, but this week we’ll focus instead on special cases where local uses of these features do not affect referential transparency.
In fact, we have already seen one example: tail recursion elimination! Converting a recursive function to a loop that uses a mutable variable does not change anything for the callers of that function, but saves space on the stack, and prevents stack overflows. In fact, this transformation is so safe that the compiler does it automatically, just as how we’d do it by hand.
This week we’ll focus on this and two more examples:
- Using exceptions for control flow.
- Using state to store previously-computed function results, a technique called memoization.
Avoiding recursion
In your past experience, you probably used loops to implement most of the algorithms you encountered. However, during this course, we have used exclusively recursion to implement looping behaviour. Is recursion useless? Is it only complicating things for the sake of it? Let’s find out! To do so, we will rewrite recursive functions as loop and vice versa.
Converting tail-recursive functions to loops ⭐️
Let us start by converting tail recursive functions to loops. As a reminder, a tail recursive function is recursive function in which the recursive call(s) is(are) the last thing(s) to be executed in the function. For example, the following function is tail recursive:
@tailrec
def fTR(x: Int, acc: Int): Int =
if (x == 0) acc
else fTR(x - 1, acc + 1)
Indeed, the last thing to be executed in the body of f
is either 0
or f(x - 1)
, which is the recursive call.
The @tailrec
annotation is used to tell the compiler that the function is supposed to be tail recursive. The compiler will then emit an error if that is not the case. This is useful to make sure that you did not make a mistake when writing the function. When the compiler detects that a function is tail recursive, either because of the annotation or because it is able to infer it, it will compile it to a loop to improve the performance. Indeed a loop does not need to use space on the stack for each recursive call.
On the other hand, the following function is not tail recursive:
def f(x: Int): Int =
if (x == 0) 0
else 1 + f(x - 1)
In this case, the last thing executed in the case x != 0
is the addition 1 + f(x-1)
. This means that the recursive call must first be completed and then the result of the addition can be computed. This is not tail recursive.
Now that we have a better understanding of tail recursion, Let us convert some tail recursive functions to loops. For this exercise, you will need to convert the functions reverseAppend
, length
and foldLeft
from the execises of week 1. The goal is to understand the process that the compiler follows by replicating manually, following a mechanical procedure.
Hint
Here is the procedure you can follow to convert a tail recursive function to a loop. We illustrate it with the tail recursive function fTR
defined above.
-
Add a
while true do
loop:def fTRLoop(x: Int, acc: Int = 0): Int = while true do
-
Create a mutable variable for each parameter of the function and assign the values of the parameters to them.
def fTRLoop(x: Int, acc: Int = 0): Int = var xVar = x var accVar = acc while true do
-
Create a mutable variable that will contains the return value of the function. Here it is the
acc
parameter. This parameter has a default value, so we can remove the parameter and assign the default value to the variable.def fTRLoop(x: Int): Int = var xVar = x var acc = 0 while true do
-
Put the body of the function as the body of the loop.
def fTRLoop(x: Int): Int = var xVar = x var acc = 0 while true do
-
Replace the base case by a
return
statement returning the accumulator variable.def fTRLoop(x: Int): Int = var xVar = x var acc = 0 while true do if xVar == 0 then return acc
-
Replace the recursive call by an assignment to the parameter variables.
def fTRLoop(x: Int): Int = var acc = 0 var xVar = x while true do if xVar == 0 then return acc acc = acc + 1 xVar = xVar - 1
-
However, the compiler will not be happy with this code. Indeed, when type checking is performed, it will evaluate the
while
loop expression to be of typeUnit
. To solve this issue, the last statement should be of typeInt
, or throwing an exception, depending on your preference. Please note, however, that this last statement is unreachable and will then never be executed!def fTRLoop(x: Int): Int = var acc = 0 var xVar = x while true do if xVar == 0 then return acc acc = acc + 1 xVar = xVar - 1 throw new AssertionError("Unreachable")
And there you go, the loop version of the tail recursive function!
def fTRLoop(x: Int): Int =
var acc = 0
var xVar = x
while true do
if xVar == 0 then
return acc
acc = acc + 1
xVar = xVar - 1
throw new AssertionError("Unreachable")
Here is a good resource about tail recursion elimination (the process of converting tail recursive functions to loops): Debunking the “expensive procedure calls” myth.
Now that you know the procedure, convert the following functions:
@tailrec
def reverseAppend(l1: List[Int], l2: List[Int]): List[Int] =
if l1.isEmpty then l2
else reverseAppend(l1.tail, l1.head :: l2)
def reverseAppendLoop(l1: List[Int], l2: List[Int]): List[Int] =
???
src/main/scala/tailRecursion/lists.scala
@tailrec
def foldLeft(l: List[Int], acc: Int)(f: (Int, Int) => Int): Int =
if l.isEmpty then acc
else foldLeft(l.tail, f(acc, l.head))(f)
def foldLeftLoop(l: List[Int], startValue: Int)(f: (Int, Int) => Int): Int =
???
src/main/scala/tailRecursion/lists.scala
@tailrec
def sum(l: List[Int], acc: Int = 0): Int =
if l.isEmpty then acc
else sum(l.tail, acc + l.head)
def sumLoop(l: List[Int]): Int =
???
src/main/scala/tailRecursion/lists.scala
Now we know how to transform mechanically a tail recursive function to a loop. So for the rest of this exercise set, feel free to write either loops or tail recursive functions, depending on what you prefer (when appropriate of course).
Now the question is: can we write any recursive function as tail recursive? If so, why do we bother with recursive functions? Let us find out!
@tailrec
def reverseAppend(l1: List[Int], l2: List[Int]): List[Int] =
if l1.isEmpty then l2
else reverseAppend(l1.tail, l1.head :: l2)
def reverseAppendLoop(l1: List[Int], l2: List[Int]): List[Int] =
var current = l1
var reversed = l2
while true do
if current.isEmpty then return reversed
reversed = current.head :: reversed
current = current.tail
throw new AssertionError("Unreachable")
src/main/scala/tailRecursion/lists.scala
@tailrec
def foldLeft(l: List[Int], acc: Int)(f: (Int, Int) => Int): Int =
if l.isEmpty then acc
else foldLeft(l.tail, f(acc, l.head))(f)
def foldLeftLoop(l: List[Int], startValue: Int)(f: (Int, Int) => Int): Int =
var current = l
var accumulator = startValue
while true do
if current.isEmpty then return accumulator
accumulator = f(accumulator, current.head)
current = current.tail
throw new AssertionError("Unreachable")
src/main/scala/tailRecursion/lists.scala
@tailrec
def sum(l: List[Int], acc: Int = 0): Int =
if l.isEmpty then acc
else sum(l.tail, acc + l.head)
def sumLoop(l: List[Int]): Int =
var acc = 0
var current = l
while true do
if current.isEmpty then return acc
acc += current.head
current = current.tail
throw new AssertionError("Unreachable")
src/main/scala/tailRecursion/lists.scala
foldt
Let’s recall the foldt
function of the SE exercises:
extension [T](l: List[T])
def pairs(op: (T, T) => T): List[T] = l match
case a :: b :: tl => op(a, b) :: tl.pairs(op)
case _ => l
def foldt(z: T)(op: (T, T) => T): T = l match
case Nil => z
case List(t) => t
case _ :: tail => l.pairs(op).foldt(z)(op)
How would you write the function foldt
with loops? You can start from the following template:
extension [T](l: List[T])
def foldt(z: T)(op: (T, T) => T): T =
???
src/main/scala/tailRecursion/lists.scala
Extra exercise: how would you write pairs
with a while
loop?
Here is an example of solution:
extension [T](l: List[T])
def foldt(z: T)(op: (T, T) => T): T =
var list = l
while true do
list match
// if list.size > 1
case _ :: _ :: tail =>
list = list.pairs(op)
// if list.size == 1
case a :: Nil => return a
// if list.size == 0
case Nil => return z
throw new AssertionError("Unreachable")
src/main/scala/tailRecursion/lists.scala
Here is an example of solution for pairs
:
extension [T](l: List[T])
def pairs(op: (T, T) => T): List[T] =
var ret: List[T] = Nil
var list = l
while true do
list match
case a :: b :: tail =>
ret = op(a, b) :: ret
list = tail
case _ =>
return ret.reverse ++ list
throw new AssertionError("Unreachable")
src/main/scala/tailRecursion/lists.scala
groupBy
Now, you will implement a function groupBy
by yourself, without using the standard groupBy
method.
Implement two versions of groupBy:
- One using a mutable
var
and aforeach
loop. - One using
foldRight
.
def groupByForeach[T, S](f: T => S)(xs: List[T]): Map[S, List[T]] =
var result = Map.empty[S, List[T]]
for x <- xs do
val key = f(x)
val value = result.getOrElse(key, Nil)
result = result + (key -> (x :: value))
result
def groupByFoldRight[T, S](f: T => S)(xs: List[T]): Map[S, List[T]] =
xs.foldRight(Map.empty[S, List[T]]) { (x, result) =>
val key = f(x)
val value = result.getOrElse(key, Nil)
result + (key -> (x :: value))
}
src/main/scala/tailRecursion/groupBy.scala
foldLeft
versus foldRight
We commonly use foldLeft
and foldRight
to shorten simple recursive functions.
Can foldLeft
be rewritten as a loop? How about foldRight
? In both cases, write the code, or explain why it cannot be rewritten that way.
foldLeft
can be rewritten as a loop because it processes elements of a collection from left to right in order. This order of traversal and accumulation fits naturally into an iterative loop structure:
def foldLeftForeach[B, A](z: B)(op: (B, A) => B)(xs: List[A]): B =
var result = z
xs foreach (x => result = op(result, x))
result
src/main/scala/tailRecursion/fold.scala
foldRight
cannot be easily rewritten as a simple loop because it processes the collection from right to left. This requires processing the last element first, which means recursion is more natural for this operation.
Tail recursion modulo context 🔥
We saw in the first exercise that a tail recursive function can be mechanically transfored to a loop. This transformation is mostly useful because it is performed automatically by the compiler.
In this exercise, we will explore a way to rewrite some non-tail recursive functions into tail-recursive ones.
To illustrate this technique, let us consider the map
function on List[Int]
:
def map(l: List[Int], f: Int => Int): List[Int] =
if l.isEmpty then Nil
else f(l.head) :: map(l.tail, f)
src/main/scala/tailRecursion/lists.scala
The only thing that happens after the call is the creation of a Cons
instance. To create it, we need to know the head (which we know before the recursive call) but also the tail (which is computed recursively). So the tail is the difficult part: when the recursive call completes, the callee returns the tail. Does that suggest a solution?
Hint
The trick could be to shift responsibilities around so that the caller begins the construction of the Cons, and the callee finishes that construction by storing the computed tail.
To do so, we will create a list type with a mutable tail. This way we can construct the list before making the recursive call, and transfer the responsability to swap the tail to the recursive call.
Now that you have the idea, try to implement the mapTRWorker
function:
enum MutableList:
case Nil
case Cons(val hd: Int, var tail: MutableList)
import MutableList.*
def mapTR(l: MutableList, f: Int => Int): MutableList =
l match
case Nil => Nil
case Cons(hd, tl) =>
val acc: Cons = Cons(f(hd), Nil)
mapTRWorker(tl, f, acc)
acc
// @tailrec uncomment when working on the exercise
def mapTRWorker(
l: MutableList,
f: Int => Int,
acc: MutableList.Cons
): Unit =
???
src/main/scala/tailRecursion/lists.scala
enum MutableList:
case Nil
case Cons(val hd: Int, var tail: MutableList)
import MutableList.*
def mapTR(l: MutableList, f: Int => Int): MutableList =
l match
case Nil => Nil
case Cons(hd, tl) =>
val acc: Cons = Cons(f(hd), Nil)
mapTRWorker(tl, f, acc)
acc
// @tailrec uncomment when working on the exercise
def mapTRWorker(
l: MutableList,
f: Int => Int,
acc: MutableList.Cons
): Unit =
l match
case Nil => ()
case Cons(h, t) =>
acc.tail = Cons(f(h), Nil)
mapTRWorker(t, f, acc.tail.asInstanceOf[Cons])
src/main/scala/tailRecursion/lists.scala
This technique is often called “destination-passing style”. This is used notably by the map
function from the standard scala library:
final override def map[B](f: A => B): List[B] = {
if (this eq Nil) Nil else {
val h = new ::[B](f(head), Nil)
var t: ::[B] = h
var rest = tail
while (rest ne Nil) {
val nx = new ::(f(rest.head), Nil)
t.next = nx
t = nx
rest = rest.tail
}
releaseFence()
h
}
}
You are not expected to understand everything that is going on here. However, you can recognise that a new list with an empty tail and the head f(head)
is created here val h = new ::[B](f(head), Nil)
. Then, the function is implemented using a while
loop in which again a new list with an empty tail is created with val nx = new ::(f(rest.head), Nil)
, then the tail of the current list is modified here t.next = nx
.
This function is more complicated and uses some internal structures specific to the implementation of the class scala.collection.List
, but the idea is the same as the one you implemented.
Today some languages are able to do this automatically like OCaml, as described in this paper. Scala is not today yet, so recursion is still relevant :)
If you are interested to learn more about it, here are some resources:
Looping on Trees
In our quest to find a case in which recursion really is easier to use than loops, we will now look at trees. We will use the following definition of binary trees:
enum Tree[T]:
case Leaf(value: T)
case Node(left: Tree[T], right: Tree[T])
src/main/scala/tailRecursion/trees.scala
Sum of leaves - rotation
Let us start with a simple function that computes the sum of a tree’s leaves. The recursive version is the following:
def sumRec(t: Tree[Int]): Int =
t match
case Leaf(value) => value
case Node(left, right) => sumRec(left) + sumRec(right)
src/main/scala/tailRecursion/trees.scala
On right line trees ⭐️
In the 2023 midterm, we saw the concept of right line trees. As a reminder, a right line tree is a tree in which each node is either a leaf, or has a leaf child on the left. The following function checks whether a tree is a right line tree:
def isRightLineTree(t: Tree[Int]): Boolean =
t match
case Leaf(_) => true
case Node(Leaf(_), right) => isRightLineTree(right)
case _ => false
src/main/scala/tailRecursion/trees.scala
Can you see the similarity between a right line tree and a list in the context of tail recursive functions?
Before writing any code, think about this: what can (a + b) + c = a + (b + c)
mean on trees? Can we exploit this to write a loop (or tail recursive) function?
Hint
This represents the right rotation on trees. This property of +
is the associativity.
In our context, it means that the tree can be rearranged to compute the sum of leaves in a different way without affecting the result.
Let us write an imperative version (or tail recursive, as you prefer) of the sum
function that works only for right line trees. Do not forget to add the correct scala call ensure that the tree is indeed a right line tree before computing the sum 🙂:
def sumRightLineTree(tr: Tree[Int]): Int =
???
src/main/scala/tailRecursion/trees.scala
What would happen if the operation is not associative, like, for example, the substraction?
Hint
You can take inspiration from the sum function on the list we saw in the first exercise of this session. Think about the similarity between the structure of a list and the one from a right line tree.def sumRightLineTree(tr: Tree[Int]): Int =
require(isRightLineTree(tr))
var acc = 0
var t = tr
while true do
t match
case Leaf(value) =>
acc += value
return acc
case Node(Leaf(value), right) =>
acc += value
t = right
case _ => // cannot happen thanks to the require clause
return acc
acc
def sumRightLineTreeTailRec(tr: Tree[Int], acc: Int = 0): Int =
require(isRightLineTree(tr))
tr match
case Leaf(value) =>
acc + value
case Node(Leaf(value), right) =>
sumRightLineTreeTailRec(right, acc + value)
case _ => // cannot happen thanks to the require clause
acc
src/main/scala/tailRecursion/trees.scala
Using rotations ⭐️
A right rotation is an operation on a tree that gives a new tree with less leaves on the left hand side. Can we use this operation to compute the sum of leaves on an arbitrary tree while reusing the idea of the sum we implemented on the right line tree? Let’s find out!
Implement the sumRotate
function that computes the sum of leaves’ values using right rotations:
def sumRotate(tr: Tree[Int], acc: Int): Int =
???
src/main/scala/tailRecursion/trees.scala
Can you name which property the operation done on the leaves must satisfy for this to work?
def sumRotate(tr: Tree[Int], acc: Int): Int =
tr match
case Leaf(value) => acc + value
case Node(Leaf(value), right) => sumRotate(right, acc + value)
case Node(Node(ll, lr), right) => sumRotate(Node(ll, Node(lr, right)), acc)
src/main/scala/tailRecursion/trees.scala
The sum using rotation works correctly because the sum of leaves is agnostic to the shape of the tree.
Sum of leaves - DFS
Now, let us write an imperative version of the sum function. Before writing any code think well about it. On what elements would you iterate? How make sure you visit all the nodes? How would you keep track of the nodes you still need to visit?
def sumLoop(t: Tree[Int]): Int =
???
src/main/scala/tailRecursion/trees.scala
Hint
As you might have realised, this is not straightforward. The main issue is that you need to keep track of the nodes to visit. What datastructure would you use to store the nodes you encounter and will visit later?Spoiler
You should indeed use a Stack to keep track of the nodes you have to visit. You can use again the Stack class from the scala library.def sumLoop(t: Tree[Int]): Int =
var sum = 0
var toVisit = Stack(t)
while toVisit.nonEmpty do
toVisit.pop() match
case Leaf(value) =>
sum += value
case Node(left, right) =>
toVisit.push(right)
toVisit.push(left)
sum
src/main/scala/tailRecursion/trees.scala
Reduce on tree 🔥
We will now take a look at another function on trees: reduce
. As a reminder, reduce
is defined recursively as follows on trees:
def reduce[T](tr: Tree[T], f: (T, T) => T): T =
tr match
case Leaf(value) => value
case Node(left, right) => f(reduce(left, f), reduce(right, f))
src/main/scala/tailRecursion/trees.scala
We will write an imperative version of this function.
To kickstart, let us implement a mutable Stack
structure, just as you used in the previous exercises. Our MStack
is based on a List
and will extend the following trait
:
trait MStackTrait[A]:
def push(a: A): Unit
def pop(): A
def isEmpty: Boolean
def size: Int
def contains(a: A): Boolean
case class MStack[A](var l: List[A] = Nil) extends MStackTrait[A]:
def push(a: A): Unit =
???
def pop(): A =
???
def isEmpty: Boolean =
???
def size: Int =
???
def contains(a: A): Boolean =
???
src/main/scala/tailRecursion/trees.scala
Now let us implement a post order traversal on trees. This function will return the subtrees in post order, which means first the left child, then the right child, then the node itself. For example, the post order traversal of the following tree:
val tree =
Node(
Node(
Leaf(1),
Leaf(2)
),
Leaf(3)
)
is the following list:
List(
Leaf(1),
Leaf(2),
Node(Leaf(1), Leaf(2)),
Leaf(3),
Node(Node(Leaf(1), Leaf(2)), Leaf(3))
)
Now, implement the postOrderTraversal
function using a while loop and the MStack
type that you just implemented. Think hard before writing the function. How do you keep track of the nodes you will visit? How you ensure that you add the nodes in the correct order?
def postOrderTraversal[T](tr: Tree[T]): List[Tree[T]] =
???
src/main/scala/tailRecursion/trees.scala
This postorder traversal should be enough to implement reduce
!
Hint
You’ll need an intermediate data structure to keep track of partially reduced results while you go over the post order. You can use a Map
that associates tree notes to the result of reduce
on them, or you can use a Stack
with a bit more thinking about the order in which nodes appear in the post-order.
def reduceLoop[T](tr: Tree[T], f: (T, T) => T): T =
???
src/main/scala/tailRecursion/trees.scala
trait MStackTrait[A]:
def push(a: A): Unit
def pop(): A
def isEmpty: Boolean
def size: Int
def contains(a: A): Boolean
case class MStack[A](var l: List[A] = Nil) extends MStackTrait[A]:
def push(a: A): Unit =
l = a :: l
def pop(): A =
val a = l.head
l = l.tail
a
def isEmpty: Boolean =
l.isEmpty
def size: Int =
l.size
def contains(a: A): Boolean =
l.contains(a)
src/main/scala/tailRecursion/trees.scala
def postOrderTraversal[T](tr: Tree[T]): List[Tree[T]] =
var toVisit = MStack[Tree[T]]()
toVisit.push(tr)
var postOrderNodes: List[Tree[T]] = Nil
while !toVisit.isEmpty do
val n = toVisit.pop()
postOrderNodes = n :: postOrderNodes
n match
case Node(left, right) =>
toVisit.push(left)
toVisit.push(right)
case Leaf(_) =>
postOrderNodes
src/main/scala/tailRecursion/trees.scala
For reduceLoop
, here are two versions: one that uses a Map
…
def reduceLoop[T](tr: Tree[T], f: (T, T) => T): T =
var cache: Map[Tree[T], T] = Map()
for (t, idx) <- postOrderTraversal(tr).zipWithIndex do
t match
case Leaf(v) => cache = cache + (t -> v)
case Node(left, right) =>
val leftValue = cache(left)
val rightValue = cache(right)
cache = cache + (t -> f(leftValue, rightValue))
cache(tr)
src/main/scala/tailRecursion/trees.scala
… and one that uses a stack (note that elements could also be removed from the map as they are accessed in the algorithm above):
def reduceWithStack[T](tr: Tree[T], f: (T, T) => T): T =
val stack = Stack.empty[T]
for t <- postOrderTraversal(tr) do
t match
case Leaf(value) => stack.push(value)
case Node(left, right) =>
val (r, l) = (stack.pop(), stack.pop())
stack.push(f(l, r))
stack.pop()
src/main/scala/tailRecursion/trees.scala
As a final, 🔥🔥 exercise: could you further reduce the memory usage by merging both traversals?
Map on tree
Now that you implemented reduce
, you can implement map
using the same principles.
No solution provided, share on Ed!
Proof of correctness of reduce
on trees 🔥
We will now revisit the reduce
function that uses the post order traversal from the exercise Reduce on tree. If you did not do it, here is the implementation:
Solution
trait MStackTrait[A]:
def push(a: A): Unit
def pop(): A
def isEmpty: Boolean
def size: Int
def contains(a: A): Boolean
case class MStack[A](var l: List[A] = Nil) extends MStackTrait[A]:
def push(a: A): Unit =
l = a :: l
def pop(): A =
val a = l.head
l = l.tail
a
def isEmpty: Boolean =
l.isEmpty
def size: Int =
l.size
def contains(a: A): Boolean =
l.contains(a)
src/main/scala/tailRecursion/trees.scala
def postOrderTraversal[T](tr: Tree[T]): List[Tree[T]] =
var toVisit = MStack[Tree[T]]()
toVisit.push(tr)
var postOrderNodes: List[Tree[T]] = Nil
while !toVisit.isEmpty do
val n = toVisit.pop()
postOrderNodes = n :: postOrderNodes
n match
case Node(left, right) =>
toVisit.push(left)
toVisit.push(right)
case Leaf(_) =>
postOrderNodes
src/main/scala/tailRecursion/trees.scala
def reduceLoop[T](tr: Tree[T], f: (T, T) => T): T =
var cache: Map[Tree[T], T] = Map()
for (t, idx) <- postOrderTraversal(tr).zipWithIndex do
t match
case Leaf(v) => cache = cache + (t -> v)
case Node(left, right) =>
val leftValue = cache(left)
val rightValue = cache(right)
cache = cache + (t -> f(leftValue, rightValue))
cache(tr)
src/main/scala/tailRecursion/trees.scala
If you are interested in program verification and proofs, two courses are given at EPFL in this area:
- Formal Verification by Pr. Viktor Kunčak
- Interactive Theorem Proving CS by Pr. Clément Pit-Claudel
Post order traversal
Let us start by proving the correctness of the post order traversal algorithm. In words, the algorithm is correct if the produced list contains all the nodes of the tree, and if the order of the nodes is indeed a post order traversal (i.e., the children appears in the list at smaller index than their parent ). In particular, the list should end with the root.
Your task is to write the above postcondition in scala code and a loop invariant for the postOrderTraversal
function that proves it is indeed satisfied at the end. Be careful, the invariant must take the state of the stack into account.
object postOrderTraversalProof:
def invariantPostOrder[T](currentList: List[Tree[T]], toVisit: MStackTrait[Tree[T]], root: Tree[T]): Boolean =
currentList.forall(tr =>
tr match
case Leaf(_) => true
case n @ Node(left, right) =>
((currentList.contains(left) && currentList.indexOf(left) < currentList.indexOf(n)) || toVisit.contains(
left
)) &&
((currentList.contains(right) && currentList.indexOf(right) < currentList.indexOf(n)) || toVisit.contains(
right
))
) && (currentList.isEmpty && toVisit.size == 1 && toVisit.contains(root) || currentList.last == root)
def postOrderTraversal[T](tr: Tree[T]): List[Tree[T]] =
var toVisit = MStack[Tree[T]]()
toVisit.push(tr)
var postOrderNodes: List[Tree[T]] = Nil
assert(invariantPostOrder(postOrderNodes, toVisit, tr))
while !toVisit.isEmpty do
toVisit.pop() match
case n @ Leaf(t) =>
postOrderNodes = n :: postOrderNodes
case n @ Node(left, right) =>
postOrderNodes = n :: postOrderNodes
toVisit.push(right)
toVisit.push(left)
assert(invariantPostOrder(postOrderNodes, toVisit, tr))
assert(invariantPostOrder(postOrderNodes, toVisit, tr))
postOrderNodes
src/main/scala/tailRecursion/trees.scala
reduce
Now that you proved the correctness of the post order traversal, you can prove the correctness of the reduce
function. The postcondition of reduce
in our case is that the cache contains the root, and that this value is equal to reduce(root, f)
.
Your task is to write the above postcondition in scala code and a loop invariant for the reduce
function that proves it is indeed satisfied at the end. You can write one invariant encoding the validity of the cache, i.e., that all values it contains are indeed correct with respect to the key and the function f
, and one invariant that encodes the correctness of how the cache is updated in the loop.
object reduceProof:
def reduce[T](tr: Tree[T], f: (T, T) => T): T =
tr match
case Leaf(value) => value
case Node(left, right) => f(reduce(left, f), reduce(right, f))
def invariantReduce[T](currentCache: Map[Tree[T], T], f: (T, T) => T): Boolean =
currentCache.keySet.forall(k => currentCache(k) == reduce(k, f))
def forInvariant[T](
postOrderList: List[Tree[T]],
currentIndex: Int,
currentCache: Map[Tree[T], T]
): Boolean =
postOrderList.take(currentIndex).forall(k => currentCache.contains(k))
def reduceLoop[T](tr: Tree[T], f: (T, T) => T): T =
var cache: Map[Tree[T], T] = Map()
assert(invariantReduce(cache, f))
for (t, idx) <- postOrderTraversal(tr).zipWithIndex do
assert(invariantReduce(cache, f))
assert(forInvariant(postOrderTraversal(tr), idx, cache))
t match
case Leaf(v) => cache = cache + (t -> v)
case Node(left, right) =>
val leftValue = cache(left)
val rightValue = cache(right)
cache = cache + (t -> f(leftValue, rightValue))
assert(invariantReduce(cache, f))
assert(forInvariant(postOrderTraversal(tr), idx, cache))
assert(invariantReduce(cache, f))
assert(cache.contains(tr))
cache(tr)
src/main/scala/tailRecursion/trees.scala
Exceptional control flow
An exceptional contains
method ⭐️
-
Consider the following two implementations of
contains
:extension [T](l: List[T]) final def containsRec(t0: T): Boolean = l match case Nil => false case hd :: tl => hd == t0 || tl.containsRec(t0)
src/main/scala/exceptions/Exceptions.scala
extension [T](l: List[T]) final def containsFold(t0: T): Boolean = l.foldRight(false)((hd, found) => found || hd == t0)
src/main/scala/exceptions/Exceptions.scala
Is one of them preferable? Why?
-
Which mechanism do you know to interrupt a computation before it completes? Use it to rewrite
contains
usingforEach
.Hint
Use an exception! They work just the same in Scala as in Java.
What advantages does this approach have?
-
containsRec
is much better: it stops as soon as it find a result! -
Exceptions allow us to leverage
.foreach
:extension [T](l: List[T]) final def containsExn(t0: T): Boolean = case object Found extends Exception try for hd <- l if hd == t0 do throw Found false catch case Found => true
src/main/scala/exceptions/Exceptions.scala
Avoiding accidental escape: boundary
/break
⭐️
Exceptions are great, but they risk escaping: if you forget to catch an exception raised for control flow, it will propagate to the caller of your function, and cause havoc there.
-
Read the boundary/break documentation.
-
Use a
boundary
to reimplementcontains
a fourth time. -
🔥 Which of these four implementations of
contains
is fastest? Make a guess, then confirm it by writing a JMH benchmark.
Value-carrying exceptions
-
Define a custom error type to hold values. Use it to write an exception-based implementation of
find
. -
Use
boundary
/break
instead of a custom error type.
import scala.util.boundary
extension [T](l: List[T])
final def containsBoundary(t0: T): Boolean =
boundary:
for hd <- l if hd == t0 do boundary.break(true)
false
src/main/scala/exceptions/Exceptions.scala
extension [T](l: List[T])
final def findExn(p: T => Boolean): Option[T] =
case class FoundWith(t: T) extends Exception
try
for hd <- l if p(hd) do throw FoundWith(hd)
None
catch
case FoundWith(t) => Some(t)
src/main/scala/exceptions/Exceptions.scala
import scala.util.boundary
extension [T](l: List[T])
final def findBoundary(p: T => Boolean): Option[T] =
boundary:
for hd <- l if p(hd) do boundary.break(Some(hd))
None
src/main/scala/exceptions/Exceptions.scala
As for performance, the best way to answer this question is to benchmark! But we can still make some educated guesses:
-
boundary
/break
is implemented using exceptions under the hood, socontainsExn
andcontainsBoundary
should be similar. -
containsRec
is tail recursive, so it will compile to a clean loop: it should be very fast. -
containsFold
does not exit early, so it should be much slower — not because of the implementation offold
, but because of the fact that it processes the whole list instead of stopping when it find the element.
To be sure, we can use the following benchmark:
@Warmup(iterations = 10, time = 100, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 50, time = 100, timeUnit = TimeUnit.MILLISECONDS)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@BenchmarkMode(Array(Mode.AverageTime))
@Fork(1)
class SimpleBenchmark:
@Benchmark
def using_exn: Unit = SimpleBenchmark.l.containsExn(5000000L)
@Benchmark
def using_fold: Unit = SimpleBenchmark.l.containsFold(5000000L)
@Benchmark
def using_rec: Unit = SimpleBenchmark.l.containsRec(5000000L)
@Benchmark
def using_boundary: Unit = SimpleBenchmark.l.containsBoundary(5000000L)
object SimpleBenchmark:
val l: List[Long] = (1L to 10000000L).toList
src/main/scala/exceptions/Benchmarks.scala
And here is one example run:
Benchmark Mode Cnt Score Error Units
SimpleBenchmark.using_boundary avgt 50 24.121 ± 0.453 ms/op
SimpleBenchmark.using_exn avgt 50 24.767 ± 0.509 ms/op
SimpleBenchmark.using_fold avgt 50 125.344 ± 22.122 ms/op
SimpleBenchmark.using_rec avgt 50 20.448 ± 0.224 ms/op
Memoization
Briefly, memoization is the process of augmenting a function with a mutable cache that records the output of the function every time it is called. If the function is subsequently called again with a previously-seen input, the result can be returned from cache instead of being recomputed.
A step-by-step example ⭐️
To see why memoization may be useful, consider a simple example: the Fibonacci function, which we studied previously:
def fib(n: Int): Int =
if n <= 1 then 1 else fib(n - 1) + fib(n - 2)
src/main/scala/memo/Fib.scala
To compute fib(4)
we made two recursive calls: one to fib(3)
, and one to fib(2)
. To compute fib(3)
, we again make two recursive calls: one to fib(2)
, and one to fib(1)
. Without special precautions, we end up computing fib(2)
twice. Other parts of the computation are similarly repeated.
fib(4)
=== ( fib(3) + fib(2) )
=== ( (fib(2) + fib(1)) + (fib(1) + fib(0)))
=== (((fib(1) + fib(0)) + 1 ) + (1 1 ))
=== (((1 + 1 ) + 1 ) + (1 1 ))
Interestingly, the cost of computing fib(n)
grows exactly as fib(n)
: if it takes $T(k)$ steps to compute fib(k)
, then the cost of computing fib(n)
is $T(n) = T(n - 1) + T(n - 2)$.
Memoization
All this redundant computation is unnecessary. Instead, as our first attempt to address this problem, we can create a cache to store fib
’s results:
import scala.collection.mutable.Map
def fibMemo(n: Int): Int =
val cache: Map[Int, Int] = Map()
def loop(idx: Int): Int =
cache.getOrElseUpdate(
idx,
if idx <= 1 then 1
else loop(idx - 1) + loop(idx - 2)
)
loop(n)
src/main/scala/memo/Fib.scala
-
Can you convince yourself that this function behaves identically to the version without a cache?
-
How large does the cache grow (i.e., how many entries get created in the cache) as we evaluate
fib(k)
? What entries does it contain when the computation completes?
Subproblem graph
To save space, we need to understand the structure of the subproblem graph of the Fibonacci function. The subproblem graph is a graph where:
- Each node is a possible input to the function.
- There is an edge
a → b
if the computation off(a)
uses the result off(b)
.
For example, the nodes of the computation graph of fib(4)
are 4
, 3
, 2
, 1
, 0
and its edges are 4 → 3
, 4 → 2
, 3 → 2
, 3 → 1
, 2 → 1
, 2 → 0
.
Here is one representation of the graph (notice that there are no edges from 1
to 0
: they are both leaves):
┌─────┬─────┐
│ v v
4 → 3 → 2 → 1 0
│ ^ ^│ ^ ^
└─────┴─────┘└────┴──────┘
Dynamic programming
The subproblem graph captures precisely the notion of dependency: we cannot compute the output of a function on a given input node unless we know the outputs of the function on all the nodes it points to. Given this:
- Traditional recursion simply recomputes these dependencies every time they are needed.
- Memoization computes every value at most once, but stores it forever, which can use a lot of memory.
- Dynamic programming reorders the computation to save memory. Every value is computed at most once and then discarded when we know future computations will not need it.
The key questions to be able to do dynamic programming are: How long do we need to remember cached values? Which computation order minimizes this time?
To answer, we proceed in three steps:
-
Find a traversal of the subproblem graph, starting from the leaves, such that dependencies are always computed before their parents. This is called a reverse topological sort of the subproblem graph.
-
Rewrite our algorithm to construct the memoization cache iteratively, in the order given by step 1.
For Fibonnaci, the order is very simple:
0, 1, 2, 3, 4, …
. In other words, to computefib(n)
, it is sufficient to know the values of allfib(k)
wherek < n
. The result of step 2 is, hence, as follows:import scala.collection.mutable.Map def fibIter(n: Int): Int = val cache: Map[Int, Int] = Map() for idx <- 0 to n do cache(idx) = if idx <= 1 then 1 else cache(idx - 1) + cache(idx - 2) cache(n)
src/main/scala/memo/Fib.scala
-
Discard entries from the memoization cache as soon as they are not used any more. In the case of the Fibonacci function, we only need to keep the last two entries:
import scala.collection.mutable.Map def fibIterOpt(n: Int): Int = val cache: Map[Int, Int] = Map() for idx <- 0 to n do cache(idx) = if idx <= 1 then 1 else cache(idx - 1) + cache(idx - 2) cache.remove(idx - 2) cache(n)
src/main/scala/memo/Fib.scala
And we can, as a last cleanup step, entirely eliminate the cache, keeping only two variables:
import scala.collection.mutable.Map def fibIterFinal(n: Int): Int = var f0 = 1 var f1 = 1 for idx <- 2 to n do val f = f0 + f1 f0 = f1 f1 = f f1
src/main/scala/memo/Fib.scala
First application: $\binom{n}{k}$ “$n$ choose $k$”
The function choose(n, k)
computes how many ways there are to pick k
elements among n
, without considering order and without allowing repetitions. This choice can be done in two ways:
- Pick the first element, then choose
k - 1
elements among the remainingn - 1
. - Do not pick the first element, and hence choose
k
elements among the remainingn - 1
.
-
Implement
choose
as a recursive function using this equation. Mind the base cases!def choose(n: Int, k: Int): Int = ???
src/main/scala/memo/Choose.scala
-
Draw the subproblem graph of
choose(5, 3)
as a tree. Each node should have a pair of numbers, since the function takes two arguments. Do you notice repeated work? -
Write a memoized implementation of
choose
. The cache should map pairs of numbers (inputs) to single numbers (outputs):def chooseMemo(n: Int, k: Int): Int = ???
src/main/scala/memo/Choose.scala
-
Redraw the subproblem graph, but this time lay it out as an array with 6 columns and 4 rows: place node
(i, j)
at positionx = i
,y = j
. What do you notice about the structure of the graph? Propose a reverse topological ordering of it. -
Replace the Map-based cache with a two-dimensional array, and rewrite the memoized algorithm to build the cache iteratively, without recursion.
-
Is the whole cache needed at all times? Rewrite the algorithm to use less memory.
-
:
def choose(n: Int, k: Int): Int = if k <= 0 || k >= n then 1 else choose(n - 1, k - 1) + choose(n - 1, k)
src/main/scala/memo/Choose.scala
-
The tracing technique that we’ve seen in lecture produced a tree that is exactly the subproblem graph.
-
:
def chooseMemo(n: Int, k: Int): Int =
val cache: Map[(Int, Int), Int] = Map()
def loop(n: Int, k: Int): Int =
cache.getOrElseUpdate(
(n, k),
if k <= 0 || k >= n then 1
else loop(n - 1, k - 1) + loop(n - 1, k)
)
loop(n, k)
src/main/scala/memo/Choose.scala
The following drawing shows the subproblem graph of nChooseK(5, 2)
, with each repeated computation replaced with [cached]
.
nChooseK(5, 2)
/ \
nChooseK(4, 1) nChooseK(4, 2)
/ \ / \
nChooseK(3, 0) nChooseK(3, 1) [cached] nChooseK(3, 2)
/ \ / \
nChooseK(2, 0) nChooseK(2, 1) nChooseK(2, 1) [cached]
/ \ / \
nChooseK(1, 0) [cached] [cached] [cached]
-
Each cell points to its neighbor directly below and the one diagonally below to the left.
-
def chooseIter(n: Int, k: Int): Int = val dp = Array.ofDim[Int](n + 1, k + 1) for nn <- 0 to n do for kk <- 0 to Math.min(nn, k) do dp(nn)(kk) = if kk <= 0 || kk >= nn then 1 else dp(nn - 1)(kk - 1) + dp(nn - 1)(kk) if k <= 0 || k >= n then 1 else dp(n)(k)
src/main/scala/memo/Choose.scala
-
Here is one first solution, directly adapted from the previous one.
def chooseIterFinal(n: Int, k: Int): Int = var prev = Array.empty[Int] for nn <- 0 to n do var nxt = Array.ofDim[Int](Math.min(nn, k) + 1) for kk <- 0 until nxt.length do nxt(kk) = if kk <= 0 || kk >= nn then 1 else prev(kk - 1) + prev(kk) prev = nxt if k <= 0 || k >= n then 1 else prev(k)
src/main/scala/memo/Choose.scala
We can adjust it further by keeping only one column:
def chooseIterFinalOpt(n: Int, k: Int): Int = if k <= 0 || k >= n then 1 else var col = Array.fill(math.min(n, k) + 1)(1) for nn <- 2 to n kk <- math.min(k, nn - 1) until 0 by -1 do col(kk) = col(kk - 1) + col(kk) col(k)
src/main/scala/memo/Choose.scala
… or by iterating along diagonals:
def chooseIterFinalGC(n: Int, k: Int): Int = if k <= 0 || k >= n then 1 else val arrLen = n - k + 1 val diag = Array.fill(arrLen)(1) for _ <- 1 to k i <- 1 until arrLen do diag(i) = diag(i) + diag(i - 1) diag(n - k)
src/main/scala/memo/Choose.scala
This matrix we’re building is called “Pascal’s triangle” — search for this term and “dynamic programming” to read more about this problem!
More applications: memoizing every previous CS214 problem ⭐️
Train yourself to add memoization to functions by revisiting previous CS-214 problems. Particularly relevant are coinChange
and des chiffres et des lettres.
No solution provided. The functions should compute the same results, and you can check that they are properly memoized by tracing them and confirming that the same result does not appear twice.
Benchmarking 🔥
The original solution to the Anagrams lab was not exactly fast. Memoize the recursive part of the anagrams
function, and measure the resulting speed improvements. How much faster does it get?
No solution provided. Share your results on Ed!
Tower of Hanoi
A popular item in dentist offices, children museums, and on “10 original gift ideas for the holidays” lists is the game called “Tower of Hanoi”.
The game has three pegs, and 7 disks of increasing size, each with a hole in their center. In the initial configuration, the disks are stacked from largest to smallest on the leftmost peg:
-|- | |
--|-- | |
---|--- | |
----|---- | |
-----|----- | |
------|------ | |
-------|------- | |
==== PEG 0 ==== ==== PEG 1 ==== ==== PEG 2 ====
The aim is to move all the disks to the rightmost peg, with one rule: a larger disk may never rest on top of a smaller disk. Hence, this is a valid move:
| | |
--|-- | |
---|--- | |
----|---- | |
-----|----- | |
------|------ | |
-------|------- | -|-
==== PEG 0 ==== ==== PEG 1 ==== ==== PEG 2 ====
… and so is this:
| | |
| | |
---|--- | |
----|---- | |
-----|----- | |
------|------ | |
-------|------- --|-- -|-
==== PEG 0 ==== ==== PEG 1 ==== ==== PEG 2 ====
… but after this the only valid moves are $2 \to 1$ (moving the disk from peg 2 to peg 1) as well as $2 \to 0$ and $1 \to 0$.
A solution of the game is a list of moves $i \to j$ that moves all disks from peg 0 to peg 2. Here is a solution for the case of three disks:
-|- | |
--|-- | |
---|--- | |
=Left= =Middle= =Right=
Left → Right
| | |
--|-- | |
---|--- | -|-
=Left= =Middle= =Right=
Left → Middle
| | |
| | |
---|--- --|-- -|-
=Left= =Middle= =Right=
Right → Middle
| | |
| -|- |
---|--- --|-- |
=Left= =Middle= =Right=
Left → Right
| | |
| -|- |
| --|-- ---|---
=Left= =Middle= =Right=
Middle → Left
| | |
| | |
-|- --|-- ---|---
=Left= =Middle= =Right=
Middle → Right
| | |
| | --|--
-|- | ---|---
=Left= =Middle= =Right=
Left → Right
| | -|-
| | --|--
| | ---|---
=Left= =Middle= =Right=
These diagrams are generated by calling viewMoves(hanoi(3), 3)
. This function is provided in the code supplement to this exercise set.
-
Any good text editor or IDE should have an implementation of tower of Hanoi (if you’re using Emacs, simply use
M-x hanoi
to start it). Use it to familiarize yourself with the game. -
Write a function that computes a solution to the problem with $n$ disks. Check below for an important hint, or skip the hint if you prefer a 🔥 exercise. In any case, remember the fundamental question of recursive problems: how do I express a solution to my problem in terms of smaller subproblems?
enum Peg: case Left, Middle, Right case class Move(from: Peg, to: Peg)
src/main/scala/memo/Hanoi.scala
def hanoi(n: Int): Seq[Move] = ???
src/main/scala/memo/Hanoi.scala
Hint
You need to solve a more general problem: how to move $n$ disks from peg $a$ to peg $b$. Can you solve the problem with 7 disks if you know how to move the first 6 disks to the middle peg?
def hanoiHelper(src: Peg, dst: Peg, third: Peg, n: Int): Seq[Move] = ???
src/main/scala/memo/Hanoi.scala
-
Can this program benefit from memoization?
-
The solution is to switch to Emacs.
-
Here’s a scala implementation:
def hanoiHelper(src: Peg, dst: Peg, third: Peg, n: Int): Seq[Move] = if n == 0 then Vector() else hanoiHelper(src, third, dst, n - 1) ++ Vector(Move(src, dst)) ++ hanoiHelper(third, dst, src, n - 1)
src/main/scala/memo/Hanoi.scala
def hanoi(n: Int): Seq[Move] = hanoiHelper(Peg.Left, Peg.Right, Peg.Middle, n)
src/main/scala/memo/Hanoi.scala
… and Wikipedia has all the details.
-
Yes; in fact, it can benefit from more than just memoizing based on the 4 inputs
src
,dst
,mid
, andn
, because solutions can be renamed: if we have a solution to transfer 6 disks from peg0
to peg1
, then we immediately have a solution for peg1
to peg2
, for example. This reduces the complexity of the code from exponential to linear, since the two recursive calls inhanoiHelper
are with the same heightn - 1
.
A memoizing fixpoint combinator 🔥
Take a look back at the combinator exercise from the polymorphism week.
Where we stopped, the first step of memoization always looks the same: starting from a function def f(input: …) = …
, we write the following:
def fMemo(input: …) =
val cache: Map[…] = Map()
def loop(input: …) =
cache.getOrElseUpdate(input,
…(body)…)
loop(input)
It would be nice to be able to abstract over this pattern. Define a higher-order function memo
to do so:
def memo[A, B](f: (A, A => B) => B)(a: A): B =
???
src/main/scala/memo/Combinator.scala
This function should be such that we can define fib
and choose
as follows:
val fib = memo: (n: Int, f: Int => Int) =>
if n <= 1 then 1 else f(n - 1) + f(n - 2)
src/main/scala/memo/Combinator.scala
val choose = memo[(Int, Int), Int] {
case ((n, k), f) =>
if k <= 0 || k >= n then 1
else f((n - 1, k - 1)) + f((n - 1, k))
}
src/main/scala/memo/Combinator.scala
Hint
Start from the fixpoint combinator from the previous exercise:
def fixpoint[A, B](f: (A, A => B) => B)(a: A): B =
def loop(a: A): B = f(a, loop)
f(a, loop)
src/main/scala/memo/Combinator.scala
def memo[A, B](f: (A, A => B) => B)(a: A): B =
val cache = mutable.Map.empty[A, B]
def loop(a: A): B =
cache.getOrElseUpdate(a, f(a, loop))
loop(a)
src/main/scala/memo/Combinator.scala