Last updated on
Debrief on the calculator lab (2025-10-05)
Here is a post-lab discussions about the calculator lab, we take a look at a few interesting topics that might have been challenging for you.
The most difficult part of the lab seems to have been the simplifier, and within the simplifier, the simplify function that combined both constant-folding and algebraic simplifications. Let’s look at four things:
- The general pattern of matching on the result of a recursive call, and the performance implications of duplicating calls;
- Why
constfold(algebraic(e))didn’t work, how it could have worked, and what the fix was (writing a separate recursive function); - What happened if you ordered rules incorrectly (a common mistake) in
simplify; - And finally, how to write
constfold,algebraic, andsimplifyas succinctly as possible, using HOFs.
If you finished the lab successfully and want to save time, you can skip straight to the last section “A neat and clean implementation using HOFs”. If you’re really in a hurry, skip straight to “The final implementation of simplify”
Matching on recursive calls
The ‘calculator’ lab was one of our first encounters with a function that pattern-matches on the results of the recursive calls it makes. It looked like this:
def evaluate(e: BasicExpr): EvalResult =
e match
case Number(n) => Ok(n)
case Add(e1, e2) =>
(evaluate(e1), evaluate(e2)) match
case (Ok(v1), Ok(v2)) => Ok(v1 + v2)
case (_, _) => DivByZero
…
… and like this:
def constfold(e: FullExpr): FullExpr =
e match
case Number(_) => e
case Add(e1, e2) =>
(constfold(e1), constfold(e2)) match
case (Number(n1), Number(n2)) => Number(n1 + n2)
case (s1, s2) => Add(s1, s2)
…
While this pattern can look disorienting at first, it’s actually a computation approach that we’re very familiar with! It works just like length, for example:
def length(l: IntList): Int =
if l.isEmpty then 0
else 1 + length(l.tail)
Specifically, we have:
-
Base cases:
if l.isEmpty then 0e match case Number(n) => edef constfold(e: FullExpr): FullExpr = e match case Number(_) => e -
Recursion, with post-processing of the result (to make things more clear I’ve separated the
resultand thepostprocessing phase into variables):val result = length(l.tail) val postprocess = res => res + 1 postprocess(result)val results = (evaluate(e1), evaluate(e2)) val postprocess = res => res match case (Ok(v1), Ok(v2)) => Ok(v1 + v2) case (_, _) => DivByZero postprocess(result)… and like this:
val results = (constfold(e1), constfold(e2)) val postprocess = res => res match case (Number(n1), Number(n2)) => Number(n1 + n2) case (s1, s2) => Add(s1, s2) postprocess(result)
This pattern is very common when dealing with enum types: any non-tail-recursive function that works with enums will do some work after its recursive calls, and that work will typically be done using pattern matching.
Exercise: Write an extension method foldRightOption[A, B](z: B)(op: (A, B) => Option[B]): Option[B] that works just like foldRight, except that op may fail and return None. Do you notice the pattern that we discussed above?
Solution
extension[A](l: List[A])
def foldRightOption[B](z: B)(op: (A, B) => Option[B]): Option[B] =
l match
case Nil => Some(z)
case h :: t =>
t.foldRightOption(z)(op) match
case None => None
case Some(acc) => op(h, acc)
calculator/foldRightOption.worksheet.sc
Here is a use case:
def sumStringsOrFail(l: List[String]): Option[Int] =
l.foldRightOption(0)((s, acc) => s.toIntOption.map(acc + _))
sumStringsOrFail(List())
sumStringsOrFail(List("1", "two", "3"))
sumStringsOrFail(List("1", "2", "3", "4"))
calculator/foldRightOption.worksheet.sc
Avoiding exponential blowup
A common mistake when writing functions in the style above is to duplicate a recursive call. For example:
case Mul(e1, e2) =>
(simplify(e1), simplify(e2)) match
case (Number(0), _) => Number(0)
case (_, Number(1)) => simplify(e1)
…
Those who made this mistake quickly realized that something was wrong: the function appears to never return. In fact, it does return, and there’s no infinite loop; it just takes a very long time.
The code makes 2 calls to itself with the same argument e1 on every step of recursion. Hence, if calling simplify(e1) takes T units of time, then (ignoring e2) calling simplify(Mul(e1, e2)) takes 2 * T (two calls times T per call). The time to process an expression of depth D (D levels of nesting), hence, is $2 \cdot 2 \cdot 2 \cdot 2 \cdot 2 \cdot … \cdot 2 = 2^D$.
Contrast this with the following code:
case Mul(e1, e2) =>
(simplify(e1), simplify(e2)) match
case (Number(0), _) => Number(0)
case (s1, Number(1)) => s1
…
This time there is just one call. Hence, if calling simplify(e1) takes T units of time, then (ignoring e2) calling simplify(Mul(e1, e2)) takes T + 1 (T per call plus a constant amount of work). The time to process an expression of depth D, hence, is $1 + 1 + 1 + … + 1 = D$.
constfold(algebraic(e))
As stated in the handout, writing simplify(e: FullExpr) = algebraic(constfold(e)) did not work. Take a look at this example to remember why.
algebraic(constfold(‘0 * x - y * (1 - 1)’))
algebraic(‘0 * x - y * 0’)
‘0 - 0’
constfold(algebraic(‘0 * x - y * (1 - 1)’))
constfold(‘0 - y * (1 - 1)’)
‘0 - y * 0’
Debugging
An earlier version of the calculator lab had fewer tests; in that version, the simplest test that failed when implementing simplify as algebraic `andThen` constfold was this:
random test: simplify 1.0 * (1.0 * (-4.719951559702111 * y1 + -8.927959938264125 * y0 - (0.0 + (-4.719951559702111 * y1 + -8.927959938264125 * y0)) - 0.0 + 0.0 + 0.0 + (0.0 - 0.0)) + 1.0 * 0.0 * 1.0) / 1.0 + (0.0 - 0.0) should be 0.0, but it is actually -0.0
=> Obtained
Neg(
e = Number(
value = 0.0
)
)
=> Diff (- obtained, + expected)
-Neg(
- e = Number(
- value = 0.0
- )
+Number(
+ value = 0.0
)
Let’s assume that we were looking at this example and trying to find the bug. We know how to approach issues like this: simplify and minimize, using the divide and conquer technique!
Set up the environment
Open a worksheet (I went to calculator.worksheet.sc that was included in the lab).
Reproduce the issue
-
Parse the text into an expression
import calculator.FullExpr import calculator.FullDriver val e: FullExpr = FullDriver.parse("1.0 * (1.0 * (-4.719951559702111 * y1 + -8.927959938264125 * y0 - (0.0 + (-4.719951559702111 * y1 + -8.927959938264125 * y0)) - 0.0 + 0.0 + 0.0 + (0.0 - 0.0)) + 1.0 * 0.0 * 1.0) / 1.0 + (0.0 - 0.0)").get -
Reproduce the bug:
import FullDriver.Simplifier.* assert(simplify(e) == simplify(simplify(e)))At this point, as expected, we get an assertion failure: a failed test.
Simplify and minimize
-
Repeatedly remove small bits of the expression (additions of
0, multiplication by1), making sure that the assertion still fails. I get to the following:val e: FullExpr = FullDriver.parse("(-4.719951559702111 * y1 + -8.927959938264125 * y0 - ((-4.719951559702111 * y1 + -8.927959938264125 * y0)) - 0.0)").get -
Simplify further! These long floating point values look suspicious, so we can get rid of them:
val e: FullExpr = FullDriver.parse("(y1 + y0 - (y1 + y0) - 0.0)").get… and I simplify it further to this:
val e: FullExpr = FullDriver.parse("y - y - 0.0").get
Observe the system At this point we have a simple enough example to reason about and try to come up with solutions.
Using fixpoints
Could this approach have worked? Looking at the example above, it looks like an additional layer of simplification may be enough:
def simplify(e: FullExpr): FullExpr =
constfold(algebraic(constfold(algebraic(e))))
… but why would that work? If we just construct a slightly more deeply nested example, we’ll run into issues again… until we add more levels of nesting.
As at least one of you noted, the following implementation did work:
def simplify(e: FullExpr): FullExpr =
constfold(algebraic(constfold(algebraic(constfold(algebraic(constfold(algebraic(constfold(algebraic(constfold(algebraic(constfold(algebraic(constfold(e)))))))))))))))
We actually had a discussion before releasing the lab about whether we should have tests to catch this case. In the end, we decided to not include deeply nested tests, and hence to allow this (incorrect) solution.
The right thing to do, if we must reuse these two functions, is to use a fixpoint to apply them as much as needed:
def simplify(e: FullExpr) =
def fixpoint[A](f: A => A)(a: A): A =
val fa = f(a)
if fa == a then a
else fixpoint(f)(fa)
fixpoint(constfold `andThen` algebraic)(e)
calculator/../../labs/calculator/src/main/scala/calculator/full/FullSimplifier.scala
This is not a good solution! It can be very costly: in the worst case, each iteration simplifies just one node of the tree, at the very bottom (close to the leaves). In that case, the fixpoint will need as many steps as the tree is deep (depth $D$), and each step will take as many operations as the size of the tree (size $S$). In total, this will be $O(D * S)$, which is $O(S^2)$ in the worst case.
The right solution
To make sure that we simplify as much as possible, we must run both sets of rules recursively at every level of the tree; that is, we need to write a new recursive function:
def simplify(e: FullExpr): FullExpr =
e match
…
case Minus(e1, e2) =>
(simplify(e1), simplify(e2)) match
// First, the constfold step
case (Number(n1), Number(n2)) => Number(n1 - n2)
// Second, the algebraic simplification step
case (Number(0), e) => Neg(e)
case (e, Number(0)) => e
case (e1, e2) if e1 == e2 => Number(0)
// Finally, the fallback
case (s1, s2) => Minus(s1, s2)
It’s valid to put two operations as disjoint match cases because the result (right of the =>) never match the patterns (these simplifications can interact at different levels of the tree, but not at the same level).
This version is much more efficient than the previous one! It does a single traversal.
A common mistake in the efficient recursive implementation
There is one subtle mistake that can happen when writing the code in the way above:
def simplify(e: FullExpr): FullExpr =
e match
…
case Minus(e1, e2) => // ⚠ Incorrect code below!
(simplify(e1), simplify(e2)) match
// Incorrect: Start with the algebraic step
case (Number(0), e) => Neg(e)
case (e, Number(0)) => e
case (e1, e2) if e1 == e2 => Number(0)
// Then continue with the constfold step
case (Number(n1), Number(n2)) => Number(n1 - n2)
// Finally, the fallback
case (s1, s2) => Minus(s1, s2)
Exercise: Why does this version not work?
Show solution
Here is a counterexample: with this implementation, the expression ‘0 - (1 - 1)’ is not correctly simplified:
simplify(‘0 - (1 - 1)’)
(simplify(‘0’), simplify(‘1 - 1’)) match …
(‘0’, ‘0’) match case (Number(0), e) => Neg(e)
‘-0’
But why? Because of priority order between match branches! We have two branches that match (‘0’, ‘0’):
case (Number(0), e) => Neg(e)case (Number(n1), Number(n2)) => Number(n1 - n2)
The second one is the right one (it’s more specific), so we should list it first.
A neat and clean implementation using HOFs
The version posted above is perfectly reasonable, but it’s very verbose. Could we do better? Yes! We can get there by looking for common patterns, like in the find lab, and eliminating redundancy. Here is a part of constfold:
case Add(e1, e2) =>
(constfold(e1), constfold(e2)) match
case (Number(n1), Number(n2)) => Number(n1 + n2)
case (s1, s2) => Add(s1, s2)
case Minus(e1, e2) =>
(constfold(e1), constfold(e2)) match
case (Number(n1), Number(n2)) => Number(n1 - n2)
case (s1, s2) => Minus(s1, s2)
…
… and here is a part of algebraic:
case Add(e1, e2) =>
(algebraic(e1), algebraic(e2)) match
case (Number(0), e2) => e2
case (e1, Number(0)) => e1
case (s1, s2) => Add(s1, s2)
case Minus(e1, e2) =>
(algebraic(e1), algebraic(e2)) match
case (Number(0), e) => Neg(e)
case (e, Number(0)) => e
case (e1, e2) if e1 == e2 => Number(0)
case (s1, s2) => Minus(s1, s2)
There is a common skeleton here:
- Apply the recursive simplification function (
constfoldoralgebraic) to both children, then - If the result follows a specific pattern, then return a custom result, and
- Otherwise, return the same original constructor (
Add,Mul) with the simplified children.
With this observation, we can split the process into two phases:
- Simplify children, then
- Apply rules on the current node.
For constfold, we get this:
val withSimplifiedChildren = expr match
case Add(e1, e2) => Add(constfold(e1), constfold(e2))
case Minus(e1, e2) => Minus(constfold(e1), constfold(e2))
…
withSimplifiedChildren match
case Add(Number(n1), Number(n2)) => Number(n1 + n2)
case Minus(Number(n1), Number(n2)) => Number(n1 - n2)
…
case _ => withSimplifiedChildren
For algebraic, we get this:
val withSimplifiedChildren = expr match
case Add(e1, e2) => Add(algebraic(e1), algebraic(e2))
case Minus(e1, e2) => Minus(algebraic(e1), algebraic(e2))
…
withSimplifiedChildren match
case Add(Number(0), e2) => e2
case Add(e1, Number(0)) => e1
case Minus(Number(0), e) => Neg(e)
case Minus(e, Number(0)) => e
case Minus(e1, e2) if e1 == e2 => Number(0)
…
case _ => withSimplifiedChildren
Already we have a great improvement: we neatly separate the recursion phase (applying a function to children) from the post-processing phase (simplifying the current node). Notice how all the rules are together in one neat block in each function, instead of being mixed with the recursive calls.
Look carefully once more: is there anything in common between these two functions? Yes! the withSimplifiedChildren part is just a map on the tree. Let’s write it that way, then:
def mapExpr(e: FullExpr, f: FullExpr => FullExpr): FullExpr =
f(e match
case Number(_) => e
case Add(e1, e2) => Add(mapExpr(e1, f), mapExpr(e2, f))
case Minus(e1, e2) => Minus(mapExpr(e1, f), mapExpr(e2, f))
case Mul(e1, e2) => Mul(mapExpr(e1, f), mapExpr(e2, f))
case Div(e1, e2) => Div(mapExpr(e1, f), mapExpr(e2, f))
case Neg(e) => Neg(mapExpr(e, f))
case Var(_) => e
)
calculator/../../labs/calculator/src/main/scala/calculator/full/FullSimplifier.scala
All that we do is apply a function to each node, after applying it to its children. We can then fully separate the rules from the recursion…
def constfold1(expr: FullExpr) = expr match
case Add(Number(n1), Number(n2)) => Number(n1 + n2)
case Minus(Number(n1), Number(n2)) => Number(n1 - n2)
case Mul(Number(n1), Number(n2)) => Number(n1 * n2)
case Div(Number(n1), Number(n2)) => Number(n1 / n2)
case Neg(Number(n)) => Number(-n)
case e => e
calculator/../../labs/calculator/src/main/scala/calculator/full/FullSimplifier.scala
… and constfold is just a special case of map!
def constfold(e: FullExpr): FullExpr =
mapExpr(e, constfold1)
calculator/../../labs/calculator/src/main/scala/calculator/full/FullSimplifier.scala
It’s straightforward to do the same for algebraic.
The final implementation of simplify
So then what does simplify look like in this model? We just need to apply both single-node transformations at each level of the tree, so the complete implementation is just the following:
def simplify(e: FullExpr): FullExpr =
mapExpr(e, constfold1 `andThen` algebraic1)
calculator/../../labs/calculator/src/main/scala/calculator/full/FullSimplifier.scala
Voilà !