Last updated on

Monads

Welcome to week 14 of CS-214 — Software Construction!

As usual, ⭐️ indicates the most important exercises and questions and 🔥 indicates the most challenging ones.

You do not need to complete all exercises to succeed in this class, and you do not need to do all exercises in the order they are written.

We strongly encourage you to solve the exercises on paper first, in groups. After completing a first draft on paper, you may want to check your solutions on your computer. To do so, you can download the scaffold code (ZIP).

What are monads? ⭐️

A monad can be thought of as a wrapper for values, providing useful ways to chain together operations on those values, and to handle things like side effects and context or state.

You have already worked with monads: for example both the List collection and the Option class in Scala are monads.

A monad has two fundamental operations: unit and flatMap:

trait Monad[M[_]]:
  def unit[A](a: A): M[A]
  def flatMap[A, B](ma: M[A], f: A => M[B]): M[B]

(In the lecture we saw flatMap as an extension method; here, for a change, we see it as a method on a Monad trait.)

M[_] means “a type taking one type parameter”. It can equivalently be written M <: [X] =>> Any, meaning “a subtype of type-level functions that take one type and return a type.”

unit[A] is “boxing” an element of type A inside a monad M[A]. flatMap is applying a transformation to the element(s) in the monad and combining the results.

Monads follow three laws:

  1. Try to think of types in Scala you know that implement these two functions and respect these laws. What do these functions enable? Do they have a different name?

  2. How would you implement map and flatten using only unit and flatMap? Try to first reason with List. How does this apply to the other types you thought of in the previous question?

  3. Can any monad be used in a for-comprehension or are there more pre-requisites? If so, can you change these monads to be used in a for-comprehension?

Some fundamental types of monads

Let’s try implementing some monads. For all exercises in this part, we will use the following trait:

trait Monad[M[_]]:
  def unit[A](a: A): M[A]
  def flatMap[A, B](ma: M[A])(f: A => M[B]): M[B]

src/main/scala/monads/Monads.scala

(Note that we made flatMap a regular method of the Monad trait instead of an extension method, so as to put both methods in a single trait.)

Identity monad

We’ll start with the simplest possible monad: the identity monad. This monad doesn’t do much, besides simply containing a value and satisfying the monadic laws.

  1. Implement the unit function for the identity monad
def unit[A](a: A): IdentityM[A] =
  ???

src/main/scala/monads/Monads.scala

  1. Implement the flatMap function for the identity monad scaffold
def flatMap[A, B](ma: IdentityM[A])(f: A => IdentityM[B]): IdentityM[B] =
  ???

src/main/scala/monads/Monads.scala

Option as a monad ⭐️

You are already familiar with this kind of monad because you have already used it in this course! It is sometimes called Maybe in Haskell and Option in Scala, and can have the values None or Some(x). This kind of monad provides a way to deal with undefined results.

This means we can chain operations on this monad that may or may not fail, but without having to throw exceptions: for example, operations on Ints that include divisions by a value which may or may not be zero.

Use the following OptionM class to represent computations that may fail:

enum OptionM[+A]:
  case Some(a: A)
  case None

src/main/scala/monads/Monads.scala

  1. Implement unit for the option monad:
def unit[A](a: A): OptionM[A] =
  ???

src/main/scala/monads/Monads.scala

  1. Implement flatMap for the option monad:
def flatMap[A, B](ma: OptionM[A])(f: A => OptionM[B]): OptionM[B] =
  ???

src/main/scala/monads/Monads.scala

State monad ⭐️

Let’s move on to a state monad. A state monad allows us to attach state information to a calculation. The state monad contains a function that takes a state and returns an intermediate value and a new state value.

Here is an example of a state monad. Consider evaluating an expression e1 ++ e2 where e1 and e2 use and modify a mutable variable. For example, let this mutable variable be an integer counter. Then, this is a valid expression:

var counter: Int = 0 {
  counter = counter + 1; List(counter, counter + 10)
} ++ {
  counter = counter + 5; List(counter + 2, counter + 100)
}

To represent this in a functional way, we want to give meaning to both e1 and e2. Because these expressions depend on the counter, we will represent them as a function Int => .... Because they change the state, they should give Int as one part of the return value. But they are expressions of type List[Int], so they should also return that. Thus, we model each of these expressions inside {...} as functions of the type Int => (Int, List[Int]).

The first expression is the function that behaves as:

count =>
  (count + 1, List(count + 1, count + 11))

The second expression is the function that behaves as:

count =>
  (count + 5, List(count + 7, count + 105))

We introduce a type State[A] to represent functions A => (S,A) where S is the type of the state (here: Int). Then, both of the previous expressions have types State[List[Int]]. We define flatMap on State[A] and then translate the expression of the form e1 ++ e2 into

e1.flatMap(list1 => e2.flatMap(list2 => unit(list1 ++ list2)))

or, using for expressions:

for list1 <- e1
    list2 <- e2
yield list1 ++ list2

This shows how state monads are useful for situations where we have a mutable state that may change when we perform operations, but also have a return value. Now it’s your turn!

Use the following StateM type to represent stateful computations:

case class StateM[S, A](run: S => (S, A))

object StateM:
  // Get state at this point
  def get[S](): StateM[S, S] =
    StateM(s => (s, s))

  // Replace the state
  def put[S](s: S): StateM[S, Unit] =
    StateM(_ => (s, ()))

  // Update the state
  def modify[S](f: S => S): StateM[S, Unit] =
    StateM((s: S) => (f(s), ()))

src/main/scala/monads/Monads.scala

  1. Implement unit for the state monad:
def unit[A](a: A): StateM[S, A] =
  ???

src/main/scala/monads/Monads.scala

  1. Implement flatMap for the state monad:
def flatMap[A, B](ma: StateM[S, A])(f: A => StateM[S, B]): StateM[S, B] =
  ???

src/main/scala/monads/Monads.scala

Reader monad

A reader monad allows a computation to depend on values from a shared context. It is somewhat similar to a state monad, except here our computation doesn’t actually modify our state.

Use the following ReaderM type to represent computations that read from the environment:

case class ReaderM[R, A](run: R => A)

object ReaderM:
  // Get current context
  def ask[R](): ReaderM[R, R] =
    ReaderM(r => r)

  // Do operation in modified context
  def local[R, A](f: (R => R), m: ReaderM[R, A]): ReaderM[R, A] =
    ReaderM(r => m.run(f(r)))

src/main/scala/monads/Monads.scala

  1. Implement unit for the reader monad:
def unit[A](a: A): ReaderM[R, A] =
  ???

src/main/scala/monads/Monads.scala

  1. Implement flatMap for the reader monad:
def flatMap[A, B](ma: ReaderM[R, A])(f: A => ReaderM[R, B]): ReaderM[R, B] =
  ???

src/main/scala/monads/Monads.scala

Writer monad

A writer monad is used when we are chaining operations but want to keep track of some accumulating auxiliary output as we go. Here, our auxiliary output is represented in a generic way by a monoid. A monoid is simply a set that has an identity element and an associative binary operation:

trait Monoid[A]:
  def munit: A
  def mcombine(x: A)(y: A): A

src/main/scala/monads/Monads.scala

This could be as simple as a string that we use to log information about the operations done: the identity element is the empty string, and string concatenation is our associative binary operation. All of the operations we want to chain will compute two things: the actual operation to be done, and a string to add to the logs.

Use the following WriteM type to represent computations that produce output:

case class WriterM[W, A](run: (W, A))

src/main/scala/monads/Monads.scala

  1. Implement unit for the writer monad:
def unit[A](a: A): WriterM[W, A] =
  ???

src/main/scala/monads/Monads.scala

  1. Implement flatMap for the writer monad:
def flatMap[A, B](ma: WriterM[W, A])(f: A => WriterM[W, B]): WriterM[W, B] =
  ???

src/main/scala/monads/Monads.scala

Try

In previous exercises you have learned about Try. Try[A] is an enum with two cases: either a Success(a: A) or a Failure(t: Throwable). Try is very similar to Option from above: a Success(a) is just like a Some(a), and a Failure is a more informative version of None (it records what went wrong).

And, just like with Option, we can define monadic combinators on Try:

Accordingly, Try values can be used in for comprehensions:

for x <- Success("Hello, ")
    y <- Success("World!")
yield x + y // Success(Hello, World!)

src/main/scala/playground.worksheet.sc

for x <- Success("Hello, ")
    y <- Failure(Exception("Woops"))
yield x + y // Failure(Exception: Woops)

src/main/scala/playground.worksheet.sc

But Try, in Scala, does more than this: it has a special constructor Try that takes a lazy value => a and converts any native (JVM) exceptions into a Failure object. This exception-catching behavior also exists in flatMap:

Try(throw Exception("Woops")) // Failure(Exception: Woops)

src/main/scala/playground.worksheet.sc

Success(1).flatMap(x => throw Exception("Woops")) // Failure(Exception: Woops)

src/main/scala/playground.worksheet.sc

Notice how neither of these let the exception escape: instead, exceptions are “reified” as Failure values.

What consequence does this have on monad laws? Let’s find out!

Use the following TryM class to represent computations that may fail:

enum TryM[+A]:
  case Success(a: A)
  case Failure(ex: Throwable)

src/main/scala/monads/Try.scala

TryM on pure values

In this exercise, assume that all computations are pure and do not throw exceptions.

  1. Implement unit, assuming that no exceptions are thrown:
def unit[A](a: A): TryM[A] =
  ???

src/main/scala/monads/Try.scala

  1. Implement flatMap, assuming that no exceptions are thrown:
def flatMap[A, B](ma: TryM[A])(f: A => TryM[B]): TryM[B] =
  ???

src/main/scala/monads/Try.scala

  1. Show that these two functions respect the usual monad laws, assuming that no exceptions are thrown.

TryM on impure values

Let us now generalize the problem to handle impure computation. We want unit and flatMap to capture exceptions and turn them into values, so the first thing we need to do is to change the signature of unit:

  1. Implement unit, making sure to capture any exceptions.
def unit[A](a: => A): TryM[A] =
  ???

src/main/scala/monads/Try.scala

  1. Implement flatMap, making sure to capture any exceptions:
def flatMap[A, B](ma: TryM[A])(f: A => TryM[B]): TryM[B] =
  ???

src/main/scala/monads/Try.scala

  1. Show that these definitions violate one or more of the monad laws:

    1. Translate each monad law into tests:

      def assertAssociative[A, B, C](t: M[A], f: A => M[B], g: B => M[C]) =
        ???
      

      src/test/scala/TrySuite.scala

      def assertLeftIdentity[A, B](x: => A, f: A => M[B]) =
        ???
      

      src/test/scala/TrySuite.scala

      def assertRightIdentity[A](t: M[A]) =
        ???
      

      src/test/scala/TrySuite.scala

    2. Construct a failing test using one of these laws and a function or expression that throws an exception:

      test("One of the monads laws fails for effectful `Try`.".fail) {
        // Find an invocation of one of the lemmas above that fails.
      }
      

      src/test/scala/TrySuite.scala

Proving monad laws ⭐️

The State monad is a powerful construct in functional programming used for managing state in a purely functional way. It encapsulates state transformations and allows for stateful computations without mutating state directly.

Given your implementation of the State monad, let’s discuss the left unit law:

For any value a and function f that returns a monad:

monad.unit(a).flatMap(f) == f(a)

Write on paper an equational proof for the left unit law using the provided implementation of the State monad.

Functors

Monads are often defined with a map method in addition to flatMap:

trait M[T] {
  def flatMap[U](f: T => M[U]): M[U]

  def map[U](f: T => U): M[U]
}

Where map and flatMap are related by the following law:

Monad/Functor Consistency:
m.map(f) === m.flatMap(x => unit(f(x)))

We introduce you a new algebraic structure called Functor-s. We say that a type F is a Functor if F[T] has a map method with the following signature:

trait F[T] {
  def map[U](f: T => U): F[U]
}

And there is a unit method for F with the following signature:

def unit[T](x: T): F[T]

Such that map and unit fulfill the following laws:

Identity:
m.map(x => x) === m

Associativity:
m.map(h).map(g) === m.map(x => g(h(x)))

Prove that any Monad with a map method that fulfills the Monad/Functor Consistency law is also a Functor.

Throwback: evaluator ⭐️

In this exercise, we will explore section 2 of the paper Monads for functional programming. Reading the paper is not needed to continue with the exercises. The goal is to demonstrate how monads can promote ease of programming.

Evaluator - error handling with monads

You have implemented a calculator in a previous lab. Recall that each operation corresponded to a different case in an eval function. For example, the division case was:

def eval(expr: Expr): Int = expr match
  // …
  case Division(numerator, denominator) =>
    eval(numerator) / eval(denominator)

When adding the division to the calculator, the division by zero case must be handled, otherwise an ArithmeticException is thrown. The way that you tackled this problem was to “box” the result in a case class:

sealed trait Result
object DivByZero extends Result
case class Ok(value: Int) extends Result

src/main/scala/monads/Evaluator.scala

This leads to this version of the Division case:

def eval(expr: Expr): Result = expr match
  // …
  case Division(numerator, denominator) =>
    (eval(numerator), eval(denominator)) match
      case (Ok(v1), Ok(0))  => DivByZero
      case (Ok(v1), Ok(v2)) => Ok(v1 / v2)
      case (_, _)           => DivByZero

Recalling the calculator lab, how did the changes needed to handle the DivByZero case affect the eval function?

Solution

All other cases of the evaluator had to be updated to support the DivByZero, even though they are not directly related to the change. They should therefore not require any development effort.

Specifically, the add case had to be updated to handle the DivByZero case:

def evaluate(e: BasicExpr): EvalResult =
  e match
    case Add(e1, e2) =>
      (evaluate(e1), evaluate(e2)) match
        case (Ok(v1), Ok(v2)) => Ok(v1 + v2)
        case (_, _)           => DivByZero
  // …

A similar procedure had to be applied to all other cases.

The downside is clear: to add support for some new operation, all cases need to be modified, even the ones that are not related to the new feature!

Modularize the evaluator - combine

Let’s extract the logic dealing with DivByZero and Ok in a new function to better modularize. Let’s call that function combine. This function takes two results of type Result and a function f: (Int, Int) => Int, and combines the two results if they are both of the Ok type, or returns a DivByZero in case one of the two Results is DivByZero.

Let’s implement this function:

def combine(x: Result, y: Result, f: (Int, Int) => Result) =
    ???

src/main/scala/monads/Evaluator.scala

Now let’s modify eval to use this new combine function:

def eval(e: Expr): Result = e match
  case Constant(a) => ???
  case Add(a, b) => ??? // combine(???)
  case Division(num, den) => ??? // combine(???)

src/main/scala/monads/Evaluator.scala

Are there parallels between the combine function and the flatMap function? What are the common points? What are the differences?

Combine using flatMap

Our function combine depends on two values. How can we rewrite the combine function so that it takes only one argument of type Result, namely a flatMap function?

Let’s implement the flatMap function for our Result class. When done, implement the function combineFlatMap which does the same thing as combine but by using flatMap.

def flatMap(x: Result, f: Int => Result) =
  x match
    case DivByZero => ???
    case Ok(value) => ???

def combineFlatMap(x: Result, y: Result, f: (Int, Int) => Result) =
  ??? /* calls to flatMap */

src/main/scala/monads/Evaluator.scala

Monadic version of eval using flatMap

Do we need both combine and flatMap? No, we will write eval using only flatMap.

Let us now write this eval version without combine, by using only flatMap:

def evalWithoutCombine(e: Expr): Result = e match
  case Constant(a) => ???
  case Add(a, b) => ???
  case Division(num, den) => ???

src/main/scala/monads/Evaluator.scala

We now have a version of the eval function that returns a monad instead of a simple Int. This allows us to handle the error in the division case without modifying the other cases, namely the addition. Now let’s explore other features that we can add to our calculator and how using monads can help.

Adding more features to monadic eval

At this point, we modified the eval function to use flatMap. We had to modify every case in the function, so did we really save some development time? We will now add other features to our calculator and we will see how the work we did helps for future modifications.

For each new feature, we will compare the modifications needed for the eval functions in the two cases: with and without monads. We will see how the monads help us keep clean and concise modifications and final code base.

The two stateful features we will add to the evaluator are:

  1. Keeping track of how many additions have been performed.
  2. Keeping track of what operations have been performed, i.e. logging.

Keeping track of the number of additions

We start with the version of eval that simply returns an Int and throws an exception upon a division by zero. To avoid many changes, we will use the following type definition:

type Result = Int

Going back to the Ok/DivByZero case class is a matter of changing this type definition and modifying accordingly the eval function.

Without monads

Implement the addition tracker without monads by modifying the following eval function:

def evalStateful(e: Expr, addNumber: Int = 0): (Result, Int) =
  ???

src/main/scala/monads/Evaluator.scala

Think about how you would have done it in an Object Oriented programming language like Java or Python.

You would keep a global variable that would contain the counter that you would increment in the add case. It is interesting to note that this would introduce challenges when working in a concurrent or parallel environment. This variable would need to be thread-local or protected by a lock.

Functional programming enforces immutability, which prevents us from doing it that way. Instead, we are forced to pass the counter around and therefore modifying heavily the code. This would however protect us from concurrency issues.

As we will see later on, Monads will help us to get some of this OO conciseness back.

With monads

Now let’s implement the same feature using monads.

Start by thinking about what the result type would be. What monad should we use?

A tuple (Result, Int) is not the way to go. Think about what it would imply for flatMap.

Solution

It is not appropriate because flatMap would need to change. We want to preserve the following definition of flatMap:

def flatMap(m: Monad, f: Result => Monad)

In this case, the idea is to use a function as the return type:

type Result = Int
type State = Int
type Monad = State => (Result, State)

src/main/scala/monads/Evaluator.scala

Implement the unit, flatMap and eval functions for the monad type you picked:

def unit(a: Result): Monad =
  ???
extension (m: Monad)
  def flatMap(f: Result => Monad): Monad =
  ???

src/main/scala/monads/Evaluator.scala

Then modify the eval function to keep track of the number of additions as before, but using this monad as return type:

def eval(t: Expr): Monad = t match
  case Constant(a) => unit(a)
  case Add(a, b) =>
    eval(a).flatMap(a =>
      ???
    )
  case Division(a, b) =>
    eval(a).flatMap(a =>
      ???
    )

src/main/scala/monads/Evaluator.scala

Compare the two approaches

Which one required more work? How far is your implementation of the eval function from the evalStateful one?

Another question to ask yourself when adding features like these is: how maintainable is the code?

Maintainability is indeed crucial in a software project as it is very likely that new features will be added in the future, or bugs fixed, or library updated, and so on. It is therefore a key aspect when writing software.

Logging

In this part, we will add a logging feature to the evaluator. This time, implementing the version without monads is optional. Feel free to do it if you want to see how monads help in this case.

We are ignoring the DivByZero case for the logging feature for now.

Here is the definition of the monad that we are going to use; implement the unit and flatMap functions:

type Logs = List[String]
type Monad[A] = (A, Logs)

def unit[A](a: A): Monad[A] =
  ???

extension [A](m: Monad[A])
  def flatMap[B](f: A => Monad[B]): Monad[B] =
  ???

src/main/scala/monads/Evaluator.scala

Note that the monad is now generically typed.

Here are two helper functions to manipulate the log and create message strings: out and line.

def out(output: String): Monad[Unit] =
  ((), List(output))
def line(t: Expr, res: Int) =
  f"eval($t) ⇐ $res"

src/main/scala/monads/Evaluator.scala

Implement now the eval function that also logs the performed operations, using the monad type defined above. You can use the out and line functions to log the operations.

def eval(t: Expr): Monad[Int] = t match
  case Constant(a) => ???
  case Add(a, b) => ???
  case Division(a, b) => ???

src/main/scala/monads/Evaluator.scala

DivByZero and logging 🔥

Now we will add the DivByZero case to the logging feature. This exercise is optional and no solutions are provided. Feel free to write new tests based on existing ones.

Logging and counting additions 🔥

Another optional exercise: combine the counting of additions and the logging features. No solutions are provided. Feel free to write new tests based on existing ones.