Last updated on
Mutation
Welcome to week 10 of CS-214 — Software Construction!
This exercise set is intended to help you explore mutation and reasoning about programs correctness.
As usual, ⭐️ indicates the most important exercises and questions and 🔥 indicates the most challenging ones.
You do not need to complete all exercises to succeed in this class, and you do not need to do all exercises in the order they are written.
We strongly encourage you to solve the exercises on paper first, in groups. After completing a first draft on paper, you may want to check your solutions on your computer. To do so, you can download the scaffold code (ZIP).
Proofs of imperative algorithms (Hoare logic)
In this section, we will take a look at how we can make sure that imperative programs are actually correct.
This is related to what you did in the Specifications exercises, when you translated properties into code. Here you will perform a similar task, but you will also write invariants, which are properties that stay true throughout the execution of a program and help to prove that some properties hold.
Absolute value ⭐️
We will start with the absolute value function on Int
with the following definition:
def abs(x: Int): Int = {
if x < 0 then -x else x
} ensuring (res =>
???
)
src/main/scala/hoare/hoare.scala
What does this function guarantee? What is the condition that the result of the absolute value function must satisfy mathematically? Complete the ensuring
above.
The max
function ⭐️
Let us start with the max
function on List[Int]
and Array[Int]
. You can choose which version you want to implement. We require that the list or the array is not empty.
def maxLoopArray(a: Array[Int]): Int =
require(!a.isEmpty)
???
src/main/scala/hoare/hoare.scala
def maxLoopList(l: List[Int]): Int =
require(!l.isEmpty)
???
src/main/scala/hoare/hoare.scala
Solution
As you need a function to work on for the rest of the exercise, here is a working implementation of max
:
def maxLoopArray(a: Array[Int]): Int =
require(!a.isEmpty)
var index = 1
var max = a(0)
while index < a.length do
if a(index) > max then max = a(index)
index += 1
max
src/main/scala/hoare/hoare.scala
def maxLoopList(l: List[Int]): Int =
require(!l.isEmpty)
var list = l.tail
var max = l.head
while !list.isEmpty do
if list.head > max then
max = list.head
list = list.tail
max
src/main/scala/hoare/hoare.scala
We now ask you to write a loop invariant for the max
function. As a reminder, a loop invariant is a boolean condition that stays true throughout all executions of the loop. In particular, the invariant is true before and after the loop, and therefore helps to prove the correctness of the loop.
You can use the take
and forall
functions on List
, knowing that:
assert(l.take(n+1) == l.take(n) ++ List(l(n)))
assert(l.take(0).isEmpty)
assert(l.forall(x => true))
These functions also exist on Array
and work in a similar way.
def maxLoopListWithInvariant(l: List[Int]): Int =
require(!l.isEmpty)
???
src/main/scala/hoare/hoare.scala
def maxLoopArrayWithInvariant(a: Array[Int]): Int =
require(a.size > 0)
???
src/main/scala/hoare/hoare.scala
The find
function
Let us know take a look at the find
function on List
. The find
function takes a predicate and a list and returns the first element of the list that satisfies the predicate. If no element satisfies the predicate, it returns None
. We will use the following implementation:
def find(l: List[Int], p: Int => Boolean): Option[Int] = {
def loop(l: List[Int]): Option[Int] =
var li = l
while !li.isEmpty do
assert(true)
if p(li.head) then
return Some(li.head)
li = li.tail
assert(true)
None
loop(l)
}
.ensuring(res => ???)
src/main/scala/hoare/hoare.scala
We want to prove that find
is correct. Can you think of the post condition? If the function returns None
, what should be true about the list? And if it returns Some(x)
, what should be true about x
and the list?
First write the post condition of the function in the ensuring
.
Then, we will write an invariant for the while
loop. What should be true about the list at each iteration? Do we need a condition on the part of the list that was already visited?
Add your invariant condition in the assert
s in the loop.
I can’t believe it can sort 🔥
In this exercise, we will analyze a sorting algorithm puzzle from a recent short paper. Here it is, implemented in Scala:
def swap(a: Array[Int], i: Int, j: Int): Unit =
val tmp = a(i)
a(i) = a(j)
a(j) = tmp
def ICantBelieveItCanSort(a: Array[Int]): Unit =
for i <- 0 until a.length do
for j <- 0 until a.length do
if a(i) < a(j) then
swap(a, i, j)
src/main/scala/hoare/sorting.scala
It looks suspiciously similar to other sorting algorithms; for comparison, here is an insertion sort:
def insertionSort(a: Array[Int]): Unit =
for i <- 1 until a.length do
for j <- i until 0 by -1 do
if a(j) < a(j - 1) then
swap(a, j, j - 1)
src/main/scala/hoare/sorting.scala
… yet it’s not the same. In fact, this ICantBelieveItCanSort
routine is not a good sorting algorithm — don’t use it! But it’s a good reasoning puzzle… why does it work? The authors write:
There is nothing good about this algorithm. It is slow – the algorithm obviously runs in $\Theta(n^2)$ time, whether worst-case, average-case or best-case. It unnecessarily compares all pairs of positions, twice. There seems to be no intuition behind it, and its correctness is not entirely obvious. You certainly do not want to use it as a first example to introduce students to sorting algorithms. It is not stable, does not work well for external sorting, cannot sort inputs arriving online, and does not benefit from partially sorted inputs.
You task is to understand why this algorithm correctly sorts its input array, and then write a proof of correctness for it. Coming up with the complete proof is hard; simply transcribing the proof provided by the authors would already be an excellent exercise.
Hint
Tracing the code may help you understand how it works.
Traced implementation
def highlight(a: Array[Int], is: Int*) =
a.toList.zipWithIndex.map((x, k) => if is.contains(k) then f"[$x]" else f" $x ").mkString(", ")
def ICantBelieveItCanSort_traced(a: Array[Int]): Unit =
for i <- 0 until a.length do
println(f">> i=$i")
for j <- 0 until a.length do
if a(i) < a(j) then
println(f"> j=$j ${highlight(a, i, j)}")
swap(a, i, j)
println(f"> j=$j ${highlight(a, i, j)}")
src/main/scala/hoare/sorting.scala
Strongest postconditions
In this section, we will take a look at the concept of strongest postconditions to reason about programs behaviour, and ultimately correctness.
The strongest postcondition is the most restrictive condition on the output of a program, given that the input satisfies a given condition. Let us say we are analyzing the following function f
:
def f(x: BigInt): BigInt = {
require(x > 0)
x + 1
} ensuring (res => ???)
src/main/scala/strongestPostcondition/strongest.scala
Note that we are using BigInt
, which represents natural numbers, so we don’t have to worry about overflows.
We are looking for a condition sp
such that f
returns a value that satisfies sp
if the input x
satisfies the precondition x > 0
. Moreover, we want sp
to be the most restrictive (“strongest”) possible.
So:
def f(x: BigInt): BigInt = {
require(x > 0)
x + 1
} ensuring (res => res > 1)
src/main/scala/strongestPostcondition/strongest.scala
In this example, the strongest postcondition is x => x > 1
. Indeed, for all x > 0
, f
returns a value that is greater than 1
. Moreover, we cannot find a more restrictive condition.
We can also think of the strongest postcondition in terms of set of values for the input and output. Let us call the set of values that satisfy the precondition P
and a set of values that satisfy the postcondition Q
. Then, the strongest postcondition is the smallest set of values Q
such that for all x
in P
, f(x)
is in Q
. In our example, P = {x | x > 0}
and Q = {x | x > 1}
.
Strongest postcondition calculations are useful to check specifications: if we have a precondition (require
) and a postcondition (ensuring
) for a given program, then we can check that the postcondition follows from the precondition by checking that the strongest postcondition implies the stated postcondition. Formally, it is:
assert(if sp(x) then ensuring(x) else false)
or
$$ \forall x: x \in Q \implies x \in S $$
where $S$ is the set of all outputs permitted by the user-supplied ensuring
clause.
Strongest postcondition for pure functions
Now that we know what a strongest postcondition is, let us find one for the following functions:
def f1(x: BigInt): BigInt = {
require(x > 0)
2 * x
} ensuring (y => ???)
src/main/scala/strongestPostcondition/strongest.scala
def f2(x: BigInt): BigInt = {
require(x > 2 && x <= 10)
if x < 5 then BigInt(0) else x
} ensuring (res => ???)
src/main/scala/strongestPostcondition/strongest.scala
def f3(x: BigInt): BigInt = {
require((x > 0 && x < 9) || (x >= 20 && x < 26))
if x < 6 then x + 1
else if x < 23 then 3 * x
else -2 * x
} ensuring (y => ???)
src/main/scala/strongestPostcondition/strongest.scala
Try to think of arguments as symbolic values rather than concrete. This means that you should think of x
as a variable containing a natural number constrained by the precondition. Then you go through the code and for each new condition, you create a new condition to a mental set. When you encounter a modification or return statement, you also add a condition to the one of the branch you are in, connected with a conjunction. At then end, you can take a disjunction of the conditions you created for each branch.
This technique can be implemented mechanically, and this is called Symbolic Execution. It can be used in practice to prove postconditions of programs.
Strongest postcondition for imperative functions
Now, let us take a look at imperative functions — try to find the strongest postconditions of these programs!
def imperativeF1(x: BigInt): BigInt = {
require(x >= -1 && x <= 4)
var y = x
var z = BigInt(1)
if y > 0 then z *= 4
if y < 4 then z *= 2
if y % 2 == 0 then z -= 3
z
} ensuring (z => ???)
src/main/scala/strongestPostcondition/strongest.scala
def imperativeF2(x: BigInt): BigInt = {
require(x >= -5 && x <= 5)
var y = x
var z = BigInt(0)
while y * y > 0 do
z += 1
if y > 0 then y -= 1
else y += 1
z
} ensuring (z => ???)
src/main/scala/strongestPostcondition/strongest.scala
Bottom-up merge sort: saving memory with mutation ⭐️
Bottom-up merge sort is a variant of merge sort that uses a different algorithm to merge two sorted lists. In the functional programming paradigm, we have used the foldt
trick. In the imperative paradigm, we will use mutation on Array
s. This variant will be in-place meaning that we will directly modify the input array instead of creating a new one.
The benefit of having an in-place algorithm is that we do not need to allocate a new array for each merge. This saves memory and is more efficient. However, it is more difficult to implement and to reason about. This algorithm will be a good exercise to practice imperative programming and to see how mutation can be used to save memory. It will allocate a new array only once, at the beginning of the algorithm.
For this exercise, we will sort Array
s of size $2^k$ for some $k \in \mathbb{N}$. This is not a limitation of the algorithm, but it will simplify the implementation.
The idea is the following, with an array a
of length $n = 2^k$ for some $k \in \mathbb{N}$:
- Create an array of the same length $n$ that we will name
b
. - Do successive runs of length $1$, $2$, $4$, …, $2^k$ where $2^k \leq n$. Denote
width
the length of such a run. - For each run of width
width
:- We will merge the two subarrays of length
width / 2
into the arrayb
. - We will then copy the content of
b
intoa
.
- We will merge the two subarrays of length
- Once we have merged all the runs of width
width
, we copy the content ofb
intoa
and double the width and start again until the whole array is merged.
Here is an example of execution of this merge-sort algorithm, with the array [15, 4, 18, 11]
:
- Allocating
b = [0, 0, 0, 0]
- The array
a
contains runs of length 1:[15, 4, 18, 11]
- Runs are
[15], [4], [18], [11]
- Runs are
- Merging the runs of length 1 into
b
to produce runs of length 2:[4, 15, 11, 18]
- Writing
b
intoa
:[4, 15, 11, 18]
- The array
a
contains runs of length 2:[4, 15, 11, 18]
- Runs are
[4, 15], [11, 18]
- Runs are
- Merging the runs of length 2 into
b
to produce runs of length 4:[4, 11, 15, 18]
You can see that the array is sorted at the end.
The implementation will be divided into two parts:
- A function
sort
that will do the successive runs and merge them. - A function
merge
that will merge two consecutive runs of lengthk
from an array into a run of length2*k
, written to another array.
You can start from the following template:
def sort(original: Array[Int]): Array[Int] =
val n = original.length
???
def merge(a: Array[Int], b: Array[Int], indexLeft: Int, width: Int) =
???
src/main/scala/mutation/mergeSort.scala
merge
is a function that will take three extra arguments:
indexLeft
corresponds to the index of the first element of the left run.width
corresponds to the width of a run.
In other words, the left run is a[indexLeft: indexLeft + width)
, the right run is a[indexLeft + width: indexLeft + 2 * width)
and the merged run should be b[indexLeft: indexLeft + 2 * width)
, where [a:b)
means from index a
(included) to index b
(excluded).
A more efficient implementation would not copy the content of the array but swap their roles and return either a
or b
depending on which one is the output array. This would save the copy of the array at the end of each run. However, this would make the algorithm more difficult to understand and to reason about. So, it is up to you!
Idea for the really curious ones ;) :
- You can try to implement this more efficient version.
- How would you adapt your algorithm to handle non-power of two arrays?
Logging
Logger
When running code, it can be useful to print what is going on for debugging. In order to do so, let us define a trait Logger
:
trait Logger:
def log(message: String, depth: Int = 0): Unit
src/main/scala/mutation/logger.scala
Unlike in the lecture — in which we referred to a Logger
without using depth
— the depth
argument is used to indent the message. It is employed to print the message in a tree-like structure. For example, if the depth
is 2, the message will be printed with two indentations before it.
Suppose the code is logging the following messages:
logger.log("= Head of the tree")
logger.log("= Sub-element", 1)
logger.log("= Sub-sub-element", 2)
logger.log("= Sub-element", 1)
logger.log("= Sub-sub-element 1", 2)
logger.log("= Sub-sub-element 2", 2)
The output should be:
= Head of the tree
= Sub-element
= Sub-sub-element
= Sub-element
= Sub-sub-element 1
= Sub-sub-element 2
You have seen this structure already in log messages of tests. Example:
recursion.ListOpsTests:
+ length: empty list 0.008s
+ length: list with 2 elements 0.0s
+ length: list with 4 elements 0.0s
…
anagrams.AnagramsSuite:
+ computeOccurrenceList: abcd (3pts) 0.032s
+ computeOccurrenceList: Robert (3pts) 0.0s
…
Your implementation must comply to these rules to pass the tests. You are free to use any character of your liking (either a space, two spaces, four spaces, a tab, …) for the indentation.
It will help to make the output more readable for the next exercise.
Let’s implement a LoggerBuffered
that will accumulate messages in a private buffer
mutable field as the following:
class LoggerBuffered extends Logger:
// private var ???
def log(message: String, depth: Int = 0): Unit =
???
def getOutput: String =
???
src/main/scala/mutation/logger.scala
Adding logs to eval
Let’s refer to a simple evaluator for the following ADT:
enum Expr:
case Constant(a: Int)
case Add(a: Expr, b: Expr)
case Sub(a: Expr, b: Expr)
src/main/scala/mutation/logger.scala
and implement the function that evaluates such an expression, but for every expression that is evaluated, you need to log every performed operation using the given Logger l
:
def eval(e: Expr, l: Logger, depth: Int = 0): Int =
???
src/main/scala/mutation/logger.scala
And here is an example: the following snippet
val l = new LoggerBuffered
eval(Add(Add(Constant(1), Constant(2)), Constant(3)), l)
l.getOutput
generates
Add(Constant(1),Constant(2)) + Constant(3) ->
Constant(1) + Constant(2) ->
Constant(1) = 1
Constant(2) = 2
= 3
Constant(3) = 3
= 6
The output may be a bit different depending on the strings you print for each operation, but it should be similar. Note that the tests will expect a specific output. So not passing the tests on this exercise does not mean that your solution is wrong, it may just be different.
Here are the specific outputs:
-
eval(Constant(0))
should log
Constant(0) = 0
-
eval(Add(Constant(1), Constant(2)))
shoud log
Constant(1) + Constant(2) -> Constant(1) = 1 Constant(2) = 2 = 3
where the two lines with
Constant(1) = 1
andConstant(2) = 2
are logged withdepth
1. -
eval(Sub(Constant(1), Constant(2)))
shoud log
Constant(1) - Constant(2) -> Constant(1) = 1 Constant(2) = 2 = -1
where the two lines with
Constant(1) = 1
andConstant(2) = 2
are logged withdepth
1. -
The first given example also satisfies the expected output.
Adding logs to a stack interpreter
A practical foldLeft ⭐️
In this exercise, you will write an interpreter for a small stack language. We already saw this concept in the exercises about polymorphism in week 4. This version is a bit different though.
The language offers the following operations:
enum Operation:
case Push(n: Int)
case Add
case Sub
type Program = List[Operation]
src/main/scala/mutation/interpreter.scala
This language then offers constant integer values, as well as addition and subtraction operations.
For example, the arithmetic expression 1 + 2 - 3 + 4
would be expressed by the following Program
:
val p = List(Push(1), Push(2), Add, Push(3), Sub, Push(4), Add)
We will write a stack machine based interpreter for this language.
A stack machine interpreter uses a stack and reads the tokens from a Program
(i.e., a list of operations) one after the other. For each operation, it performs the following:
- if it is a
Push(n)
value, it pushes the valuen
on the stack. - if it is an
Add
operation, it pops two elements of the stack, adds them and pushes the result back on the stack. - if it is a
Sub
operation, it pops two elements of the stack, subtracts them and pushes the result back on the stack.
After reading the whole program, the stack should contain a single element, which is the result of the program.
This model is used in practice by the WASM virtual machine, which runs WASM bytecode in browsers. It is also used by the JVM to run Java bytecode.
For those who are learning about RISCV in the architecture course: the stack machine model we are using here is an alternative to the register machine model that RISCV was designed for.
Recursive
Let us start by writing a recursive version of the interpreter. To start, you will implement a function that takes a stack (in this case, it is represented by a List[Int]
) and one Operation
and returns a new stack. The returned stack corresponds to the state of the stack after the operation has been executed:
def evalOp(stack: Stack, op: Operation): Stack =
???
src/main/scala/mutation/interpreter.scala
No, write two copies of the recursive version of the eval
function which interprets an entire program, starting with an empty stack: one using plain recursion without calling higher-order list functions, and one with foldLeft
:
def eval(p: List[Operation]): Stack =
???
src/main/scala/mutation/interpreter.scala
Now imperative 🔥
Write down the operations that the recursive algorithm performs, step by step, on an example. How does the stack evolve? Can you replicate this behavior with a mutable Stack from the Scala library?
def eval(p: List[Operation]): Stack[Int] =
???
src/main/scala/mutation/interpreter.scala
You may want to revisit this part of the exercise next week, where we’ll study a systematic way to transform left folds into loops.
Adding logs ⭐️
In this step, we will add logging to the interpreter. The goal is to log every operation that is performed on the stack. For example, the following instructions:
Push(1), Push(2), Add, Push(3), Sub, Push(4), Add
should log the following:
Push(1) →
Stack: 1
Push(2) →
Stack: 2 1
Add →
Stack: 3
Push(3) →
Stack: 3 3
Sub →
Stack: 0
Push(4) →
Stack: 4 0
Add →
Stack: 4
It’s to you how to implement the logging: the only requirements is that you should log the operation that is performed and the stack at each step.
You are free to use either the recursive or the imperative version of the interpreter.