Last updated on

Comprehensions

Welcome to week 5 of CS-214 — Software Construction!

As before, exercises or questions marked ⭐️ are the most important, 🔥 are the most challenging, and 🧪 are most useful for this week’s lab.

We strongly encourage you to solve the exercises on paper first, in groups. After completing a first draft on paper, you may want to check your solutions on your computer. To do so, you can download the scaffold code (ZIP).

for comprehensions

Scala does not have for loops; instead, it has for comprehensions. Comprehensions are particularly powerful when you need to combine and filter results from multiple collections. Let’s see some examples!

Warm-up ⭐️

In previous weeks we saw how to filter, map over, and flatten lists using recursive functions and List API functions. This week, we have a new way to perform these computations: comprehensions.

Just 3 (filters) 🧪

Using a for comprehension, write a function that filters a list of words (represented as Strings), keeping only words of length 3:

def onlyThreeLetterWords(words: List[String]): List[String] =
  ???

src/main/scala/comprehensions/forcomp.scala

LOUDER (maps)

Using a for comprehension, write a function that converts a list of words (represented as Strings) to uppercase, using the .toUpperCase() method:

def louder(words: List[String]): List[String] =
  ???

src/main/scala/comprehensions/forcomp.scala

For example List("a", "bc", "def", "ghij") should become List("A", "BC", "DEF", "GHIJ").

When using internationalization-related methods, such as .toUpperCase(locale) or .toLowerCase(locale), always be careful about which locale parameter you use.

If omitted, locale defaults to the user’s current locale. On some computers, this will create issues with our tests: for example, on a Turkish system, calling "TITLE".toLowerCase() will yield "tıtle", not "title" (you can reproduce this behavior by passing locale = Locale("tr")).

Locale-related bugs are a common source of issues in real-world apps — think hard about the right value, and if writing locale-independent code, use the standardized invariant locale Locale.ROOT.

Echo (flatMaps)

Using a for comprehension, write a function that repeats each word in a list of words n times. Write a comprehension with two separate <- clauses, and use either Iterable.fill(n)(word) to create an iterable of n times the value word, or (1 to n) to iterate n times.

def echo(words: List[String], n: Int): List[String] =
  ???

src/main/scala/comprehensions/forcomp.scala

For example List("a", "bc", "def", "ghij") should become List("a", "a", "bc", "bc", "def", "def", "ghij", "ghij") if n == 2.

What should happen if n == 1? How about n == 0?

Solution

For n == 1, the resulting list is the same as the source string. For n == 0, the resulting list is empty.

All together now 🧪

Using a for comprehension, write a function that converts all words in a list to upper case, removes all words whose length is not three, and repeats all others n times:

def allTogether(words: List[String], n: Int): List[String] =
  ???

src/main/scala/comprehensions/forcomp.scala

For example List("All", "together", "now") should become List("ALL", "ALL", "ALL", "NOW", "NOW", "NOW") if n is 3.

Cross product

Reimplement the cross-product function from last week using a for comprehension.

Additionally, your function should now take a Scala List as input, and return a Scala List as output.

def crossProduct[A, B](l1: List[A], l2: List[B]): List[(A, B)] =
  ???

src/main/scala/comprehensions/forcomp.scala

Triangles in a directed graph

Reimplement the triangles function from last week using a for comprehension and Scala Lists.

def triangles(edges: DirectedGraph): List[(NodeId, NodeId, NodeId)] =
  ???

src/main/scala/comprehensions/forcomp.scala

Glob matching (ungraded callback to find!)

The real Unix find command takes a “glob pattern” for its -name filter: it supports wildcards like ? and * to allow for partial matches. For example,

Write a function glob(pattern, input) that returns a boolean indicating whether a glob pattern matches a string (represented as a list of Chars):

def glob(pattern: List[Char], input: List[Char]): Boolean =
  ???

src/main/scala/comprehensions/Glob.scala

The rules are as follows:

The whole pattern must match the whole string: partial matches are not allowed.

Implementation guide

This is a tricky problem at first sight, but it admits a very nice recursive solution. Think of it in groups: how can you reduce the problem of a matching a string against a pattern to a problem with a smaller string, or a smaller pattern?

Based on your implementation, can you prove that a pattern that does not contain wildcards matches only itself? That “*” matches everything? That a pattern with only ? matches all strings of the same length as the pattern? That repeated * can be replaced by single *?

If you plug in your new function into your copy of find, you’ll get an even better file-searcher! It should be a very straightforward refactoring — just change the function passed to the higher-order find function that we wrote in the last callback.

Des chiffres et des lettres

Des chiffres et des lettres” is a popular TV show in French-speaking countries.

In this show, contestants take turns guessing the longest word that can be made from a list of letters, and finding a way to arrange a list of numbers into an arithmetic computation to get as close as possible to a target number. For example:

Des lettres 🧪

Write a function longestWord which, given a wordlist (a collection of words) represented as a List[String] and a collection of letters represented as a String, finds the longest word that can be made by reordering a subset of these strings. You can write your function directly, or follow the steps below.

We have not provided tests for this function. Before starting, write a few unit tests. What does it mean for the result of longestWord to be correct?

Reveal step-by-step hints

Let’s represent collections of letters by converting them to uppercase an sorting them, so that "Polytechnique" becomes "CEEHILNOPQTUY". We call "Polytechnique" the “original” word, and "CEEHILNOPQTUY" the “scrambled” word.

  1. Write a function scramble that transforms a single word into its scrambled representation.

    def scramble(word: String): String =
      ???
    

    src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

    Hint

    You may find .toUpperCase() and .sorted useful.

  2. Write a function scrambleList that transforms a wordlist into a Map from scrambled words to sets of original words.

    For example, Set("Nale", "lean", "wasp", "swap") should become Map("AELN" -> Set("Nale", "lean"), "APSW" -> Set("wasp", "swap"))

    def scrambleList(allWords: Set[String]): Map[String, Set[String]] =
      ???
    

    src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

    Hint

    Consider using .groupBy to create the map.

  3. Write a function exactWord that returns all words of a wordlist that can be formed by using all letters from a given collection of letters.

    def exactWord(allWords: Set[String], letters: String): Set[String] =
      ???
    

    src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

  4. Write a function compatible that checks whether a scrambled word can be formed from a collection of letters. Beware of repeated letters!

    def compatible(small: String, large: String): Boolean =
      ???
    

    src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

  5. Write a function longestWord that returns the longest word that can be formed using some letters from a given collection of letters.

    def longestWord(allWords: Set[String], letters: String): Set[String] =
      ???
    

    src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

The lab uses a different data structure to check whether a multiset of letters contains another.

Des chiffres 🔥

Write a function leCompteEstBon that takes a List of integers and a target number, and returns an arithmetic expression that evaluates to the target sum, if one exists. The allowable operators are +, - (only when the result is nonnegative), *, and / (only when the division is exact), as well as parentheses. For the purpose of this exercise, assume that we are only interested in exact results, using all provided integers. The result should be an expression of type Expr:

The Expr type

We have provided you with an Expr type as a starting point:

trait Expr:
  val value: Option[Int]

src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

Notice that each Expr reports its own value as an Option[Int]: this is because this type has two direct subclasses: numbers and binary operators:

case class Num(n: Int) extends Expr:
  val value = Some(n)

src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

abstract class Binop extends Expr:
  val e1, e2: Expr // Subexpressions
  def op(n1: Int, n2: Int): Option[Int] // How to evaluate this operator

src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

The value of a binop is defined thus, using for to unpack options:

val value: Option[Int] =
  for
    n1 <- e1.value
    n2 <- e2.value
    r <- op(n1, n2)
  yield r

src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

Finally, the four arithmetic operators are subclasses of Binop:

case class Add(e1: Expr, e2: Expr) extends Binop:
  def op(n1: Int, n2: Int) =
    Some(n1 + n2)

src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

case class Sub(e1: Expr, e2: Expr) extends Binop:
  def op(n1: Int, n2: Int) =
    if n1 < n2 then None else Some(n1 - n2)

src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

case class Mul(e1: Expr, e2: Expr) extends Binop:
  def op(n1: Int, n2: Int) =
    Some(n1 * n2)

src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

case class Div(e1: Expr, e2: Expr) extends Binop:
  def op(n1: Int, n2: Int) =
    if n2 != 0 && n1 % n2 == 0 then Some(n1 / n2) else None

src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

Note how - and / sometimes return None.

Our solution is short, and uses a combination of for comprehensions and recursion. Take the time to think through how you might divide this problem into smaller subproblems.

Reveal possible approaches

There are broadly two possible approaches: a top-down one, which splits the input set into two halves, and combines them with an operator; or a bottom-up one, which combines numbers into increasingly larger expression trees.

Reveal step-by-step hints (top-down)
  1. Write a recursive function partitions which generates all partitions of a list into two non-overlapping sublists.

    def partitions[A](l: List[A]): List[(List[A], List[A])] =
        ???
    

    src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

    Hint

    This function must decide, for each element, whether it goes into the left or the right partition.

  2. Write a recursive function allTrees that generates all possible trees of expressions from a set of numbers using partitions.

    def allTrees(ints: List[Int]): List[Expr] =
        ???
    

    src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

    Hint

    At each step, this function should call itself recursively twice, once per subset, and generate one tree per operator.

  3. Write a recursive function leCompteEstBon that finds an expression among the possible ones that match the target number, or returns None if the target cannot be achieved:

    def leCompteEstBon(ints: List[Int], target: Int): Option[Expr] =
        ???
    

    src/main/scala/comprehensions/DesChiffresEtDesLettres.scala

Since the steps above are only suggestions, we have provided only integration tests (tests for leCompteEstBon), and no unit tests for intermediate functions. Make sure to write a few unit tests to make sure you understand what each function does before starting! (And feel free to share them with other students on Ed!)

You may find that your function takes a long time to return on our tests. In that case, write your own small tests to make sure that it works, then study it by running it on examples to understand where it wastes time. Can you think of optimizations?

Hint

Focus on allTrees. Does it really have to return all trees? For example, given the set List(2, 2), is it valuable to return both Add(Num(2), Num(2)) and Mul(Num(2), Num(2))? Similarly, on the set List(2, 3, 4), is it valuable to keep both 2 * 3 + 4 and 3 * 4 - 2? Our solution keeps just one of each, and this speeds it up from multiple seconds per problem to just a few milliseconds.

Tracing comprehensions

Performance issues

The following function has surprisingly poor performance. Why?

def badZip[T1, T2](l1: List[T1], l2: List[T2]): Seq[(T1, T2)] =
  for i <- (0 to math.min(l1.length, l2.length) - 1)
  yield (l1(i), l2(i))

src/main/scala/comprehensions/Tracing.scala

A subtlety in desugaring and tracing comprehensions 🔥

A keystone of the CS214 debugging guide is to “understand the system”. Sometimes observing the system’s final output is enough, but other times adding instrumentation helps. It pays to be careful, however: any time you change the system that you are debugging, you risk changing its behavior, too. Let’s see an example.

As we’ve seen previously, tracing can help not just with recursion, but also with comprehensions: to insert a tracing instruction into a comprehension, we use _ = println(…).

  1. Here is an instrumented comprehension. What does the tracing code print, and in which order?

    def traceIfTrue(b: Boolean, label: String) =
      if b then println(label)
      b
    
    def filter_tracedIfTrue(l: Seq[Int]): Seq[Int] =
      for
        n <- l
        if traceIfTrue(n % 5 == 0, f"$n is a multiple of 5!")
        if traceIfTrue(n % 3 == 0, f"  $n is also a multiple of 3!")
        if traceIfTrue(n >= 10, f"    $n is also greater than 10!")
      yield n
    

    src/main/scala/comprehensions/Tracing.scala

    Confirm your guess by running this code in a worksheet.

  2. Here is a different, arguably more readable way to trace the same comprehension. What will it print?

    def filter_traced(l: Seq[Int]): Seq[Int] =
      for
        n <- l
        if n % 5 == 0
        _ = println(f"$n is a multiple of 5!")
        if n % 3 == 0
        _ = println(f"  $n is also a multiple of 3!")
        if n >= 10
        _ = println(f"    $n is also greater than 10!")
      yield n
    

    src/main/scala/comprehensions/Tracing.scala

    As before, confirm your guess by running this code in a worksheet.

  3. What can you conclude?

Non-structural recursion: numbers

Factorial ⭐️

Define a recursive function which returns the factorial of the input integer.

def factorial(n: Int): Int =
  ???

src/main/scala/comprehensions/recursion.scala

Fast exponentiation

Fast exponentiation is a technique to optimize the exponentiation of numbers:

b²ⁿ = (b²)ⁿ = (bⁿ)²
b²ⁿ⁺¹ = b * b²ⁿ

Define a function that implements this fast exponentiation.

def fastExp(base: Int, exp: Int): Int =
  ???

src/main/scala/comprehensions/recursion.scala

Base-N encoding ⭐️

In this exercise, your task is to implement a function that encodes a given integer into its representation in a given base. The result should be a list of integers.

For instance:

  1. Encode the number 9 to base-2: decimalToBaseN(9, 2) should return List(1, 0, 0, 1).
  2. Encode the number 20 to base-16: decimalToBaseN(20, 16) should return List(1, 4).

Your task is to implement the decimalToBaseN function.

def decimalToBaseN(number: Int, base: Int, acc: List[Int] = Nil): List[Int] =
    ???

src/main/scala/comprehensions/recursion.scala

Non-structural recursion: collections

Coin Change ⭐️

You are given a list of coin denominations coins: List[Int], which are the different values a coin can have, and a target amount. You need to determine the number of distinct ways to make up the target amount using any combination of the provided coin denominations. You can use each coin denomination an unlimited number of times.

def coinChange(coins: List[Int], amount: Int): Int =
    ???

src/main/scala/comprehensions/recursion.scala

Let’s illustrate the coin change problem with an example.

Suppose you have the following coin denominations: List(1, 2, 5), and you want to make change for the amount 5. How many different ways are there to make change for 5 using these coins?

Here are the different ways to make change for 5:

5 = 5 (use one 5-coin)
5 = 2 + 2 + 1 (use two 2-coins and one 1-coin)
5 = 2 + 1 + 1 + 1 (use one 2-coin and three 1-coins)
5 = 1 + 1 + 1 + 1 + 1 (use five 1-coins)

So, there are a total of 4 different ways to make change for 5 using the denominations List(1, 2, 5).

Merge sort ⭐️

Merge sort is a popular and efficient comparison-based sorting algorithm that follows the divide-and-conquer paradigm. It works by recursively dividing a list into smaller sublists, sorting them, and then merging them back together into a single sorted list.

Roughly, merge sort proceeds in the following way:

Split

Define a function that splits a list into two halves, whose length difference is less than or equal to 1.

There are multiple ways to implement such function. You may use a library function.

For example, if the list has n elements, the first list has the first $\lfloor n / 2 \rfloor$ elements and the second has the rest.

Or, when counting elements from 1, the first list has all items that were at odd positions, and the second one has all items that were at even positions. For example, the list a b c d e f g will be split into a c e g and b d f.

def split[A](l: List[A]): (List[A], List[A]) =
  ???

src/main/scala/comprehensions/recursion.scala

Merge

Write a merge function that takes two sorted lists and returns a sorted list containing all elements of both lists.

def merge(xs: List[Int], ys: List[Int]): List[Int] =
  ???

src/main/scala/comprehensions/recursion.scala

Sort

Implement the mergeSort function.

If the list has one or zero elements, what should the function do?

def mergeSort(xs: List[Int]): List[Int] =
    ???

src/main/scala/comprehensions/recursion.scala

Exercises from the slides

N-Queens

Write a function isSafe which tests if a queen placed in an indicated column col is secure amongst the other placed queens. It is assumed that the new queen is placed in the next available row after the other placed queens (in other words: in row queens.length).

def isSafe(col: Int, queens: List[Int]): Boolean =
  ???

src/main/scala/comprehensions/NQueens.scala