Last updated on
Week 5: Comprehensions, Variance and Specs
Welcome to week 5 of CS-214 — Software Construction!
As before, exercises or questions marked ⭐️ are the most important, 🔥 are the most challenging, and 🧪 are the most useful for this week’s lab.
We strongly encourage you to solve the exercises on paper first, in groups. After completing a first draft on paper, you may want to check your solutions on your computer. To do so, you can pull from the course exercises repository.
Comprehensions
for
comprehensions
Scala does not have for
loops; instead, it has for
comprehensions. Comprehensions are particularly powerful when you need to combine and filter results from multiple collections. Let’s see some examples!
Warm-up ⭐️
In previous weeks we saw how to filter, map over, and flatten lists using recursive functions and List
API functions. This week, we have a new way to perform these computations: comprehensions.
Just 3 (filters) 🧪
Using a for
comprehension, write a function that filters a list of words (represented as String
s), keeping only words of length 3:
def onlyThreeLetterWords(words: List[String]): List[String] =
???
comprehensions/src/main/scala/comprehensions/forcomp.scala
LOUDER (maps)
Using a for
comprehension, write a function that converts a list of words (represented as String
s) to uppercase, using the .toUpperCase()
method:
def louder(words: List[String]): List[String] =
???
comprehensions/src/main/scala/comprehensions/forcomp.scala
For example List("a", "bc", "def", "ghij")
should become List("A", "BC", "DEF", "GHIJ")
.
When using internationalization-related methods, such as .toUpperCase(locale)
or .toLowerCase(locale)
, always be careful about which locale parameter you use.
If omitted, locale defaults to the user’s current locale. On some computers, this will create issues with our tests: for example, on a Turkish system, calling "TITLE".toLowerCase()
will yield "tıtle"
, not "title"
(you can reproduce this behavior by passing locale = Locale("tr")
).
Locale-related bugs are a common source of issues in real-world apps — think hard about the right value, and if writing locale-independent code, use the standardized invariant locale Locale.ROOT
.
Echo (flatMaps)
Using a for
comprehension, write a function that repeats each word in a list of words n
times. Write a comprehension with two separate <-
clauses, and use either Iterable.fill(n)(word)
to create an iterable of n
times the value word
, or (1 to n)
to iterate n
times.
def echo(words: List[String], n: Int): List[String] =
???
comprehensions/src/main/scala/comprehensions/forcomp.scala
For example List("a", "bc", "def", "ghij")
should become List("a", "a", "bc", "bc", "def", "def", "ghij", "ghij")
if n == 2
.
What should happen if n == 1
? How about n == 0
?
Solution
For n == 1
, the resulting list is the same as the source string. For n == 0
, the resulting list is empty.
All together now 🧪
Using a for
comprehension, write a function that converts all words in a list to upper case, removes all words whose length is not three, and repeats all others n
times:
def allTogether(words: List[String], n: Int): List[String] =
???
comprehensions/src/main/scala/comprehensions/forcomp.scala
For example List("All", "together", "now")
should become List("ALL", "ALL", "ALL", "NOW", "NOW", "NOW")
if n
is 3
.
Cross product
Reimplement the cross-product
function from last week using a for
comprehension.
Additionally, your function should now take Scala List
s as input, and return a Scala List
as output.
def crossProduct[A, B](l1: List[A], l2: List[B]): List[(A, B)] =
???
comprehensions/src/main/scala/comprehensions/forcomp.scala
Permutations
Reimplement the distinctPairs
function from last week using a for
comprehension.
def distinctPairs[A](items: Seq[A]): Seq[(A, A)] =
???
comprehensions/src/main/scala/comprehensions/forcomp.scala
Comprehensions on Option
⭐️
When thinking about a for
comprehension, the most intuitive use of it would be on a list. But it can be used on any type that defines map
, flatMap
and, optionally, withFilter
. In particular, we can use for
comprehensions on Option
values.
For this exercise, we provide you with a function that parses a string into an Option
of integer. For example, parseInt("123")
returns Some(123)
, but parseInt("abc")
returns None
. You do not need to understand the inner works of the function, as we will learn more about Try
later 🔜.
def parseInt(s: String): Option[Int] =
Try(s.toInt).toOption
comprehensions/src/main/scala/comprehensions/forcomp.scala
Implement the validInts
function, that parses a list of strings to integers and returns only valid integers.
def validInts(strings: List[String]): List[Int] =
???
comprehensions/src/main/scala/comprehensions/forcomp.scala
Glob matching (ungraded callback to find
!)
The real Unix find
command takes a “glob pattern” for its -name
filter: it supports wildcards like ?
and *
to allow for partial matches. For example,
find -name 2023-*.jpg
finds all files whose name starts with2023-
and ends with.jpg
, andfind -name 20??-*.jp*g
finds all files whose name starts with20
, then has two arbitrary characters, then a dash, any letter, and finally.jp
followed by anything followed byg
. This would allow users to find2002-03-18T11:15.jpg
or2002-03-18 modified.jpeg
, for example.
Write a function glob(pattern, input)
that returns a boolean indicating whether a glob pattern matches a string (represented as a list of Char
s):
def glob(pattern: List[Char], input: List[Char]): Boolean =
???
comprehensions/src/main/scala/comprehensions/Glob.scala
The rules are as follows:
?
matches one arbitrary character*
matches an arbitrary sequence of characters- Other characters match themselves
The whole pattern must match the whole string: partial matches are not allowed.
Implementation guide
This is a tricky problem at first sight, but it admits a very nice recursive solution. Think of it in groups: how can you reduce the problem of matching a string against a pattern to a problem with a smaller string, or a smaller pattern?
Based on your implementation, can you prove that a pattern that does not contain wildcards matches only itself? That *
matches all strings? That a pattern with only ?
s matches all strings of the same length as the pattern? That repeated *
s can be replaced by a single *
?
If you plug in your new function into your copy of find, you’ll get an even better file-searcher! It should be a very straightforward refactoring — just change the function passed to the higher-order find function that we wrote in the last callback.
Des chiffres et des lettres
“Des chiffres et des lettres” is a popular TV show in French-speaking countries.
In this show, contestants take turns guessing the longest word that can be made from a list of letters, and finding a way to arrange a list of numbers into an arithmetic computation to get as close as possible to a target number. For example:
-
For the letters
G N Q E U T I O L Y C E P H
, an excellent contestant would immediately propose the word “POLYTECHNIQUE”. -
For the numbers
2 5 3 9 100 9
and the target number304
, an excellent contestant would suggest3 * (9 + 9) + 5 * 100 / 2
, and exclaim “le compte est bon!”.
Des lettres 🧪
Write a function longestWord
which, given a wordlist (a collection of words) represented as a List[String]
and a collection of letters represented as a String
, finds the longest word that can be made by reordering a subset of these strings. You can write your function directly, or follow the steps below.
Reveal step-by-step hints
Let’s represent collections of letters by converting them to uppercase and sorting them, so that "Polytechnique"
becomes "CEEHILNOPQTUY"
. We call "Polytechnique"
the “original” word, and "CEEHILNOPQTUY"
the “scrambled” word.
-
Write a function
scramble
that transforms a single word into its scrambled representation.def scramble(word: String): String = ???
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
Hint
You may find
.toUpperCase()
and.sorted
useful. -
Write a function
scrambleList
that transforms a wordlist into aMap
from scrambled words to sets of original words.For example,
Set("Nale", "lean", "wasp", "swap")
should becomeMap("AELN" -> Set("Nale", "lean"), "APSW" -> Set("wasp", "swap"))
def scrambleList(allWords: Set[String]): Map[String, Set[String]] = ???
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
Hint
Consider using
.groupBy
to create the map. -
Write a function
exactWord
that returns all words of a wordlist that can be formed by using all letters from a given collection of letters.def exactWord(allWords: Set[String], letters: String): Set[String] = ???
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
-
Write a function
compatible
that checks whether a scrambled word can be formed from a collection of letters. Beware of repeated letters!def compatible(small: String, large: String): Boolean = ???
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
-
Write a function
longestWord
that returns the longest word that can be formed using some letters from a given collection of letters.def longestWord(allWords: Set[String], letters: String): Set[String] = ???
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
Des chiffres 🔥
Write a function leCompteEstBon
that takes a List
of integers and a target number, and returns an arithmetic expression that evaluates to the target sum, if one exists. The allowable operators are +
, -
(only when the result is nonnegative), *
, and /
(only when the division is exact), as well as parentheses. For the purpose of this exercise, assume that we are only interested in exact results, using all provided integers. The result should be an expression of type Expr
:
The Expr
type
We have provided you with an Expr
type as a starting point:
trait Expr:
val value: Option[Int]
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
Notice that each Expr
reports its own value as an Option[Int]
: this is because this type has two direct subclasses: numbers and binary operators:
case class Num(n: Int) extends Expr:
val value = Some(n)
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
abstract class Binop extends Expr:
val e1, e2: Expr // Subexpressions
def op(n1: Int, n2: Int): Option[Int] // How to evaluate this operator
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
The value
of a binop is defined thus, using for
to unpack options:
val value: Option[Int] =
for
n1 <- e1.value
n2 <- e2.value
r <- op(n1, n2)
yield r
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
Finally, the four arithmetic operators are subclasses of Binop
:
case class Add(e1: Expr, e2: Expr) extends Binop:
def op(n1: Int, n2: Int) =
Some(n1 + n2)
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
case class Sub(e1: Expr, e2: Expr) extends Binop:
def op(n1: Int, n2: Int) =
if n1 < n2 then None else Some(n1 - n2)
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
case class Mul(e1: Expr, e2: Expr) extends Binop:
def op(n1: Int, n2: Int) =
Some(n1 * n2)
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
case class Div(e1: Expr, e2: Expr) extends Binop:
def op(n1: Int, n2: Int) =
if n2 != 0 && n1 % n2 == 0 then Some(n1 / n2) else None
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
Note how -
and /
sometimes return None
.
Our solution is short, and uses a combination of for
comprehensions and recursion. Take the time to think through how you might divide this problem into smaller subproblems.
Reveal possible approaches
There are broadly two possible approaches: a top-down one, which splits the input set into two halves, and combines them with an operator; or a bottom-up one, which combines numbers into increasingly larger expression trees.
Reveal step-by-step hints (top-down)
-
Write a recursive function
partitions
which generates all partitions of a list into two non-overlapping sublists.def partitions[A](l: List[A]): List[(List[A], List[A])] = ???
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
Hint
This function must decide, for each element, whether it goes into the left or the right partition.
-
Write a recursive function
allTrees
that generates all possible trees of expressions from a set of numbers usingpartitions
.def allTrees(ints: List[Int]): List[Expr] = ???
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
Hint
At each step, this function should call itself recursively twice, once per subset, and generate one tree per operator.
-
Write a recursive function
leCompteEstBon
that finds an expression among the possible ones that match the target number, or returns None if the target cannot be achieved:def leCompteEstBon(ints: List[Int], target: Int): Option[Expr] = ???
comprehensions/src/main/scala/comprehensions/DesChiffresEtDesLettres.scala
Since the steps above are only suggestions, we have provided only integration tests (tests for leCompteEstBon
), and no unit tests for intermediate functions. Make sure to write a few unit tests to make sure you understand what each function does before starting! (And feel free to share them with other students on Ed!)
You may find that your function takes a long time to return in our tests. In that case, write your own small tests to make sure that it works, then study it by running it on examples to understand where it wastes time. Can you think of optimizations?
Hint
Focus on allTrees
. Does it really have to return all trees? For example, given the set List(2, 2)
, is it valuable to return both Add(Num(2), Num(2))
and Mul(Num(2), Num(2))
? Similarly, on the set List(2, 3, 4)
, is it valuable to keep both 2 * 3 + 4
and 3 * 4 - 2
? Our solution keeps just one of each, and this speeds it up from multiple seconds per problem to just a few milliseconds.
Tracing code execution
Polish notation (from week 1) ⭐️
Use tracing to gain a better understanding of the following function, which computes the value of a polish-notation expression:
enum Operand[+T]:
case Add extends Operand[Nothing]
case Mul extends Operand[Nothing]
case Num(t: T)
type OpStack[T] = List[Operand[T]]
comprehensions/src/main/scala/comprehensions/Tracing.scala
def polishEval(ops: OpStack[Int]): (Int, OpStack[Int]) =
ops match
case Nil => throw IllegalArgumentException()
case op :: afterOp =>
op match
case Operand.Num(n) =>
(n, afterOp)
case Operand.Add =>
val (l, afterL) = polishEval(afterOp)
val (r, afterR) = polishEval(afterL)
(l + r, afterR)
case Operand.Mul =>
val (l, afterL) = polishEval(afterOp)
val (r, afterR) = polishEval(afterL)
(l * r, afterR)
comprehensions/src/main/scala/comprehensions/Tracing.scala
McCarthy ⭐️
What does the following function do? Use tracing annotations to figure it out!
def mc(n: Int): Int =
if n > 100 then n - 10
else mc(mc(n + 11))
comprehensions/src/main/scala/comprehensions/Tracing.scala
Takeuchi 🔥
What do the following functions do? Use tracing annotations to figure it out!
def t(x: Int, y: Int, z: Int): Int =
if x <= y then y
else
t(
t(x - 1, y, z),
t(y - 1, z, x),
t(z - 1, x, y)
)
comprehensions/src/main/scala/comprehensions/Tracing.scala
This function and the previous one were both important in the development of computer-assisted proofs, as they exhibit complex, nested recursion patterns. This 1991 paper discusses the two examples above (and the problems it lists as “open problems” have now been solved!).
Tree folds
Two functions pairs
and foldt
are defined below. Use examples, pictures, the substitution method, refactoring, or tracing to figure out what foldt
does.
extension [T](l: List[T])
def pairs(op: (T, T) => T): List[T] = l match
case a :: b :: tl => op(a, b) :: tl.pairs(op)
case _ => l
def foldt(z: T)(op: (T, T) => T): T = l match
case Nil => z
case List(t) => t
case _ :: tail => l.pairs(op).foldt(z)(op)
comprehensions/src/main/scala/comprehensions/Tracing.scala
Armed with that knowledge, can you figure out what algorithm the function ms
below implements? Is it efficient? How does it differ from another version of the same algorithm that you saw previously?
extension (l: List[Int])
def ms: List[Int] =
l.map(List(_)).foldt(Nil)(merge)
comprehensions/src/main/scala/comprehensions/Tracing.scala
(merge
is the usual function that takes two sorted lists and merges them in sorted order)
def merge(xs: List[Int], ys: List[Int]): List[Int] =
(xs, ys) match
case (Nil, _) => ys
case (_, Nil) => xs
case (x :: xs1, y :: ys1) =>
if x < y then x :: merge(xs1, ys)
else y :: merge(xs, ys1)
comprehensions/src/main/scala/comprehensions/recursion.scala
Tracing performance issues
Sometimes, tracing can even help find performance issues: the longer a calculation takes, and the slower tracing statements will appear on screen. Can you use this insight to demonstrate that something is wrong with each of the following functions?
- Reverse
def badReverse[T](l: List[T], acc: List[T] = Nil): List[T] =
l match
case Nil => acc.reverse
case h :: t => badReverse(t, acc ++ List(h))
comprehensions/src/main/scala/comprehensions/Tracing.scala
- Map
def badMap[T1, T2](l: List[T1], f: T1 => T2): List[T2] =
if l.length == 0 then Nil
else f(l.head) :: badMap(l.tail, f)
comprehensions/src/main/scala/comprehensions/Tracing.scala
Tracing comprehensions
Performance issues
The following function has surprisingly poor performance. Why?
def badZip[T1, T2](l1: List[T1], l2: List[T2]): Seq[(T1, T2)] =
for i <- (0 to math.min(l1.length, l2.length) - 1)
yield (l1(i), l2(i))
comprehensions/src/main/scala/comprehensions/Tracing.scala
A subtlety in desugaring and tracing comprehensions 🔥
A keystone of the CS214 debugging guide is to “understand the system”. Sometimes observing the system’s final output is enough, but other times adding instrumentation helps. It pays to be careful, however: any time you change the system that you are debugging, you risk changing its behavior, too. Let’s see an example.
As we’ve seen previously, tracing can help not just with recursion, but also with comprehensions: to insert a tracing instruction into a comprehension, we use _ = println(…)
.
-
Here is an instrumented comprehension. What does the tracing code print, and in which order?
def traceIfTrue(b: Boolean, label: String) = if b then println(label) b def filter_tracedIfTrue(l: Seq[Int]): Seq[Int] = for n <- l if traceIfTrue(n % 5 == 0, f"$n is a multiple of 5!") if traceIfTrue(n % 3 == 0, f" $n is also a multiple of 3!") if traceIfTrue(n >= 10, f" $n is also greater than 10!") yield n
comprehensions/src/main/scala/comprehensions/Tracing.scala
Confirm your guess by running this code in a worksheet.
-
Here is a different, arguably more readable way to trace the same comprehension. What will it print?
def filter_traced(l: Seq[Int]): Seq[Int] = for n <- l if n % 5 == 0 _ = println(f"$n is a multiple of 5!") if n % 3 == 0 _ = println(f" $n is also a multiple of 3!") if n >= 10 _ = println(f" $n is also greater than 10!") yield n
comprehensions/src/main/scala/comprehensions/Tracing.scala
As before, confirm your guess by running this code in a worksheet.
-
What can you conclude?
Non-structural recursion: numbers
Factorial ⭐️
Define a recursive function that returns the factorial of the input integer.
def factorial(n: Int): Int =
???
comprehensions/src/main/scala/comprehensions/recursion.scala
Fast exponentiation
Fast exponentiation is a technique to optimize the exponentiation of numbers:
b²ⁿ = (b²)ⁿ = (bⁿ)²
b²ⁿ⁺¹ = b * b²ⁿ
Define a function that implements this fast exponentiation.
def fastExp(base: Int, exp: Int): Int =
???
comprehensions/src/main/scala/comprehensions/recursion.scala
Base-N encoding ⭐️
In this exercise, your task is to implement a function that encodes a given integer into its representation in a given base. The result should be a list of integers.
For instance:
- Encode the number 9 to base-2:
decimalToBaseN(9, 2)
should returnList(1, 0, 0, 1)
. - Encode the number 20 to base-16:
decimalToBaseN(20, 16)
should returnList(1, 4)
.
Your task is to implement the decimalToBaseN
function.
def decimalToBaseN(number: Int, base: Int, acc: List[Int] = Nil): List[Int] =
???
comprehensions/src/main/scala/comprehensions/recursion.scala
Non-structural recursion: collections
Coin Change ⭐️
You are given a list of coin denominations coins: List[Int]
, which are the different values a coin can have, and a target amount
. You need to determine the number of distinct ways to make up the target amount using any combination of the provided coin denominations. You can use each coin denomination an unlimited number of times.
def coinChange(coins: List[Int], amount: Int): Int =
???
comprehensions/src/main/scala/comprehensions/recursion.scala
Let’s illustrate the coin change problem with an example.
Suppose you have the following coin denominations: List(1, 2, 5)
, and you want to make change for the amount 5. How many different ways are there to make change for 5 using these coins?
Here are the different ways to make change for 5:
5 = 5 (use one 5-coin)
5 = 2 + 2 + 1 (use two 2-coins and one 1-coin)
5 = 2 + 1 + 1 + 1 (use one 2-coin and three 1-coins)
5 = 1 + 1 + 1 + 1 + 1 (use five 1-coins)
So, there are a total of 4 different ways to make change for 5 using the denominations List(1, 2, 5)
.
Merge sort ⭐️
Merge sort is a popular and efficient comparison-based sorting algorithm that follows the divide-and-conquer paradigm. It works by recursively dividing a list into smaller sublists, sorting them, and then merging them back together into a single sorted list.
Roughly, merge sort proceeds in the following way:
- Split the list into two sublists.
- Recursively sort each sublist.
- Merge the two sorted sublists back into one sorted list.
Split
Define a function that splits a list into two halves, whose length difference is less than or equal to 1.
There are multiple ways to implement such a function from the Scala standard library. You may use a library function.
For example, if the list has n
elements, the first list has the first $\lfloor n / 2 \rfloor$ elements and the second has the rest.
Or, when counting elements from 1, the first list has all items that were at odd positions, and the second one has all items that were at even positions. For example, the list a b c d e f g
will be split into a c e g
and b d f
.
def split[A](l: List[A]): (List[A], List[A]) =
???
comprehensions/src/main/scala/comprehensions/recursion.scala
Merge
Write a merge
function that takes two sorted lists and returns a sorted list containing all elements of both lists.
def merge(xs: List[Int], ys: List[Int]): List[Int] =
???
comprehensions/src/main/scala/comprehensions/recursion.scala
Sort
Implement the mergeSort
function.
If the list has one or zero elements, what should the function do?
def mergeSort(xs: List[Int]): List[Int] =
???
comprehensions/src/main/scala/comprehensions/recursion.scala
Exercises from the slides
N-Queens
Write a function isSafe
which tests if a queen placed in an indicated column col
is secure amongst the other placed queens.
It is assumed that the new queen is placed in the next available row after the other placed queens (in other words: in row queens.length
).
def isSafe(col: Int, queens: List[Int]): Boolean =
???
comprehensions/src/main/scala/comprehensions/NQueens.scala
Variance
The following exercises are intended to help you hone your understanding of covariance and contravariance.
Do these exercises on paper first. Using the compiler is a great way to confirm your understanding, but starting from the code will not help you develop your intuition and understanding of these rules, since the compiler will do all the work for you. After completing a first draft on paper, you can either check our solution, or check your solutions on your computer.
Subtyping of variant types ⭐️
Recall that:
- Lists are covariant in their only type parameter.
- Functions are contravariant in the argument, and covariant in the result.
Consider the following hierarchies:
abstract class Fruit
class Banana extends Fruit
class Apple extends Fruit
abstract class Liquid
class Juice extends Liquid
Consider also the following typing relationships for A
, B
, C
, D
: A <: B
and C <: D
.
Fill in the subtyping relation between the types below. Bear in mind that it might be that neither type is a subtype of the other.
Left hand side | ?: | Right hand side |
---|---|---|
List[Banana] | List[Fruit] | |
List[A] | List[B] | |
Banana => Juice | Fruit => Juice | |
Banana => Juice | Banana => Liquid | |
A => C | B => D | |
List[Banana => Liquid] | List[Fruit => Juice] | |
List[A => D] | List[B => C] | |
(Fruit => Juice) => Liquid | (Banana => Liquid) => Liquid | |
(B => C) => D | (A => D) => D | |
Fruit => (Juice => Liquid) | Banana => (Liquid => Liquid) | |
B => (C => D) | A => (D => D) |
Testing subtyping relationships
Scala’s compiler automatically checks variance rules. Can we make it show us what it inferred? Write a function assertSubtype
so that assertSubtype[Int, String]
causes a compilation error, but assertSubtype[String, Object]
compiles.
Understanding variance ⭐️
Variance rules are designed to keep type casts safe: any time the compiler rejects code for variance reasons, it does so to avoid runtime errors that could happen if these protections were not in place.
Here are some examples. For each, think of what might go wrong if the compiler didn’t disallow this pattern.
Contravariant field
Scala rejects the following definition. Why?
case class C[-A](a: A)
Show error message
case class C[-A](a: A) // Error
// ^ contravariant type A occurs in covariant position in type A of value a
Double negation
Scala rejects the following definition. Why?
trait C[-A]:
def app(f: A => Int): Int
Show error message
trait C[-A]:
def app(f: A => Int): Int // Error
// ^ contravariant type A occurs in covariant position in type A => Int of parameter f
Free and bound functions
Scala rejects the following definition:
trait F[+A]:
def f(a: A): A
… yet it accepts the following:
def f[A](a: A): A = a
variance/src/main/scala/variance/CounterExamples.scala
Why?
Show error message
trait F[+A]:
def f(a: A): A // Error
// ^ covariant type A occurs in contravariant position in type A of parameter a
Extension methods
Scala rejects the following definition:
trait Foldable1[+A]:
def fold(a: A)(f: (A, A) => A): A
… yet it accepts the following:
trait Foldable2[+A]
extension [A](t: Foldable2[A])
def fold(a: A)(f: (A, A) => A): A = ???
variance/src/main/scala/variance/CounterExamples.scala
Why?
Show error message
trait Foldable1[+A]:
def fold(a: A)(f: (A, A) => A): A // Error
// ^ ^ covariant type A occurs in contravariant position in type (A, A) => A of parameter f
// ^ covariant type A occurs in contravariant position in type A of parameter a
Implementing classes with variance
Here is a simple trait for stacks:
trait Stack[T]:
/** Peek at the top of this stack */
def peek(): Option[T]
/** Create a new stack with one more entry, at the top */
def push(t: T): Stack[T]
/** Separate the top entry from the rest of the stack */
def pop(): (Option[T], Stack[T])
variance/src/main/scala/variance/Variance.scala
-
Write an implementation of this trait.
-
Write a function that takes a list of stacks and collects the top of each stack into a new stack.
def joinStacks[T](l: List[Stack[T]]): Stack[T] = ???
variance/src/main/scala/variance/Variance.scala
🔜 An operation similar to this one is often called “join” in parallel programming. We will encounter it in the parallelism week, and we’ll study a structure similar to the
Box
below when we studyFuture
s. -
What happens if you try to call
join
with a mix of stacks with different element types (differentT
s)? Can you even put two such stacks together in a list? Why or why not?def mkStackInt(): Stack[Int] = ??? def mkStackString(): Stack[String] = ??? // Does this work? // val tops = joinStacks(List(mkStackInt(), mkStackString()))
variance/src/main/scala/variance/Variance.scala
-
Assume that we want to change
Stack[T]
to allow for multiple stacks with differentT
types to be placed together in a list. Is there a variance annotation forT
that will allow this? Which one makes more sense?Try to perform the change by adding an annotation on
T
inStack[T]
. What problem do you run into? Why is this error reasonable? (What would go wrong otherwise?) Change the signature ofpush
to make it work. -
Is it always possible to turn an invariant trait into a covariant one? Consider the following example:
trait Drawer[T]: def get(): T def put(t: T): Drawer[T] case class IncrementingDrawer(i: Int) extends Drawer[Int]: def get() = i - 1 def put(j: Int) = IncrementingDrawer(j + 1)
variance/src/main/scala/variance/Variance.scala
Can you make
Drawer
covariant inT
without breakingIncrementingDrawer
? -
From question 5, it is clear that covariance puts restrictions on what classes that implement a covariant trait can do (the same is true of contravariance). Based on this, can you think of a good reason not to make a container class (like a
List
or aStack
) covariant? -
Here is another trait, similar to stacks, but storing at most one element. First give
T
an adequate variance annotation as in theStack
case (this will require other changes, of course), then implement the trait.trait Box[T]: /** Peek at the value inside the box */ def unbox(): Option[T] /** Create a new box with the contents */ def replace(t: T): Box[T] /** Create a new box by applying `f` to the contents of this one */ def map[T2](f: T => T2): Box[T2]
variance/src/main/scala/variance/Variance.scala
-
Both of these traits have a way to inspect one element of the container. Change both of them to extend the following
HasTop
trait, which captures this property.trait HasTop[+T]: /** Peek at one value inside this container */ def top: Option[T]
variance/src/main/scala/variance/Variance.scala
-
Drawer[T]
is required to be invariant to be compatible withIncrementingDrawer
, butHasTop[+T]
is covariant inT
. CanDrawer[T]
implementHasTop[T]
without breakingIncrementingDrawer
? -
Rewrite the
join
operation to work on lists ofHasTop
instances, rather than stacks. Does thisjoin
still work if you remove the covariance annotation onT
inBox
andStack
? Why?def joinTops[T](l: List[HasTop[T]]): List[T] = ???
variance/src/main/scala/variance/Variance.scala
A variance puzzle 🔥
This exercise goes beyond what will be on the exam. It’s hard to get an intuition for it: instead, if you want to complete it, you may find our complete writeup on variance rules useful.
Assume the following two classes or traits…
C[-A, +B]
F[ A, +B] extends C[A, B]
… and the following subtyping relationships…
A <: B
X <: Y
… what are the relations between the following pairs?
C[A, Y]
andF[B, X]
;F[A, X]
andC[A, Y]
?
Check in a worksheet if you’re not sure that your answer is correct!
Exercises from the slides
Array Typing Problem
Arrays in Java are covariant, but covariant array typing causes problems. Do you know why?
To see why, let’s explore with an example:
Consider a class IntSet
that defines binary trees storing sets of integers. A set of integers is either an empty set represented by an empty tree, or a non-empty set stored in a tree that consists of one integer and two subtrees. In Java, IntSet
can be implemented as follows:
sealed interface IntSet permits Empty, NonEmpty {}
final class Empty implements IntSet {}
final class NonEmpty implements IntSet {
private final int elem;
private final IntSet left, right;
public NonEmpty (int elem, IntSet left, IntSet right) {
this.elem = elem;
this.left = left;
this.right = right;
}
}
In Scala, IntSet
would be written as follows:
sealed trait IntSet
case class Empty() extends IntSet
case class NonEmpty(elem: Int, left: IntSet, right: IntSet) extends IntSet
Now consider the Java code below:
NonEmpty[] a = new NonEmpty[]{
new NonEmpty(1, new Empty (), new Empty())};
IntSet[] b = a ;
b[0] = new Empty();
NonEmpty s = a[0];
It looks like in the last line, we assigned an Empty
value to a variable of type NonEmpty
! What went wrong?
The problematic array example would be written as follows in Scala:
val a: Array[NonEmpty] = Array(NonEmpty(1, Empty(), Empty()))
val b: Array[IntSet] = a
b(0) = Empty()
val s: NonEmpty = a(0)
When you try out this example, what do you observe?
- A type error in line 1
- A type error in line 2
- A type error in line 3
- A type error in line 4
- A program that compiles and throws an exception at run-time
- A program that compiles and runs without exception
Liskov Substitution Principle
The Liskov Substitution Principle (LSP) tells us when a type can be a subtype of another:
If A <: B, then everything one can do with a value of type B one should also be able to do with a value of type A.
Assume the following type hierarchy and two function types:
trait Fruit
class Apple extends Fruit
class Orange extends Fruit
type A = Fruit => Orange
type B = Apple => Fruit
According to the LSP, which of the following should be true?
- A <: B
- B <: A
- A and B are unrelated.
Prepend
Consider adding a prepend
method to List
which prepends a given element, yielding a new list.
-
A first implementation of
prepend
could look like this:trait List[+T]: def prepend(elem: T): List[T] = Node(elem, this)
But that does not type-check. Why?
- prepend turns List into a mutable class.
- prepend fails variance checking.
- prepend’s right-hand side contains a type error.
-
Given the second implementation of
prepend
:trait List[+T]: def prepend [U >: T] (elem: U): List[U] = Node(elem, this)
What is the result type of this function (
Apple <: Fruit
,Orange <: Fruit
):def f(xs: List[Apple], x: Orange) = xs.prepend(x)
Possible answers:
- does not type check
- List[Apple]
- List[Orange]
- List[Fruit]
- List[Any]
Specifications
Reasoning about your code requires specifying what it is supposed to do??
In Scala, it’s common to use require
and ensuring
to specify pre- and post-conditions for functions, respectively.
The require
method checks a given condition (usually an input validation) and throws an IllegalArgumentException
if the condition is not met. On the other hand, ensuring
is used to validate the result of a function. It takes a predicate that the result must satisfy, and if not, an assertion error is thrown.
For example:
val eps = 0.00001f
def sqrt(x: Double): Double = {
require(x >= 0)
Math.sqrt(x)
} ensuring (res =>
(x - res * res) <= eps && (x - res * res) >= -eps
)
The following exercises will make writing such specs second nature!
Debugging with tests and specs
Bad combinatorics ⭐️
Consider the following incorrect function:
/** Construct all subsets of `s` of size `k` (INCORRECT!) */
def badCombinations[T](set: Set[T], k: Int): Set[Set[T]] = {
if k == 0 then Set(Set())
if set.isEmpty then Set()
else
for
item <- set
rest = set - item
subset <- Iterable.concat(
for s <- badCombinations(rest, k - 1) yield rest + item,
for s <- badCombinations(rest, k) yield rest
)
yield subset
}
specs/src/main/scala/specs/Tracing.scala
-
Read this function carefully.
-
Which inputs are meaningful for this function? Write a
require
clause to rule out meaningless arguments. Check that yourrequire
clause works by calling this function with invalid arguments in a worksheet and confirming that it raises an exception. -
Unit tests are a great way to ensure that corner cases are correctly handled. Create a new test suite class in
src/test/scala/specs/SpecsSuite.scala
(you can consult a previous-week exercise set or lab for inspiration, or the munit documentation), then write a unit test that asserts that this function computes the right result when called with an empty set as input (useassertEquals
andSet()
). -
Make sure that your test runs and reports a failure (it should, since this function is incorrect). You can use
sbt testOnly -- *your-test-name*
to run your test. -
Find the bug that leads to an empty set being returned, and confirm that the unit test now passes.
-
What does the documentation of this function promise? Write an
ensuring
clause to capture this formally. Is this a complete specification? (A specification is complete if any function that respects it is a correct implementation of the original requirements — here, the documentation string). -
Find an input for which this function violates the postcondition that you just wrote. Confirm your conjecture by calling this function in a worksheet and confirming that the
ensuring
clause throws an exception. -
Correct the remaining bugs in this function. You may find it useful to add tracing instructions to observe its behavior.
-
Bonus question: Is your corrected function efficient? Look carefully at the work it does on a simple example. Do you spot any redundant work?
Bad transactions
The following code defines a class BankAccount
and associated operations deposit
, withdraw
and transfer
:
protected class BankAccount(private var _balance: Double):
import AccOpResult.*
def balance = _balance
private def updateBalance(amount: Double): AccOpResult =
val oldBalance = balance
_balance = amount
Ok(oldBalance)
/** Deposits the specified amount into the bank account.
*
* @param amount
* The amount to be deposited. Must be non-negative.
*/
def deposit(amount: Double): AccOpResult = {
updateBalance(balance + amount)
}
/** Withdraws the specified amount from the bank account.
*
* @param amount
* The amount to be withdrawn. Must be non-negative.
*/
def withdraw(amount: Double): AccOpResult = {
if balance >= amount then
updateBalance(balance - amount)
else
InsufficientFund(balance)
}
/** Transfers the specified amount from this bank account to `that` bank
* account.
*
* @param amount
* The amount to be transferred. Must be non-negative.
*/
def transfer(that: BankAccount, amount: Double): (AccOpResult, AccOpResult) = {
if this.balance >= amount then
(this.withdraw(amount), that.deposit(amount))
else
(InsufficientFund(this.balance), Ok(that.balance))
}
specs/src/main/scala/specs/BankAccount.scala
Type AccOpResult
is used to memorize a “screenshot” of the balance before the operation and whether the operation is successful:
enum AccOpResult:
case Ok(oldBalance: Double)
case InsufficientFund(oldBalance: Double)
specs/src/main/scala/specs/BankAccount.scala
- Write the preconditions and postconditions for the functions
deposit
,withdraw
andtransfer
usingrequire
andensuring
. - Can you identify the bugs in the code that violate these conditions?
Translating specs from English
Ambiguous specs
The following statements are ambiguous. For each of them, explain the ambiguity (either in words, or by giving at least two different, unambiguous expressions in Scala code that are compatible with the original sentence):
- “Access to the system is restricted to senior engineers and developers.”
- “The argument to this function must be a supertype of a serializable type that implements
IsFinite
” - “Only finite
Double
lists are supported.” - “Any sibling of a node that makes a request must be ready to process answers.”
- “Objects passed to this method must be unlockable.”
A particularly good source of examples of ambiguities is newspapers; just look up “Ambiguous newspaper headlines” on Google.
Translating simple specs ⭐️
Translate the following English statements to mathematical statements or to Scala code. Aim to be as succinct and clear as possible; do not optimize for performance. For some statements, it may not be possible to write Scala code (which ones?).
- The list of integers
l
is sorted in ascending order - All values in map
m1
are keys in mapm2
. - Even numbers are always immediately followed by odd numbers in list
l
. - There are no negative numbers in the output of function
f
. - Functions
f
andg
compute the same output when given a positive number. - List
l1
is a permutation of listl2
. - Number
p
is prime.
Paths and cycles
Translate the following English statements to Scala code:
A path is a nonempty sequence of edges (pairs of points) such that the endpoint of each edge is the starting point of the next edge.
A cycle is a path whose starting point equals its endpoint.
Course policies
This course has a reasonably complex grading policy: one lab is dropped, we accept medical reasons for skipping a lab, callbacks combine with labs, we have a midterm and a final, etc.
Translate the grading policy of this course (found in Overall Grade) into a Scala function, OverallGrade
. First, start with the easy version:
def OverallGrade(labScore: Double, midtermScore: Double, finalScore: Double): Double =
???
specs/src/main/scala/specs/Specs.scala
🔥 Then, try the full grading policy.
Finding incompleteness in specs ⭐️
The following ensuring
specifications are incomplete: they do not detect all incorrect behaviors.
For each one, replace the existing function body with an incorrect implementation that still obeys the ensuring
clause, then find an input that proves that your modified function is incorrect and confirm that the ensuring
clause is incomplete by checking that it does not raise an assertion.
-
def filterWithIncompleteSpec[T](l: List[T], p: T => Boolean) = { l.filter(p) } ensuring (res => res.forall(p))
specs/src/main/scala/specs/Specs.scala
-
def mapWithIncompleteSpec[T, Q](ts: List[T])(f: T => Q) = { ts.map(f) } ensuring (qs => qs.length == ts.length && qs.forall(q => ts.exists(t => f(t) == q)) )
specs/src/main/scala/specs/Specs.scala
-
def flattenWithIncompleteSpec[T](tss: List[List[T]]) = { tss.flatten } ensuring (ts => ts.length == tss.map(_.length).sum && tss.forall(ts.containsSlice(_)) )
specs/src/main/scala/specs/Specs.scala