Last updated on
Monads
Welcome to week 14 of CS-214 — Software Construction!
As usual, ⭐️ indicates the most important exercises and questions and 🔥 indicates the most challenging ones.
You do not need to complete all exercises to succeed in this class, and you do not need to do all exercises in the order they are written.
We strongly encourage you to solve the exercises on paper first, in groups. After completing a first draft on paper, you may want to check your solutions on your computer. To do so, you can download the scaffold code (ZIP).
What are monads? ⭐️
A monad can be thought of as a wrapper for values, providing useful ways to chain together operations on those values, and to handle things like side effects and context or state.
You have already worked with monads: for example both the List
collection and the Option
class in Scala are monads.
A monad has two fundamental operations: unit
and flatMap
:
trait Monad[M[_]]:
def unit[A](a: A): M[A]
def flatMap[A, B](ma: M[A], f: A => M[B]): M[B]
(In the lecture we saw flatMap
as an extension method; here, for a change, we see it as a method on a Monad
trait.)
M[_]
means “a type taking one type parameter”. It can equivalently be written M <: [X] =>> Any
, meaning “a subtype of type-level functions that take one type and return a type.”
unit[A]
is “boxing” an element of type A
inside a monad M[A]
. flatMap
is applying a transformation to the element(s) in the monad and combining the results.
Monads follow three laws:
- Associativity:
m.flatMap(f).flatMap(g) == m.flatMap(f(_).flatMap(g))
- Left identity:
unit(x).flatMap(f) == f(x)
- Right identity:
m.flatMap(unit) == m
-
Try to think of types in Scala you know that implement these two functions and respect these laws. What do these functions enable? Do they have a different name?
-
How would you implement
map
andflatten
using onlyunit
andflatMap
? Try to first reason withList
. How does this apply to the other types you thought of in the previous question? -
Can any monad be used in a
for
-comprehension or are there more pre-requisites? If so, can you change these monads to be used in afor
-comprehension?
Some fundamental types of monads
Let’s try implementing some monads. For all exercises in this part, we will use the following trait:
trait Monad[M[_]]:
def unit[A](a: A): M[A]
def flatMap[A, B](ma: M[A])(f: A => M[B]): M[B]
src/main/scala/monads/Monads.scala
(Note that we made flatMap
a regular method of the Monad
trait instead of an extension method, so as to put both methods in a single trait.)
Identity monad
We’ll start with the simplest possible monad: the identity monad. This monad doesn’t do much, besides simply containing a value and satisfying the monadic laws.
- Implement the
unit
function for the identity monad
def unit[A](a: A): IdentityM[A] =
???
src/main/scala/monads/Monads.scala
- Implement the
flatMap
function for the identity monad scaffold
def flatMap[A, B](ma: IdentityM[A])(f: A => IdentityM[B]): IdentityM[B] =
???
src/main/scala/monads/Monads.scala
Option as a monad ⭐️
You are already familiar with this kind of monad because you have already used it in this course! It is sometimes
called Maybe
in Haskell and Option
in Scala, and can have the values None
or Some(x)
. This kind of monad
provides a way to deal with undefined results.
This means we can chain operations on this monad that may or may not fail, but without having to throw exceptions: for
example, operations on Int
s that include divisions by a value which may or may not be zero.
Use the following OptionM
class to represent computations that may fail:
enum OptionM[+A]:
case Some(a: A)
case None
src/main/scala/monads/Monads.scala
- Implement
unit
for the option monad:
def unit[A](a: A): OptionM[A] =
???
src/main/scala/monads/Monads.scala
- Implement
flatMap
for the option monad:
def flatMap[A, B](ma: OptionM[A])(f: A => OptionM[B]): OptionM[B] =
???
src/main/scala/monads/Monads.scala
State monad ⭐️
Let’s move on to a state monad. A state monad allows us to attach state information to a calculation. The state monad contains a function that takes a state and returns an intermediate value and a new state value.
Here is an example of a state monad. Consider evaluating an expression e1 ++ e2
where e1
and e2
use and modify a
mutable variable. For example, let this mutable variable be an integer counter. Then, this is a valid expression:
var counter: Int = 0 {
counter = counter + 1; List(counter, counter + 10)
} ++ {
counter = counter + 5; List(counter + 2, counter + 100)
}
To represent this in a functional way, we want to give meaning to both e1
and e2
. Because these expressions depend
on the counter, we will represent them as a function Int => ...
. Because they change the state, they should give Int
as one part of the return value. But they are expressions of type List[Int]
, so they should also return that. Thus, we
model each of these expressions inside {...}
as functions of the type Int => (Int, List[Int])
.
The first expression is the function that behaves as:
count =>
(count + 1, List(count + 1, count + 11))
The second expression is the function that behaves as:
count =>
(count + 5, List(count + 7, count + 105))
We introduce a type State[A]
to represent functions A => (S,A)
where S
is the type of the state (here: Int
).
Then, both of the previous expressions have types State[List[Int]]
. We define flatMap
on State[A]
and then
translate the expression of the form e1 ++ e2
into
e1.flatMap(list1 => e2.flatMap(list2 => unit(list1 ++ list2)))
or, using for
expressions:
for list1 <- e1
list2 <- e2
yield list1 ++ list2
This shows how state monads are useful for situations where we have a mutable state that may change when we perform operations, but also have a return value. Now it’s your turn!
Use the following StateM
type to represent stateful computations:
case class StateM[S, A](run: S => (S, A))
object StateM:
// Get state at this point
def get[S](): StateM[S, S] =
StateM(s => (s, s))
// Replace the state
def put[S](s: S): StateM[S, Unit] =
StateM(_ => (s, ()))
// Update the state
def modify[S](f: S => S): StateM[S, Unit] =
StateM((s: S) => (f(s), ()))
src/main/scala/monads/Monads.scala
- Implement
unit
for the state monad:
def unit[A](a: A): StateM[S, A] =
???
src/main/scala/monads/Monads.scala
- Implement
flatMap
for the state monad:
def flatMap[A, B](ma: StateM[S, A])(f: A => StateM[S, B]): StateM[S, B] =
???
src/main/scala/monads/Monads.scala
Reader monad
A reader monad allows a computation to depend on values from a shared context. It is somewhat similar to a state monad, except here our computation doesn’t actually modify our state.
Use the following ReaderM
type to represent computations that read from the environment:
case class ReaderM[R, A](run: R => A)
object ReaderM:
// Get current context
def ask[R](): ReaderM[R, R] =
ReaderM(r => r)
// Do operation in modified context
def local[R, A](f: (R => R), m: ReaderM[R, A]): ReaderM[R, A] =
ReaderM(r => m.run(f(r)))
src/main/scala/monads/Monads.scala
- Implement
unit
for the reader monad:
def unit[A](a: A): ReaderM[R, A] =
???
src/main/scala/monads/Monads.scala
- Implement
flatMap
for the reader monad:
def flatMap[A, B](ma: ReaderM[R, A])(f: A => ReaderM[R, B]): ReaderM[R, B] =
???
src/main/scala/monads/Monads.scala
Writer monad
A writer monad is used when we are chaining operations but want to keep track of some accumulating auxiliary output as we go. Here, our auxiliary output is represented in a generic way by a monoid. A monoid is simply a set that has an identity element and an associative binary operation:
trait Monoid[A]:
def munit: A
def mcombine(x: A)(y: A): A
src/main/scala/monads/Monads.scala
This could be as simple as a string that we use to log information about the operations done: the identity element is the empty string, and string concatenation is our associative binary operation. All of the operations we want to chain will compute two things: the actual operation to be done, and a string to add to the logs.
Use the following WriteM
type to represent computations that produce output:
case class WriterM[W, A](run: (W, A))
src/main/scala/monads/Monads.scala
- Implement
unit
for the writer monad:
def unit[A](a: A): WriterM[W, A] =
???
src/main/scala/monads/Monads.scala
- Implement
flatMap
for the writer monad:
def flatMap[A, B](ma: WriterM[W, A])(f: A => WriterM[W, B]): WriterM[W, B] =
???
src/main/scala/monads/Monads.scala
Try
In previous exercises you have learned about Try
. Try[A]
is an enum
with two cases: either a Success(a: A)
or a Failure(t: Throwable)
.
Try
is very similar to Option
from above: a Success(a)
is just like a Some(a)
, and a Failure
is a more informative version of None
(it records what went wrong).
And, just like with Option
, we can define monadic combinators on Try
:
- We use
Success
to inject a pure value of typeA
intoTry[A]
- We use pattern matching to run a computation on the value held in the body of a
Success
(this is what theflatMap
defined onTry
does).
Accordingly, Try
values can be used in for
comprehensions:
for x <- Success("Hello, ")
y <- Success("World!")
yield x + y // Success(Hello, World!)
src/main/scala/playground.worksheet.sc
for x <- Success("Hello, ")
y <- Failure(Exception("Woops"))
yield x + y // Failure(Exception: Woops)
src/main/scala/playground.worksheet.sc
But Try
, in Scala, does more than this: it has a special constructor Try
that takes a lazy value => a
and converts any native (JVM) exceptions into a Failure
object. This exception-catching behavior also exists in flatMap
:
Try(throw Exception("Woops")) // Failure(Exception: Woops)
src/main/scala/playground.worksheet.sc
Success(1).flatMap(x => throw Exception("Woops")) // Failure(Exception: Woops)
src/main/scala/playground.worksheet.sc
Notice how neither of these let the exception escape: instead, exceptions are “reified” as Failure
values.
What consequence does this have on monad laws? Let’s find out!
Use the following TryM
class to represent computations that may fail:
enum TryM[+A]:
case Success(a: A)
case Failure(ex: Throwable)
src/main/scala/monads/Try.scala
TryM on pure values
In this exercise, assume that all computations are pure and do not throw exceptions.
- Implement
unit
, assuming that no exceptions are thrown:
def unit[A](a: A): TryM[A] =
???
src/main/scala/monads/Try.scala
- Implement
flatMap
, assuming that no exceptions are thrown:
def flatMap[A, B](ma: TryM[A])(f: A => TryM[B]): TryM[B] =
???
src/main/scala/monads/Try.scala
- Show that these two functions respect the usual monad laws, assuming that no exceptions are thrown.
TryM on impure values
Let us now generalize the problem to handle impure computation. We want unit
and flatMap
to capture exceptions and turn them into values, so the first thing we need to do is to change the signature of unit
:
- Implement
unit
, making sure to capture any exceptions.
def unit[A](a: => A): TryM[A] =
???
src/main/scala/monads/Try.scala
- Implement
flatMap
, making sure to capture any exceptions:
def flatMap[A, B](ma: TryM[A])(f: A => TryM[B]): TryM[B] =
???
src/main/scala/monads/Try.scala
-
Show that these definitions violate one or more of the monad laws:
-
Translate each monad law into tests:
def assertAssociative[A, B, C](t: M[A], f: A => M[B], g: B => M[C]) = ???
src/test/scala/TrySuite.scala
def assertLeftIdentity[A, B](x: => A, f: A => M[B]) = ???
src/test/scala/TrySuite.scala
def assertRightIdentity[A](t: M[A]) = ???
src/test/scala/TrySuite.scala
-
Construct a failing test using one of these laws and a function or expression that throws an exception:
test("One of the monads laws fails for effectful `Try`.".fail) { // Find an invocation of one of the lemmas above that fails. }
src/test/scala/TrySuite.scala
-
Proving monad laws ⭐️
The State monad is a powerful construct in functional programming used for managing state in a purely functional way. It encapsulates state transformations and allows for stateful computations without mutating state directly.
Given your implementation of the State monad, let’s discuss the left unit law:
For any value a
and function f
that returns a monad:
monad.unit(a).flatMap(f) == f(a)
Write on paper an equational proof for the left unit law using the provided implementation of the State monad.
Functors
Monads are often defined with a map
method in addition to flatMap
:
trait M[T] {
def flatMap[U](f: T => M[U]): M[U]
def map[U](f: T => U): M[U]
}
Where map
and flatMap
are related by the following law:
Monad/Functor Consistency:
m.map(f) === m.flatMap(x => unit(f(x)))
We introduce you a new algebraic structure called Functor
-s. We say that a type F
is a Functor
if F[T]
has
a map
method with the following signature:
trait F[T] {
def map[U](f: T => U): F[U]
}
And there is a unit
method for F
with the following signature:
def unit[T](x: T): F[T]
Such that map
and unit
fulfill the following laws:
Identity:
m.map(x => x) === m
Associativity:
m.map(h).map(g) === m.map(x => g(h(x)))
Prove that any Monad
with a map
method that fulfills the Monad/Functor Consistency
law is also a Functor
.
Throwback: evaluator ⭐️
In this exercise, we will explore section 2 of the paper Monads for functional programming. Reading the paper is not needed to continue with the exercises. The goal is to demonstrate how monads can promote ease of programming.
Evaluator - error handling with monads
You have implemented a calculator in a previous lab. Recall that each operation corresponded to a different case in an eval
function. For example, the division case was:
def eval(expr: Expr): Int = expr match
// …
case Division(numerator, denominator) =>
eval(numerator) / eval(denominator)
When adding the division to the calculator, the division by zero case must be handled, otherwise an ArithmeticException
is thrown. The way that you tackled this problem was to “box” the result in a case class:
sealed trait Result
object DivByZero extends Result
case class Ok(value: Int) extends Result
src/main/scala/monads/Evaluator.scala
This leads to this version of the Division
case:
def eval(expr: Expr): Result = expr match
// …
case Division(numerator, denominator) =>
(eval(numerator), eval(denominator)) match
case (Ok(v1), Ok(0)) => DivByZero
case (Ok(v1), Ok(v2)) => Ok(v1 / v2)
case (_, _) => DivByZero
Recalling the calculator lab, how did the changes needed to handle the DivByZero
case affect the eval
function?
Solution
All other cases of the evaluator had to be updated to support the DivByZero
, even though they are not directly related to the change. They should therefore not require any development effort.
Specifically, the add
case had to be updated to handle the DivByZero
case:
def evaluate(e: BasicExpr): EvalResult =
e match
case Add(e1, e2) =>
(evaluate(e1), evaluate(e2)) match
case (Ok(v1), Ok(v2)) => Ok(v1 + v2)
case (_, _) => DivByZero
// …
A similar procedure had to be applied to all other cases.
The downside is clear: to add support for some new operation, all cases need to be modified, even the ones that are not related to the new feature!
Modularize the evaluator - combine
Let’s extract the logic dealing with DivByZero
and Ok
in a new function to better modularize. Let’s call that function combine
. This function takes two results of type Result
and a function f: (Int, Int) => Int
, and combines the two results if they are both of the Ok
type, or returns a DivByZero
in case one of the two Result
s is DivByZero
.
Let’s implement this function:
def combine(x: Result, y: Result, f: (Int, Int) => Result) =
???
src/main/scala/monads/Evaluator.scala
Now let’s modify eval
to use this new combine
function:
def eval(e: Expr): Result = e match
case Constant(a) => ???
case Add(a, b) => ??? // combine(???)
case Division(num, den) => ??? // combine(???)
src/main/scala/monads/Evaluator.scala
Are there parallels between the combine
function and the flatMap
function? What are the common points? What are the differences?
Combine using flatMap
Our function combine
depends on two values. How can we rewrite the combine
function so that it takes only one argument of type Result
, namely a flatMap
function?
Let’s implement the flatMap
function for our Result
class. When done, implement the function combineFlatMap
which does the same thing as combine
but by using flatMap
.
def flatMap(x: Result, f: Int => Result) =
x match
case DivByZero => ???
case Ok(value) => ???
def combineFlatMap(x: Result, y: Result, f: (Int, Int) => Result) =
??? /* calls to flatMap */
src/main/scala/monads/Evaluator.scala
Monadic version of eval
using flatMap
Do we need both combine
and flatMap
? No, we will write eval
using only flatMap
.
Let us now write this eval
version without combine
, by using only flatMap
:
def evalWithoutCombine(e: Expr): Result = e match
case Constant(a) => ???
case Add(a, b) => ???
case Division(num, den) => ???
src/main/scala/monads/Evaluator.scala
We now have a version of the eval
function that returns a monad instead of a simple Int
. This allows us to handle the error in the division case without modifying the other cases, namely the addition. Now let’s explore other features that we can add to our calculator and how using monads can help.
Adding more features to monadic eval
At this point, we modified the eval
function to use flatMap
. We had to modify every case in the function, so did we really save some development time? We will now add other features to our calculator and we will see how the work we did helps for future modifications.
For each new feature, we will compare the modifications needed for the eval
functions in the two cases: with and without monads. We will see how the monads help us keep clean and concise modifications and final code base.
The two stateful features we will add to the evaluator
are:
- Keeping track of how many additions have been performed.
- Keeping track of what operations have been performed, i.e. logging.
Keeping track of the number of additions
We start with the version of eval
that simply returns an Int
and throws an exception upon a division by zero. To avoid many changes, we will use the following type definition:
type Result = Int
Going back to the Ok
/DivByZero
case class is a matter of changing this type definition and modifying accordingly the eval
function.
Without monads
Implement the addition tracker without monads by modifying the following eval
function:
def evalStateful(e: Expr, addNumber: Int = 0): (Result, Int) =
???
src/main/scala/monads/Evaluator.scala
Think about how you would have done it in an Object Oriented programming language like Java or Python.
You would keep a global variable that would contain the counter that you would increment in the add
case. It is interesting to note that this would introduce challenges when working in a concurrent or parallel environment. This variable would need to be thread-local or protected by a lock.
Functional programming enforces immutability, which prevents us from doing it that way. Instead, we are forced to pass the counter around and therefore modifying heavily the code. This would however protect us from concurrency issues.
As we will see later on, Monads will help us to get some of this OO conciseness back.
With monads
Now let’s implement the same feature using monads.
Start by thinking about what the result type would be. What monad should we use?
A tuple (Result, Int)
is not the way to go. Think about what it would imply for flatMap
.
Solution
It is not appropriate because flatMap
would need to change. We want to preserve the following definition of flatMap
:
def flatMap(m: Monad, f: Result => Monad)
In this case, the idea is to use a function as the return type:
type Result = Int
type State = Int
type Monad = State => (Result, State)
src/main/scala/monads/Evaluator.scala
Implement the unit
, flatMap
and eval
functions for the monad type you picked:
def unit(a: Result): Monad =
???
extension (m: Monad)
def flatMap(f: Result => Monad): Monad =
???
src/main/scala/monads/Evaluator.scala
Then modify the eval
function to keep track of the number of additions as before, but using this monad as return type:
def eval(t: Expr): Monad = t match
case Constant(a) => unit(a)
case Add(a, b) =>
eval(a).flatMap(a =>
???
)
case Division(a, b) =>
eval(a).flatMap(a =>
???
)
src/main/scala/monads/Evaluator.scala
Compare the two approaches
Which one required more work? How far is your implementation of the eval
function from the evalStateful
one?
Another question to ask yourself when adding features like these is: how maintainable is the code?
Maintainability is indeed crucial in a software project as it is very likely that new features will be added in the future, or bugs fixed, or library updated, and so on. It is therefore a key aspect when writing software.
Logging
In this part, we will add a logging feature to the evaluator. This time, implementing the version without monads is optional. Feel free to do it if you want to see how monads help in this case.
We are ignoring the DivByZero
case for the logging feature for now.
Here is the definition of the monad that we are going to use; implement the unit
and flatMap
functions:
type Logs = List[String]
type Monad[A] = (A, Logs)
def unit[A](a: A): Monad[A] =
???
extension [A](m: Monad[A])
def flatMap[B](f: A => Monad[B]): Monad[B] =
???
src/main/scala/monads/Evaluator.scala
Note that the monad is now generically typed.
Here are two helper functions to manipulate the log and create message strings: out
and line
.
def out(output: String): Monad[Unit] =
((), List(output))
def line(t: Expr, res: Int) =
f"eval($t) ⇐ $res"
src/main/scala/monads/Evaluator.scala
Implement now the eval
function that also logs the performed operations, using the monad type defined above. You can use the out
and line
functions to log the operations.
def eval(t: Expr): Monad[Int] = t match
case Constant(a) => ???
case Add(a, b) => ???
case Division(a, b) => ???
src/main/scala/monads/Evaluator.scala
DivByZero and logging 🔥
Now we will add the DivByZero
case to the logging feature. This exercise is optional and no solutions are provided. Feel free to write new tests based on existing ones.
Logging and counting additions 🔥
Another optional exercise: combine the counting of additions and the logging features. No solutions are provided. Feel free to write new tests based on existing ones.