Last updated on
Effects, memoization
Welcome to week 11 of CS-214 β Software Construction!
As usual, βοΈ indicates the most important exercises and questions and π₯ indicates the most challenging ones. Exercises or questions marked π§ͺ are intended to build up to concepts used in this week’s lab.
You do not need to complete all exercises to succeed in this class, and you do not need to do all exercises in the order they are written.
We strongly encourage you to solve the exercises on paper first, in groups. After completing a first draft on paper, you may want to check your solutions on your computer. To do so, you can download the scaffold code (ZIP).
The beauty of functional programming is that it provides referential transparency, the property that the results of a function depend only on its inputs. But many things can break referential transparency:
- Randomness
- Exceptions
- State (global variables, mutable objects)
- Non-deterministic parallelism and concurrency
- I/O (reading to/from disk, the network, the terminal, β¦)
Traditionally, these are called effects. In the coming weeks we’ll see safe ways to encapsulate all these and recover referential transparency, using monads, but this week we’ll focus instead on special cases where local uses of these features do not affect referential transparency.
In fact, we have already seen one example: tail recursion elimination! Converting a recursive function to a loop that uses a mutable variable does not change anything for the callers of that function, but saves space on the stack, and prevents stack overflows. In fact, this transformation is so safe that the compiler does it automatically, just as how we’d do it by hand.
This week we’ll focus on this and two more examples:
- Using exceptions for control flow.
- Using state to store previously-computed function results, a technique called memoization.
Avoiding recursion
In your past experience, you probably used loops to implement most of the algorithms you encountered. However, during this course, we have used exclusively recursion to implement looping behaviour. Is recursion useless? Is it only complicating things for the sake of it? Let’s find out! To do so, we will rewrite recursive functions as loop and vice versa.
Converting tail-recursive functions to loops βοΈ
Let us start by converting tail recursive functions to loops. As a reminder, a tail recursive function is recursive function in which the recursive call(s) is(are) the last thing(s) to be executed in the function. For example, the following function is tail recursive:
@tailrec
def fTR(x: Int, acc: Int): Int =
if (x == 0) acc
else fTR(x - 1, acc + 1)
Indeed, the last thing to be executed in the body of f
is either 0
or f(x - 1)
, which is the recursive call.
The @tailrec
annotation is used to tell the compiler that the function is supposed to be tail recursive. The compiler will then emit an error if that is not the case. This is useful to make sure that you did not make a mistake when writing the function. When the compiler detects that a function is tail recursive, either because of the annotation or because it is able to infer it, it will compile it to a loop to improve the performance. Indeed a loop does not need to use space on the stack for each recursive call.
On the other hand, the following function is not tail recursive:
def f(x: Int): Int =
if (x == 0) 0
else 1 + f(x - 1)
In this case, the last thing executed in the case x != 0
is the addition 1 + f(x-1)
. This means that the recursive call must first be completed and then the result of the addition can be computed. This is not tail recursive.
Now that we have a better understanding of tail recursion, Let us convert some tail recursive functions to loops. For this exercise, you will need to convert the functions reverseAppend
, length
and foldLeft
from the execises of week 1. The goal is to understand the process that the compiler follows by replicating manually, following a mechanical procedure.
Hint
Here is the procedure you can follow to convert a tail recursive function to a loop. We illustrate it with the tail recursive function fTR
defined above.
-
Add a
while true do
loop:def fTRLoop(x: Int, acc: Int = 0): Int = while true do
-
Create a mutable variable for each parameter of the function and assign the values of the parameters to them.
def fTRLoop(x: Int, acc: Int = 0): Int = var xVar = x var accVar = acc while true do
-
Create a mutable variable that will contains the return value of the function. Here it is the
acc
parameter. This parameter has a default value, so we can remove the parameter and assign the default value to the variable.def fTRLoop(x: Int): Int = var xVar = x var acc = 0 while true do
-
Put the body of the function as the body of the loop.
def fTRLoop(x: Int): Int = var xVar = x var acc = 0 while true do
-
Replace the base case by a
return
statement returning the accumulator variable.def fTRLoop(x: Int): Int = var xVar = x var acc = 0 while true do if xVar == 0 then return acc
-
Replace the recursive call by an assignment to the parameter variables.
def fTRLoop(x: Int): Int = var acc = 0 var xVar = x while true do if xVar == 0 then return acc acc = acc + 1 xVar = xVar - 1
-
However, the compiler will not be happy with this code. Indeed, when type checking is performed, it will evaluate the
while
loop expression to be of typeUnit
. To solve this issue, the last statement should be of typeInt
, or throwing an exception, depending on your preference. Please note, however, that this last statement is unreachable and will then never be executed!def fTRLoop(x: Int): Int = var acc = 0 var xVar = x while true do if xVar == 0 then return acc acc = acc + 1 xVar = xVar - 1 throw new AssertionError("Unreachable")
And there you go, the loop version of the tail recursive function!
def fTRLoop(x: Int): Int =
var acc = 0
var xVar = x
while true do
if xVar == 0 then
return acc
acc = acc + 1
xVar = xVar - 1
throw new AssertionError("Unreachable")
Here is a good resource about tail recursion elimination (the process of converting tail recursive functions to loops): Debunking the “expensive procedure calls” myth.
Now that you know the procedure, convert the following functions:
@tailrec
def reverseAppend(l1: List[Int], l2: List[Int]): List[Int] =
if l1.isEmpty then l2
else reverseAppend(l1.tail, l1.head :: l2)
def reverseAppendLoop(l1: List[Int], l2: List[Int]): List[Int] =
???
src/main/scala/tailRecursion/lists.scala
@tailrec
def foldLeft(l: List[Int], acc: Int)(f: (Int, Int) => Int): Int =
if l.isEmpty then acc
else foldLeft(l.tail, f(acc, l.head))(f)
def foldLeftLoop(l: List[Int], startValue: Int)(f: (Int, Int) => Int): Int =
???
src/main/scala/tailRecursion/lists.scala
@tailrec
def sum(l: List[Int], acc: Int = 0): Int =
if l.isEmpty then acc
else sum(l.tail, acc + l.head)
def sumLoop(l: List[Int]): Int =
???
src/main/scala/tailRecursion/lists.scala
Now we know how to transform mechanically a tail recursive function to a loop. So for the rest of this exercise set, feel free to write either loops or tail recursive functions, depending on what you prefer (when appropriate of course).
Now the question is: can we write any recursive function as tail recursive? If so, why do we bother with recursive functions? Let us find out!
foldt
Let’s recall the foldt
function of the SE exercises:
extension [T](l: List[T])
def pairs(op: (T, T) => T): List[T] = l match
case a :: b :: tl => op(a, b) :: tl.pairs(op)
case _ => l
def foldt(z: T)(op: (T, T) => T): T = l match
case Nil => z
case List(t) => t
case _ :: tail => l.pairs(op).foldt(z)(op)
How would you write the function foldt
with loops? You can start from the following template:
extension [T](l: List[T])
def foldt(z: T)(op: (T, T) => T): T =
???
src/main/scala/tailRecursion/lists.scala
Extra exercise: how would you write pairs
with a while
loop?
groupBy
Now, you will implement a function groupBy
by yourself, without using the standard groupBy
method.
Implement two versions of groupBy:
- One using a mutable
var
and aforeach
loop. - One using
foldRight
.
foldLeft
versus foldRight
We commonly use foldLeft
and foldRight
to shorten simple recursive functions.
Can foldLeft
be rewritten as a loop? How about foldRight
? In both cases, write the code, or explain why it cannot be rewritten that way.
Tail recursion modulo context π₯
We saw in the first exercise that a tail recursive function can be mechanically transfored to a loop. This transformation is mostly useful because it is performed automatically by the compiler.
In this exercise, we will explore a way to rewrite some non-tail recursive functions into tail-recursive ones.
To illustrate this technique, let us consider the map
function on List[Int]
:
def map(l: List[Int], f: Int => Int): List[Int] =
if l.isEmpty then Nil
else f(l.head) :: map(l.tail, f)
src/main/scala/tailRecursion/lists.scala
The only thing that happens after the call is the creation of a Cons
instance. To create it, we need to know the head (which we know before the recursive call) but also the tail (which is computed recursively). So the tail is the difficult part: when the recursive call completes, the callee returns the tail. Does that suggest a solution?
Hint
The trick could be to shift responsibilities around so that the caller begins the construction of the Cons, and the callee finishes that construction by storing the computed tail.
To do so, we will create a list type with a mutable tail. This way we can construct the list before making the recursive call, and transfer the responsability to swap the tail to the recursive call.
Now that you have the idea, try to implement the mapTRWorker
function:
enum MutableList:
case Nil
case Cons(val hd: Int, var tail: MutableList)
import MutableList.*
def mapTR(l: MutableList, f: Int => Int): MutableList =
l match
case Nil => Nil
case Cons(hd, tl) =>
val acc: Cons = Cons(f(hd), Nil)
mapTRWorker(tl, f, acc)
acc
// @tailrec uncomment when working on the exercise
def mapTRWorker(
l: MutableList,
f: Int => Int,
acc: MutableList.Cons
): Unit =
???
src/main/scala/tailRecursion/lists.scala
Looping on Trees
In our quest to find a case in which recursion really is easier to use than loops, we will now look at trees. We will use the following definition of binary trees:
enum Tree[T]:
case Leaf(value: T)
case Node(left: Tree[T], right: Tree[T])
src/main/scala/tailRecursion/trees.scala
Sum of leaves - rotation
Let us start with a simple function that computes the sum of a tree’s leaves. The recursive version is the following:
def sumRec(t: Tree[Int]): Int =
t match
case Leaf(value) => value
case Node(left, right) => sumRec(left) + sumRec(right)
src/main/scala/tailRecursion/trees.scala
On right line trees βοΈ
In the 2023 midterm, we saw the concept of right line trees. As a reminder, a right line tree is a tree in which each node is either a leaf, or has a leaf child on the left. The following function checks whether a tree is a right line tree:
def isRightLineTree(t: Tree[Int]): Boolean =
t match
case Leaf(_) => true
case Node(Leaf(_), right) => isRightLineTree(right)
case _ => false
src/main/scala/tailRecursion/trees.scala
Can you see the similarity between a right line tree and a list in the context of tail recursive functions?
Before writing any code, think about this: what can (a + b) + c = a + (b + c)
mean on trees? Can we exploit this to write a loop (or tail recursive) function?
Hint
This represents the right rotation on trees. This property of +
is the associativity.
In our context, it means that the tree can be rearranged to compute the sum of leaves in a different way without affecting the result.
Let us write an imperative version (or tail recursive, as you prefer) of the sum
function that works only for right line trees. Do not forget to add the correct scala call ensure that the tree is indeed a right line tree before computing the sum π:
def sumRightLineTree(tr: Tree[Int]): Int =
???
src/main/scala/tailRecursion/trees.scala
What would happen if the operation is not associative, like, for example, the substraction?
Hint
You can take inspiration from the sum function on the list we saw in the first exercise of this session. Think about the similarity between the structure of a list and the one from a right line tree.Using rotations βοΈ
A right rotation is an operation on a tree that gives a new tree with less leaves on the left hand side. Can we use this operation to compute the sum of leaves on an arbitrary tree while reusing the idea of the sum we implemented on the right line tree? Let’s find out!
Implement the sumRotate
function that computes the sum of leaves’ values using right rotations:
def sumRotate(tr: Tree[Int], acc: Int): Int =
???
src/main/scala/tailRecursion/trees.scala
Can you name which property the operation done on the leaves must satisfy for this to work?
Sum of leaves - DFS
Now, let us write an imperative version of the sum function. Before writing any code think well about it. On what elements would you iterate? How make sure you visit all the nodes? How would you keep track of the nodes you still need to visit?
def sumLoop(t: Tree[Int]): Int =
???
src/main/scala/tailRecursion/trees.scala
Hint
As you might have realised, this is not straightforward. The main issue is that you need to keep track of the nodes to visit. What datastructure would you use to store the nodes you encounter and will visit later?Spoiler
You should indeed use a Stack to keep track of the nodes you have to visit. You can use again the Stack class from the scala library.Reduce on tree π₯
We will now take a look at another function on trees: reduce
. As a reminder, reduce
is defined recursively as follows on trees:
def reduce[T](tr: Tree[T], f: (T, T) => T): T =
tr match
case Leaf(value) => value
case Node(left, right) => f(reduce(left, f), reduce(right, f))
src/main/scala/tailRecursion/trees.scala
We will write an imperative version of this function.
To kickstart, let us implement a mutable Stack
structure, just as you used in the previous exercises. Our MStack
is based on a List
and will extend the following trait
:
trait MStackTrait[A]:
def push(a: A): Unit
def pop(): A
def isEmpty: Boolean
def size: Int
def contains(a: A): Boolean
case class MStack[A](var l: List[A] = Nil) extends MStackTrait[A]:
def push(a: A): Unit =
???
def pop(): A =
???
def isEmpty: Boolean =
???
def size: Int =
???
def contains(a: A): Boolean =
???
src/main/scala/tailRecursion/trees.scala
Now let us implement a post order traversal on trees. This function will return the subtrees in post order, which means first the left child, then the right child, then the node itself. For example, the post order traversal of the following tree:
val tree =
Node(
Node(
Leaf(1),
Leaf(2)
),
Leaf(3)
)
is the following list:
List(
Leaf(1),
Leaf(2),
Node(Leaf(1), Leaf(2)),
Leaf(3),
Node(Node(Leaf(1), Leaf(2)), Leaf(3))
)
Now, implement the postOrderTraversal
function using a while loop and the MStack
type that you just implemented. Think hard before writing the function. How do you keep track of the nodes you will visit? How you ensure that you add the nodes in the correct order?
def postOrderTraversal[T](tr: Tree[T]): List[Tree[T]] =
???
src/main/scala/tailRecursion/trees.scala
This postorder traversal should be enough to implement reduce
!
Hint
You’ll need an intermediate data structure to keep track of partially reduced results while you go over the post order. You can use a Map
that associates tree notes to the result of reduce
on them, or you can use a Stack
with a bit more thinking about the order in which nodes appear in the post-order.
def reduceLoop[T](tr: Tree[T], f: (T, T) => T): T =
???
src/main/scala/tailRecursion/trees.scala
Map on tree
Now that you implemented reduce
, you can implement map
using the same principles.
Proof of correctness of reduce
on trees π₯
We will now revisit the reduce
function that uses the post order traversal from the exercise Reduce on tree. If you did not do it, here is the implementation:
Solution
trait MStackTrait[A]:
def push(a: A): Unit
def pop(): A
def isEmpty: Boolean
def size: Int
def contains(a: A): Boolean
case class MStack[A](var l: List[A] = Nil) extends MStackTrait[A]:
def push(a: A): Unit =
l = a :: l
def pop(): A =
val a = l.head
l = l.tail
a
def isEmpty: Boolean =
l.isEmpty
def size: Int =
l.size
def contains(a: A): Boolean =
l.contains(a)
src/main/scala/tailRecursion/trees.scala
def postOrderTraversal[T](tr: Tree[T]): List[Tree[T]] =
var toVisit = MStack[Tree[T]]()
toVisit.push(tr)
var postOrderNodes: List[Tree[T]] = Nil
while !toVisit.isEmpty do
val n = toVisit.pop()
postOrderNodes = n :: postOrderNodes
n match
case Node(left, right) =>
toVisit.push(left)
toVisit.push(right)
case Leaf(_) =>
postOrderNodes
src/main/scala/tailRecursion/trees.scala
def reduceLoop[T](tr: Tree[T], f: (T, T) => T): T =
var cache: Map[Tree[T], T] = Map()
for (t, idx) <- postOrderTraversal(tr).zipWithIndex do
t match
case Leaf(v) => cache = cache + (t -> v)
case Node(left, right) =>
val leftValue = cache(left)
val rightValue = cache(right)
cache = cache + (t -> f(leftValue, rightValue))
cache(tr)
src/main/scala/tailRecursion/trees.scala
If you are interested in program verification and proofs, two courses are given at EPFL in this area:
- Formal Verification by Pr. Viktor KunΔak
- Interactive Theorem Proving CS by Pr. ClΓ©ment Pit-Claudel
Post order traversal
Let us start by proving the correctness of the post order traversal algorithm. In words, the algorithm is correct if the produced list contains all the nodes of the tree, and if the order of the nodes is indeed a post order traversal (i.e., the children appears in the list at smaller index than their parent ). In particular, the list should end with the root.
Your task is to write the above postcondition in scala code and a loop invariant for the postOrderTraversal
function that proves it is indeed satisfied at the end. Be careful, the invariant must take the state of the stack into account.
reduce
Now that you proved the correctness of the post order traversal, you can prove the correctness of the reduce
function. The postcondition of reduce
in our case is that the cache contains the root, and that this value is equal to reduce(root, f)
.
Your task is to write the above postcondition in scala code and a loop invariant for the reduce
function that proves it is indeed satisfied at the end. You can write one invariant encoding the validity of the cache, i.e., that all values it contains are indeed correct with respect to the key and the function f
, and one invariant that encodes the correctness of how the cache is updated in the loop.
Exceptional control flow
An exceptional contains
method βοΈ
-
Consider the following two implementations of
contains
:extension [T](l: List[T]) final def containsRec(t0: T): Boolean = l match case Nil => false case hd :: tl => hd == t0 || tl.containsRec(t0)
src/main/scala/exceptions/Exceptions.scala
extension [T](l: List[T]) final def containsFold(t0: T): Boolean = l.foldRight(false)((hd, found) => found || hd == t0)
src/main/scala/exceptions/Exceptions.scala
Is one of them preferable? Why?
-
Which mechanism do you know to interrupt a computation before it completes? Use it to rewrite
contains
usingforEach
.Hint
Use an exception! They work just the same in Scala as in Java.
What advantages does this approach have?
Avoiding accidental escape: boundary
/break
βοΈ
Exceptions are great, but they risk escaping: if you forget to catch an exception raised for control flow, it will propagate to the caller of your function, and cause havoc there.
-
Read the boundary/break documentation.
-
Use a
boundary
to reimplementcontains
a fourth time. -
π₯ Which of these four implementations of
contains
is fastest? Make a guess, then confirm it by writing a JMH benchmark.
Value-carrying exceptions
-
Define a custom error type to hold values. Use it to write an exception-based implementation of
find
. -
Use
boundary
/break
instead of a custom error type.
Memoization
Briefly, memoization is the process of augmenting a function with a mutable cache that records the output of the function every time it is called. If the function is subsequently called again with a previously-seen input, the result can be returned from cache instead of being recomputed.
A step-by-step example βοΈ
To see why memoization may be useful, consider a simple example: the Fibonacci function, which we studied previously:
def fib(n: Int): Int =
if n <= 1 then 1 else fib(n - 1) + fib(n - 2)
src/main/scala/memo/Fib.scala
To compute fib(4)
we made two recursive calls: one to fib(3)
, and one to fib(2)
. To compute fib(3)
, we again make two recursive calls: one to fib(2)
, and one to fib(1)
. Without special precautions, we end up computing fib(2)
twice. Other parts of the computation are similarly repeated.
fib(4)
=== ( fib(3) + fib(2) )
=== ( (fib(2) + fib(1)) + (fib(1) + fib(0)))
=== (((fib(1) + fib(0)) + 1 ) + (1 1 ))
=== (((1 + 1 ) + 1 ) + (1 1 ))
Interestingly, the cost of computing fib(n)
grows exactly as fib(n)
: if it takes $T(k)$ steps to compute fib(k)
, then the cost of computing fib(n)
is $T(n) = T(n - 1) + T(n - 2)$.
Memoization
All this redundant computation is unnecessary. Instead, as our first attempt to address this problem, we can create a cache to store fib
’s results:
import scala.collection.mutable.Map
def fibMemo(n: Int): Int =
val cache: Map[Int, Int] = Map()
def loop(idx: Int): Int =
cache.getOrElseUpdate(
idx,
if idx <= 1 then 1
else loop(idx - 1) + loop(idx - 2)
)
loop(n)
src/main/scala/memo/Fib.scala
-
Can you convince yourself that this function behaves identically to the version without a cache?
-
How large does the cache grow (i.e., how many entries get created in the cache) as we evaluate
fib(k)
? What entries does it contain when the computation completes?
Subproblem graph
To save space, we need to understand the structure of the subproblem graph of the Fibonacci function. The subproblem graph is a graph where:
- Each node is a possible input to the function.
- There is an edge
a β b
if the computation off(a)
uses the result off(b)
.
For example, the nodes of the computation graph of fib(4)
are 4
, 3
, 2
, 1
, 0
and its edges are 4 β 3
, 4 β 2
, 3 β 2
, 3 β 1
, 2 β 1
, 2 β 0
.
Here is one representation of the graph (notice that there are no edges from 1
to 0
: they are both leaves):
βββββββ¬ββββββ
β v v
4 β 3 β 2 β 1 0
β ^ ^β ^ ^
βββββββ΄ββββββββββββ΄βββββββ
Dynamic programming
The subproblem graph captures precisely the notion of dependency: we cannot compute the output of a function on a given input node unless we know the outputs of the function on all the nodes it points to. Given this:
- Traditional recursion simply recomputes these dependencies every time they are needed.
- Memoization computes every value at most once, but stores it forever, which can use a lot of memory.
- Dynamic programming reorders the computation to save memory. Every value is computed at most once and then discarded when we know future computations will not need it.
The key questions to be able to do dynamic programming are: How long do we need to remember cached values? Which computation order minimizes this time?
To answer, we proceed in three steps:
-
Find a traversal of the subproblem graph, starting from the leaves, such that dependencies are always computed before their parents. This is called a reverse topological sort of the subproblem graph.
-
Rewrite our algorithm to construct the memoization cache iteratively, in the order given by stepΒ 1.
For Fibonnaci, the order is very simple:
0, 1, 2, 3, 4, β¦
. In other words, to computefib(n)
, it is sufficient to know the values of allfib(k)
wherek < n
. The result of step 2 is, hence, as follows:import scala.collection.mutable.Map def fibIter(n: Int): Int = val cache: Map[Int, Int] = Map() for idx <- 0 to n do cache(idx) = if idx <= 1 then 1 else cache(idx - 1) + cache(idx - 2) cache(n)
src/main/scala/memo/Fib.scala
-
Discard entries from the memoization cache as soon as they are not used any more. In the case of the Fibonacci function, we only need to keep the last two entries:
import scala.collection.mutable.Map def fibIterOpt(n: Int): Int = val cache: Map[Int, Int] = Map() for idx <- 0 to n do cache(idx) = if idx <= 1 then 1 else cache(idx - 1) + cache(idx - 2) cache.remove(idx - 2) cache(n)
src/main/scala/memo/Fib.scala
And we can, as a last cleanup step, entirely eliminate the cache, keeping only two variables:
import scala.collection.mutable.Map def fibIterFinal(n: Int): Int = var f0 = 1 var f1 = 1 for idx <- 2 to n do val f = f0 + f1 f0 = f1 f1 = f f1
src/main/scala/memo/Fib.scala
First application: $\binom{n}{k}$ “$n$ choose $k$”
The function choose(n, k)
computes how many ways there are to pick k
elements among n
, without considering order and without allowing repetitions. This choice can be done in two ways:
- Pick the first element, then choose
k - 1
elements among the remainingn - 1
. - Do not pick the first element, and hence choose
k
elements among the remainingn - 1
.
-
Implement
choose
as a recursive function using this equation. Mind the base cases!def choose(n: Int, k: Int): Int = ???
src/main/scala/memo/Choose.scala
-
Draw the subproblem graph of
choose(5, 3)
as a tree. Each node should have a pair of numbers, since the function takes two arguments. Do you notice repeated work? -
Write a memoized implementation of
choose
. The cache should map pairs of numbers (inputs) to single numbers (outputs):def chooseMemo(n: Int, k: Int): Int = ???
src/main/scala/memo/Choose.scala
-
Redraw the subproblem graph, but this time lay it out as an array with 6 columns and 4 rows: place node
(i, j)
at positionx = i
,y = j
. What do you notice about the structure of the graph? Propose a reverse topological ordering of it. -
Replace the Map-based cache with a two-dimensional array, and rewrite the memoized algorithm to build the cache iteratively, without recursion.
-
Is the whole cache needed at all times? Rewrite the algorithm to use less memory.
More applications: memoizing every previous CS214 problem βοΈ
Train yourself to add memoization to functions by revisiting previous CS-214 problems. Particularly relevant are coinChange
and des chiffres et des lettres.
Benchmarking π₯
The original solution to the Anagrams lab was not exactly fast. Memoize the recursive part of the anagrams
function, and measure the resulting speed improvements. How much faster does it get?
Tower of Hanoi
A popular item in dentist offices, children museums, and on β10 original gift ideas for the holidaysβ lists is the game called βTower of Hanoiβ.
The game has three pegs, and 7 disks of increasing size, each with a hole in their center. In the initial configuration, the disks are stacked from largest to smallest on the leftmost peg:
-|- | |
--|-- | |
---|--- | |
----|---- | |
-----|----- | |
------|------ | |
-------|------- | |
==== PEG 0 ==== ==== PEG 1 ==== ==== PEG 2 ====
The aim is to move all the disks to the rightmost peg, with one rule: a larger disk may never rest on top of a smaller disk. Hence, this is a valid move:
| | |
--|-- | |
---|--- | |
----|---- | |
-----|----- | |
------|------ | |
-------|------- | -|-
==== PEG 0 ==== ==== PEG 1 ==== ==== PEG 2 ====
β¦ and so is this:
| | |
| | |
---|--- | |
----|---- | |
-----|----- | |
------|------ | |
-------|------- --|-- -|-
==== PEG 0 ==== ==== PEG 1 ==== ==== PEG 2 ====
β¦ but after this the only valid moves are $2 \to 1$ (moving the disk from peg 2 to peg 1) as well as $2 \to 0$ and $1 \to 0$.
A solution of the game is a list of moves $i \to j$ that moves all disks from peg 0 to peg 2. Here is a solution for the case of three disks:
-|- | |
--|-- | |
---|--- | |
=Left= =Middle= =Right=
Left β Right
| | |
--|-- | |
---|--- | -|-
=Left= =Middle= =Right=
Left β Middle
| | |
| | |
---|--- --|-- -|-
=Left= =Middle= =Right=
Right β Middle
| | |
| -|- |
---|--- --|-- |
=Left= =Middle= =Right=
Left β Right
| | |
| -|- |
| --|-- ---|---
=Left= =Middle= =Right=
Middle β Left
| | |
| | |
-|- --|-- ---|---
=Left= =Middle= =Right=
Middle β Right
| | |
| | --|--
-|- | ---|---
=Left= =Middle= =Right=
Left β Right
| | -|-
| | --|--
| | ---|---
=Left= =Middle= =Right=
These diagrams are generated by calling viewMoves(hanoi(3), 3)
. This function is provided in the code supplement to this exercise set.
-
Any good text editor or IDE should have an implementation of tower of Hanoi (if you’re using Emacs, simply use
M-x hanoi
to start it). Use it to familiarize yourself with the game. -
Write a function that computes a solution to the problem with $n$ disks. Check below for an important hint, or skip the hint if you prefer a π₯ exercise. In any case, remember the fundamental question of recursive problems: how do I express a solution to my problem in terms of smaller subproblems?
enum Peg: case Left, Middle, Right case class Move(from: Peg, to: Peg)
src/main/scala/memo/Hanoi.scala
def hanoi(n: Int): Seq[Move] = ???
src/main/scala/memo/Hanoi.scala
Hint
You need to solve a more general problem: how to move $n$ disks from peg $a$ to peg $b$. Can you solve the problem with 7 disks if you know how to move the first 6 disks to the middle peg?
def hanoiHelper(src: Peg, dst: Peg, third: Peg, n: Int): Seq[Move] = ???
src/main/scala/memo/Hanoi.scala
-
Can this program benefit from memoization?
A memoizing fixpoint combinator π₯
Take a look back at the combinator exercise from the polymorphism week.
Where we stopped, the first step of memoization always looks the same: starting from a function def f(input: β¦) = β¦
, we write the following:
def fMemo(input: β¦) =
val cache: Map[β¦] = Map()
def loop(input: β¦) =
cache.getOrElseUpdate(input,
β¦(body)β¦)
loop(input)
It would be nice to be able to abstract over this pattern. Define a higher-order function memo
to do so:
def memo[A, B](f: (A, A => B) => B)(a: A): B =
???
src/main/scala/memo/Combinator.scala
This function should be such that we can define fib
and choose
as follows:
val fib = memo: (n: Int, f: Int => Int) =>
if n <= 1 then 1 else f(n - 1) + f(n - 2)
src/main/scala/memo/Combinator.scala
val choose = memo[(Int, Int), Int] {
case ((n, k), f) =>
if k <= 0 || k >= n then 1
else f((n - 1, k - 1)) + f((n - 1, k))
}
src/main/scala/memo/Combinator.scala
Hint
Start from the fixpoint combinator from the previous exercise:
def fixpoint[A, B](f: (A, A => B) => B)(a: A): B =
def loop(a: A): B = f(a, loop)
f(a, loop)
src/main/scala/memo/Combinator.scala