Last updated on

Polymorphism

Welcome to week 4 of CS-214 — Software Construction!

The exercise set is intended to help you practice lists and polymorphism.

As usual, ⭐️ indicates the most important exercises and questions; 🔥, the most challenging ones; and 🔜, the ones that are useful to prepare for future lectures. Exercises or questions marked 🧪 are intended to build up to concepts used in this week’s lab.

You do not need to complete all exercises to succeed in this class, and you do not need to do all exercises in the order they are written.

We strongly encourage you to solve the exercises on paper first, in groups. After completing a first draft on paper, you may want to check your solutions on your computer. To do so, you can download the scaffold code (ZIP).

Warm-up: Polymorphic Lists ⭐️

In previous exercises and labs, we used IntList for lists which elements are integers. This week, we’ll move to polymorphic lists.

Reminder: Algebraic Data Types

In week 3, we learned that algebraic data types can be created with the enum construct. Check the previous lecture or this for more details.

Polymorphic lists can be defined as an algebraic data type in the following way:

enum MyList[+A]:
  case Nil
  case Cons(x: A, xs: MyList[A])

src/main/scala/poly/MyList.scala

Covariance

The + before A indicates that List is covariant in A. Check this for more details, or ignore it for now — we will cover it later!

We’ll use the above MyList type in this exercise.

Check Yourself

How would you define the isEmpty, head, and tail methods on such polymorphic lists?

Solution
def isEmpty: Boolean = this match
  case Nil => true
  case _   => false

def head: A = this match
  case Nil        => throw EmptyListException()
  case Cons(x, _) => x

def tail: MyList[A] = this match
  case Nil         => throw EmptyListException()
  case Cons(_, xs) => xs

src/main/scala/poly/MyList.scala

Functions on Polymorphic Lists ⭐️

Part 1: Higher-order functions

Common Operators on Lists

In previous weeks, we have implemented higher-order functions on IntLists. For example, map was defined as:

def map(l: IntList)(f: Int => Int): IntList =
  l match
    case IntNil         => IntNil
    case IntCons(x, xs) => IntCons(f(x), map(xs)(f))

src/main/scala/poly/MyList.scala

In contrast, the generic version of the function map (whose type consists of type parameters instead of concrete types like IntList and Int => Int) has the following signature:

def map[A, B](l: MyList[A])(f: A => B): MyList[B]
  1. Based on the example above, write a generic signature for filter, foldRight, reduceRight, forall, exists, zip, and zipWith.

    Check Yourself
    def map[A, B](l: MyList[A])(f: A => B): MyList[B] =
      ???
    
    def filter[A](l: MyList[A])(p: A => Boolean): MyList[A] =
      ???
    
    def foldRight[A, B](l: MyList[A])(f: (A, B) => B, base: B): B =
      ???
    
    def reduceRight[A](l: MyList[A])(f: (A, A) => A): A =
      ???
    
    def forall[A](l: MyList[A])(p: A => Boolean): Boolean =
      ???
    
    def exists[A](l: MyList[A])(p: A => Boolean): Boolean =
      ???
    
    def zip[A, B](l1: MyList[A], l2: MyList[B]): MyList[(A, B)] =
      ???
    
    def zipWith[A, B, C](l1: MyList[A], l2: MyList[B])(op: (A, B) => C): MyList[C] =
      ???
    

    src/main/scala/poly/MyListOps.scala

  2. In previous exercises, we had separate implementations for foldRight and foldRightList (we had to handle the cases of returning an integer and returning an IntList separately).

    Do we need to define a similar foldRightList on polymorphic lists?

    Check Yourself

    No, type variable B can be instantiated to MyList[Int].

  3. Implement these eight higher-order functions (map plus all other ones above) on MyList using pattern matching.

def map[A, B](l: MyList[A])(f: A => B): MyList[B] =
  l match
    case Nil         => Nil
    case Cons(x, xs) => Cons(f(x), map(xs)(f))

def filter[A](l: MyList[A])(p: A => Boolean): MyList[A] =
  l match
    case Nil         => Nil
    case Cons(x, xs) => if p(x) then Cons(x, filter(xs)(p)) else filter(xs)(p)

def foldRight[A, B](l: MyList[A])(f: (A, B) => B, base: B): B =
  l match
    case Nil         => base
    case Cons(x, xs) => f(x, foldRight(xs)(f, base))

def reduceRight[A](l: MyList[A])(f: (A, A) => A): A =
  l match
    case Nil          => throw new IllegalArgumentException("Empty list!")
    case Cons(x, Nil) => x
    case Cons(x, xs)  => f(x, reduceRight(xs)(f))

def forall[A](l: MyList[A])(p: A => Boolean): Boolean =
  l match
    case Nil         => true
    case Cons(x, xs) => p(x) && forall(xs)(p)

def exists[A](l: MyList[A])(p: A => Boolean): Boolean =
  l match
    case Nil         => false
    case Cons(x, xs) => p(x) || exists(xs)(p)

def zip[A, B](l1: MyList[A], l2: MyList[B]): MyList[(A, B)] =
  (l1, l2) match
    case (Cons(x, xs), Cons(y, ys)) => Cons((x, y), zip(xs, ys))
    case _                          => Nil

def zipWith[A, B, C](l1: MyList[A], l2: MyList[B])(op: (A, B) => C): MyList[C] =
  (l1, l2) match
    case (Cons(x, xs), Cons(y, ys)) => Cons(op(x, y), zipWith(xs, ys)(op))
    case _                          => Nil

src/main/scala/poly/MyListOps.scala

Using List APIs

Use the list APIs to:

  1. Implement function elementsAsStrings which converts every element of a list to a string (you may need the .toString function):

    def elementsAsStrings[A](l: MyList[A]): MyList[String] =
      ???
    

    src/main/scala/poly/MyListOps.scala

  2. Reimplement functions from previous exercises on polymorphic lists:

    def length[A](l: MyList[A]): Int =
      ???
    
    def takeWhilePositive(l: MyList[Int]): MyList[Int] =
      ???
    
    def last[A](l: MyList[A]): A =
      ???
    

    src/main/scala/poly/MyListOps.scala

  3. Adapt the string functions capitalizeString and wordCount to operate on lists of characters:

    val capitalizeString: MyList[Char] => MyList[Char] =
      TODO
    
    def wordCount(l: MyList[Char]): Int =
      ???
    

    src/main/scala/poly/MyListOps.scala

    Beware: the solution we gave in week 1 for wordCount doesn’t express itself naturally as a fold… 🔥 try to look for a different one!

    Strings and Lists

    Both String and List[Char] or MyList[Char] represent sequences of characters. However, it’s usually more efficient and convenient to use String for text processing and manipulation in Scala because String has optimized storage for texts and rich APIs tailored for text operations.

    Later this year, we will see a more general trait that covers both Lists and Strings. This will allow us to write unified code for both.

  1. def elementsAsStrings[A](l: MyList[A]): MyList[String] =
      map(l)(_.toString())
    

    src/main/scala/poly/MyListOps.scala

  2. def length[A](l: MyList[A]): Int =
      foldRight(l)((_, acc) => 1 + acc, 0)
    
    def takeWhilePositive(l: MyList[Int]): MyList[Int] =
      foldRight(l)(
        (x, acc) => if x > 0 then Cons(x, acc) else Nil,
        Nil
      )
    
    def last[A](l: MyList[A]): A =
      reduceRight(l)((x, acc) => acc)
    

    src/main/scala/poly/MyListOps.scala

  3. val capitalizeString: MyList[Char] => MyList[Char] = map(_)(c => c.toUpper)
    
    case class WordCountState(count: Int, lastWasWS: Boolean)
    
    def wordCount(l: MyList[Char]): Int =
      foldLeft(l)(
        WordCountState(0, true),
        (state: WordCountState, c: Char) =>
          val cIsWS = c.isWhitespace
          val count = state.count + (if state.lastWasWS && !cIsWS then 1 else 0)
          WordCountState(count, cIsWS)
      ).count
    

    src/main/scala/poly/MyListOps.scala

Part 2: More functions: flatMap and cross-product

flatMap

You may have come across flatMap, a powerful higher-order function that can be used to transform and flatten container datatypes, such as lists.

def flatMap[A, B](l: MyList[A])(f: A => MyList[B]): MyList[B] =
  ???

src/main/scala/poly/MyListOps.scala

The idea of flatMap(f)(l) is:

For example,

object FlatMapExamples:
  val numbers: MyList[Int] = Cons(2, Cons(3, Nil))

  val mapped = map(numbers)((n: Int) =>
    Cons(n, Cons(n * 2, Nil))
  )
  // For simplicity, we write Cons as `::` in the results.
  // Result: (2 :: 4 :: Nil) :: (3 :: 6 :: Nil)

  val flatMapped = flatMap(numbers)((n: Int) =>
    Cons(n, Cons(n * 2, Nil))
  )
  // Result: 2 :: 4 :: 3 :: 6 :: Nil

src/main/scala/poly/MyListOps.scala

  1. Implement flatMap. You may use the append function that we included in the starting code.

  2. Implement flatten using flatMap. flatten takes a list of lists, and returns the concatenation of all the lists list:

    def flatten[A](l: MyList[MyList[A]]): MyList[A] =
      ???
    

    src/main/scala/poly/MyListOps.scala

def flatMap[A, B](l: MyList[A])(f: A => MyList[B]): MyList[B] =
  l match
    case Nil         => Nil
    case Cons(x, xs) => f(x) ++ flatMap(xs)(f)

src/main/scala/poly/MyListOps.scala

def flatten[A](l: MyList[MyList[A]]): MyList[A] =
  flatMap(l)(identity)

src/main/scala/poly/MyListOps.scala

cross-product

The cross-product function, often referred to as the Cartesian product, produces all possible pairs (combinations) of elements from two lists.

def crossProduct[A, B](l1: MyList[A], l2: MyList[B]): MyList[(A, B)] =
  ???

src/main/scala/poly/MyListOps.scala

For example, given a list of main dishes and a list of side dishes, we can use crossProduct to generate all possible meal combinations:

object CrossProductExamples:
  val mains = Cons("burger", Cons("Pizza", Cons("Pasta", Nil)))
  val sides = Cons("Salad", Cons("Soup", Nil))

  val meals = crossProduct(mains, sides)
  // Result:
  // ("burger", "Salad") :: ("burger", "Soup") :: ("Pizza", "Salad") ::
  // ("Pizza", "Soup") :: ("Pasta", "Salad") :: ("Pasta", "Soup") :: Nil

src/main/scala/poly/MyListOps.scala

def crossProduct[A, B](l1: MyList[A], l2: MyList[B]): MyList[(A, B)] =
  flatMap(l1)(a => map(l2)(b => (a, b)))

src/main/scala/poly/MyListOps.scala

Triangles in Directed Graphs

Consider a directed graph given by its set of (directed) edges stored as a list of pairs of nodes:

type NodeId = Int
type DirectedEdge = (NodeId, NodeId)
type DirectedGraph = MyList[DirectedEdge]
type Triangle = (NodeId, NodeId, NodeId)

src/main/scala/poly/DirectedGraph.scala

Define the triangles function that finds all cycles of length 3, with three distinct nodes, in the given graph.

def triangles(edges: DirectedGraph): MyList[Triangle] =
  ???

src/main/scala/poly/DirectedGraph.scala

Hint

You can make use of flatMap, map and filter.

Each cycle should appear only once. For instance, given the edges:

Cons((1, 2), Cons((2, 3), Cons((3, 1), Nil)))

You should return exactly one of the three following possibilities:

(1, 2, 3), (2, 3, 1), (3, 1, 2)

You are free to decide which of the three you return.

def triangles(edges: DirectedGraph): MyList[Triangle] =
  flatMap(edges)(e0 =>
    e0 match
      case (a, b) if a < b =>
        flatMap(edges): e1 => // This colon syntax is equivalent to (e1 =>
          e1 match
            case (`b`, c) if a < c =>
              map(filter(edges)(e3 =>
                e3 match
                  case (`c`, `a`) => true
                  case _          => false
              ))(_ => (a, b, c))
            case _ => Nil
      case _ => Nil
  )

src/main/scala/poly/DirectedGraph.scala

Option Type

In last week’s exercises, we use a custom type LookupResult for the result of looking up in a context:

enum LookupResult:
  case Ok(v: Int)
  case NotFound

It’s always good to explore the Scala standard library. After all, why use a custom type when there is something suitable in the standard library?

Can you find a suitable type for LookupResult?

One suitable choice is already given by the title: the Option type!

Is there any other suitable container in the standard library?

Tuple, Either.

Part 1. Basic Usage

The basic usage of Option type is as the return type of functions that might not always return a valid value.

Implement findFirstEvenNumber to return the first even number in the list, or None is there isn’t one.

def findFirstEvenNumber(l: List[Int]): Option[Int] =
  ???

src/main/scala/poly/RevisitOption.scala

def findFirstEvenNumber(l: List[Int]): Option[Int] =
  l.find(_ % 2 == 0)

def findFirstEvenNumber_ByHand(l: List[Int]): Option[Int] =
  l match
    case Nil                    => None
    case hd :: _ if hd % 2 == 0 => Some(hd)
    case _ :: tl                => findFirstEvenNumber_ByHand(tl)

src/main/scala/poly/RevisitOption.scala

Part 2. Drawing Parallels with List in Standard Library

Notice that Option also has map, flatMap, filter just like List. Do you know why?

Hint

An option is like a list with only one element.

In this part, we use the List (scala.collection.immutable.List) from the standard library.

You can compare the definition of map, flatMap and filter in standard library List methods with Option’s. Do the definitions line up? What’s the difference between the definitions on scala.collection.immutable.List and our custom polymorphic lists poly.List?

  1. Implement parseStringToInt and findSquareRoot. Then, define findSquartRootFromString to chain these two functions to parse a string and find its square root.

    def parseStringToInt(s: String): Option[Int] =
      ???
    
    def findSquareRoot(n: Int): Option[Double] =
      ???
    
    def findSquareRootFromString(s: String): Option[Double] =
      ???
    

    src/main/scala/poly/RevisitOption.scala

  2. 🔜 Given a list of strings representing integers:

    val numberStrings: List[String] = List("1", "2", "star", "4")
    

    Try to use map to convert them in integers. What issues do you face?

    Now, use the member method flatMap of scala.collection.immutable.List and the parseStringToInt function to safely convert them.

    val numbers =
      TODO
    

    src/main/scala/poly/RevisitOption.scala

    Check Yourself 🔥

    Can you do the same trick using our custom lists poly.List and definition of flatMap instead? Why or why not?

    Solution

    No.

    The fact that we can line up List and Option easily is because in the standard library, both List and Option are subtypes of IterableOnce, and signatures of useful methods make use of the supertype InterableOnce. For example, the signature of flatMap in List is def flatMap[B](f: A => IterableOnce[B]): List[B].

    We will cover this more advanced API at two points later in the course: first to introduce comprehensions, and then more generally monads.

  1. def parseStringToInt(s: String): Option[Int] =
      s.toIntOption
    
    def findSquareRoot(n: Int): Option[Double] =
      if n >= 0 then Some(Math.sqrt(n)) else None
    
    def findSquareRootFromString(s: String): Option[Double] =
      parseStringToInt(s).flatMap(findSquareRoot)
    

    src/main/scala/poly/RevisitOption.scala

  2. val numbers =
      numberStrings.flatMap(parseStringToInt)
    

    src/main/scala/poly/RevisitOption.scala

FoldLeft and Tail Recursion 🧪

We say that a function is tail recursive if the last thing it does along all of its code paths is to call itself. For example, the following function is not tail recursive:

def length0(l: MyList[Int]): Int = l match
  case Nil         => 0
  case Cons(x, xs) => 1 + length0(xs)

src/main/scala/poly/MyListOps.scala

Indeed, after calling itself recursively, the function length0 adds one to its own result.

In contrast, the inner loop function below is tail recursive:

def lengthTR(l: MyList[Int]): Int =
  def length(l: MyList[Int], prefixLength: Int): Int = l match
    case Nil         => prefixLength
    case Cons(x, xs) => length(xs, prefixLength + 1)
  length(l, 0)

src/main/scala/poly/MyListOps.scala

Indeed, it does not do anything further after calling itself with an incremented prefixLength. This property allows the compiler to optimize the recursion completely: the function named length above will in fact be converted to a simple for loop by the compiler.

🔜 We will learn much more about tail recursion in week 11.

Reasoning about tail recursion

Use the substitution method to evaluate length0 and lengthTR on various inputs. Do they return the same thing? Can you conjecture an equation relating the inner length function to length0?

They return the same results. The following equation holds: ∀ l n. length(l, n) = n + length0(l).

Sum

  1. Is the following function tail-recursive?

    def sum0(l: MyList[Int]): Int = l match
      case Nil         => 0
      case Cons(x, xs) => x + sum0(xs)
    

    src/main/scala/poly/MyListOps.scala

  2. What happens if you run uncomment the following test, which runs sum0 on a list with 50000 elements?

    // test("sum0: large list"):
    //   assertEquals(sum0(manyNumbers1), N)
    

    src/test/scala/poly/MyListOpsTest.scala

  3. Can you think of a tail-recursive way to write sum?

    def sum1(l: MyList[Int]): Int =
      // @tailrec // Uncomment this line.
      def sum(l: MyList[Int], acc: Int): Int =
      ???
      sum(l, 0)
    

    src/main/scala/poly/MyListOps.scala

    In Scala, the `@tailrec` annotation is a directive for the compiler, indicating that the annotated method should be tail-recursive. If the method is not tail-recursive, the compiler will raise a compile-time error.
  4. What happens if you run sum1 with a very long list?

  1. Not tail-recursive.

  2. A “stack overflow”: each recursive call consumes a bit of memory, until space runs out.

  3. Here is one version:

    def sum1(l: MyList[Int]): Int =
      @tailrec
      def sum(l: MyList[Int], acc: Int): Int =
        l match
          case Nil         => acc
          case Cons(x, xs) => sum(xs, acc + x)
      sum(l, 0)
    

    src/main/scala/poly/MyListOps.scala

  4. This time there’s no stack overflow, because the compiler was able to eliminate recursion.

FoldLeft

Similar to foldRight, foldLeft processes the list from the leftmost (head) element to the rightmost element.

The main difference between foldLeft and foldRight is that foldLeft is typically implemented using tail recursion, while foldRight is the opposite.

  1. Define foldLeft:

    // @tailrec // Uncomment this line.
    def foldLeft[A, B](l: MyList[A])(base: B, f: (B, A) => B): B =
      ???
    

    src/main/scala/poly/MyListOps.scala

  2. Define sum0Fold using foldRight, define sum1Fold using foldLeft:

    def sum0Fold(l: MyList[Int]): Int =
      ???
    def sum1Fold(l: MyList[Int]): Int =
      ???
    

    src/main/scala/poly/MyListOps.scala

  3. Reimplement reverseAppend using foldLeft:

    def reverseAppend[A](l1: MyList[A], l2: MyList[A]): MyList[A] =
      ???
    

    src/main/scala/poly/MyListOps.scala

  4. Implement countEven and totalLength using foldLeft. CountEven takes a list of integers and returns the number of even integers in the list; totalLength takes a list of strings and return the sum of each string’s length.

    val countEven: MyList[Int] => Int =
      TODO
    
    val totalLength: MyList[String] => Int =
      TODO
    

    src/main/scala/poly/MyListOps.scala

  1. @tailrec
    def foldLeft[A, B](l: MyList[A])(base: B, f: (B, A) => B): B =
      l match
        case Nil         => base
        case Cons(x, xs) => foldLeft(xs)(f(base, x), f)
    

    src/main/scala/poly/MyListOps.scala

  2. def sum0Fold(l: MyList[Int]): Int =
      foldRight(l)((x, acc) => x + acc, 0)
    def sum1Fold(l: MyList[Int]): Int =
      foldLeft(l)(0, (acc, x) => x + acc)
    

    src/main/scala/poly/MyListOps.scala

  3. def reverseAppend[A](l1: MyList[A], l2: MyList[A]): MyList[A] =
      foldLeft(l1)(l2, (acc, x) => Cons(x, acc))
    

    src/main/scala/poly/MyListOps.scala

  4. val countEven: MyList[Int] => Int =
      (l: MyList[Int]) =>
        foldLeft(l)(
          0,
          (acc, x) =>
            (if x % 2 == 0 then 1 else 0) + acc
        )
    
    val totalLength: MyList[String] => Int =
      l => foldLeft(map((l))(_.length))(0, _ + _)
    

    src/main/scala/poly/MyListOps.scala

Currying and Composition

Reminder

You can check the previous exercises for currying and composition in Week 1: Higher-order Functions.

CurriedZipWith

Use map and zip to implement the curried version curriedZipWith of zipWith.

Defining polymorphic function values

Reference: Polymorphic Function Types.

// A polymorphic method:
def foo[A](xs: List[A]): List[A] = ???

// A polymorphic function value:
val bar = [A] => (xs: List[A]) => foo(xs)

src/main/scala/poly/MyListOps.scala

bar has type [A] => List[A] => List[A]. This type describes function values which take a type A as a parameter, then take a list of type List[A], and return a list of the same type List[A].

val curriedZipWith: [A, B, C] => ((A, B) => C) => MyList[A] => MyList[B] => MyList[C] =
  TODO

src/main/scala/poly/MyListOps.scala

val curriedZipWith: [A, B, C] => ((A, B) => C) => MyList[A] => MyList[B] => MyList[C] =
  [A, B, C] =>
    (op: (A, B) => C) =>
      (l1: MyList[A]) =>
        (l2: MyList[B]) =>
          map(zip(l1, l2))(t => op(t._1, t._2))

src/main/scala/poly/MyListOps.scala

Polymorphic Composition 🔥

  1. In previous exercises we defined a function compose to compose functions f: Int => Double and g: Double => String. Generalize this function to arbitrary pairs of types, using polymorphic argument types.

  2. What is the neutral element for the generalized compose?

  3. In previous exercises, we defined andLifter and notLifter for functions on Int. To make it more general, we can define andLifter for functions of arbitrary input types:

    def andLifter[A](f: A => Boolean, g: A => Boolean): A => Boolean =
      a => f(a) && g(a)
    

    src/main/scala/poly/fun.scala

    … and we can generalize further! Look at the following four functions; do they have anything in common?

    def orLifter[A](f: A => Boolean, g: A => Boolean): A => Boolean =
      a => f(a) || g(a)
    def sumLifter[A](f: A => Int, g: A => Int): A => Int =
      a => f(a) + g(a)
    def listConcatLifter[A, B](f: A => MyList[B], g: A => MyList[B]): A => MyList[B] =
      a => f(a) ++ g(a)
    

    src/main/scala/poly/fun.scala

    Write a binaryLifter higher-order function to capture the common pattern above, and use it to rewrite all four lifters that we’ve seen up to this point.

    def binaryLifter[A, B, C](f: A => B, g: A => B)(op: (B, B) => C): A => C =
      ???
    

    src/main/scala/poly/fun.scala

    def andLifter1[A](f: A => Boolean, g: A => Boolean) =
      ???
    
    def orLifter1[A](f: A => Boolean, g: A => Boolean) =
      ???
    
    def sumLifter1[A](f: A => Int, g: A => Int) =
      ???
    
    def listConcatLifter1[A, B](f: A => MyList[B], g: A => MyList[B]) =
      ???
    

    src/main/scala/poly/fun.scala

  4. Similarly, we can implement a unaryLifter to generate lifters like notLifter. Can you tell which function unaryLifter essentially is?

  1. def compose[A, B, C](f: B => C, g: A => B): A => C =
      (a: A) => f(g(a))
    

    src/main/scala/poly/fun.scala

  2. def id[A](x: A) = x
    

    src/main/scala/poly/fun.scala

  3. def binaryLifter[A, B, C](f: A => B, g: A => B)(op: (B, B) => C): A => C =
      a => op(f(a), g(a))
    

    src/main/scala/poly/fun.scala

  4. The unaryLifter can be implemented as:

    def unaryLifter[A, B, C](f: A => B)(op: B => C): A => C =
      a => op(f(a))
    

    It’s the curried version of compose.