Tying the knot in Haskell, OCaml and Idris

18 April 2018

Categories: Programming Tags: haskell ocaml idris

From the 9th to the 13th of April was the 2018 Midlands Graduate School at the University of Nottingham. It was a very pleasant and interesting event, with a record number of participants(!), many of whom I really enjoyed meeting and talking to.

The invited course was given by Edwin Brady on Idris, a functional language with dependent types, primarily directed toward programming rather than theorem proving. I can’t encourage you enough to check it out, it’s a great language with great tools to ease development!

This post assumes some familiarity with Haskell and OCaml. If you’re not familiar with Idris, but familiar with Haskell (and vice-versa), you shouldn’t have any problem to understand what’s going on as the two are very close.

The problem

Unlike Haskell, Idris is strict. But it has the ability to have explicitly lazy values: they are distinguished at the type level from normal values with the type constructor Lazy.

This got my supervisor wondering if laziness in Idris was powerful enough to make use of some classic « tying the knot » techniques that rely on laziness.

Here is the problem he asked me to try and solve:

Given a binary tree:

data Tree = Leaf Int | Node Int Tree Tree

Write a function that, in a single traversal of the tree, constructs a new tree where every integer in the new tree is the sum of the integers in the previous tree.

For example, the following tree:

     9
    / \
   /   \
  7     8
 / \   / \
1  2  3   4
         / \
        5   6

should be mapped to:

      45
     /  \
    /    \
   /      \
  45      45
 /  \    /  \
45  45  45  45
           /  \
          45  45

In Haskell

In Haskell, here is how one does it. First write an auxiliary function:

recSumAux :: Tree -> Int -> (Tree, Int)
recSumAux (Leaf n) s = (Leaf s, n)
recSumAux (Node n t1 t2) s = (Node s t1' t2', n + s1 + s2)
   where (t1', s1) = recSumAux t1 s
         (t2', s2) = recSumAux t2 s

recSumAux nearly does the two things we want: it constructs a new tree, replacing all integers with a new value passed as arguments and, in the mean time, calculates the sum of all the values in the old tree. The only thing we have to do is to supply it with the right s. But the right s is the second output of recSumAux, so we just have to « tie the knot », and voilà:

recSum :: Tree -> Tree
recSum tree = ntree
  where (ntree, s) = recSumAux tree s

Why does this work? Because Haskell is lazy, it will only evaluate s when it needs to. But recSumAux doesn’t need to evaluate s to produce a result, indeed when we construct, say Leaf s, we don’t need to know the value of s. We just need to point to a place where we can eventually find out its value (such a place is called a thunk).

So we’ve seen how to do this in a lazy language. Now can we do it with « manual » laziness annotations?

In OCaml

I know OCaml a lot better than Idris, and OCaml is like Idris with regard to laziness: you have a lazy keyword that takes some expression of type 'a, doesn’t evaluate it, and gives you back a value of type 'a lazy_t. To evaluate your lazy value, you can just use Lazy.force. So instead of directly trying to write an Idris version, I first tried to write an OCaml one.

rec_sum_aux can be ported quite literally:

type lazy_int = int lazy_t

type tree =
  | Leaf of lazy_int
  | Node of lazy_int * tree * tree

(* rec_sum_aux : tree -> lazy_int -> tree * lazy_int *)
let rec rec_sum_aux tree s =
  match tree with
  | Leaf(n) -> (Leaf s, n)
  | Node(n, t1, t2) ->
     let (t1, s1) = rec_sum_aux t1 s in
     let (t2, s2) = rec_sum_aux t2 s in
     (Node(s, t1, t2), lazy (Lazy.force n + Lazy.force s1 + Lazy.force s2))

However, we can’t tie the knot like we did in Haskell. This:

let rec_sum tree =
  let rec res = rec_sum_aux tree (snd res) in
  fst res

will cause my least favourite compile-time error from OCaml: This kind of expression is not allowed as right-hand side of `let rec'.

To get away with it, an idea is to use a fixed-point combinator:

(* fix : ('a lazy_t -> 'a) -> 'a *)
let rec fix f = f (lazy (fix f))

fix takes a function f and applies it to the result of fix f. This can only work if f is lazy in its first argument, otherwise it will obviously loop forever. There is always a way to transform a let rec block into an application of fix, and our rec_sum function is no exception:

(* map_lazy : ('a -> 'b) -> 'a lazy_t -> 'b lazy_t *)
let map_lazy f v = lazy (f (Lazy.force v))

let rec_sum tree =
  let rec_fun res =
    let (res_tree, lazy_int) = rec_sum_aux tree (map_lazy snd res) in
    (res_tree, Lazy.force lazy_int) in
  fst (fix rec_fun)

And it works! There’s a little catch however. Using fix actually entails two traversal of the trees, rather than one. Indeed, if we inline the first two calls to fix in rec_sum we get:

let rec_sum tree =
   let rec_fun res = ... in
   fst (rec_fun (rec_fun (fix rec_fun)))

What happens then is that the inner call to rec_fun calculates the sum by descending along the tree. It also reconstructs a new tree, where each lazy_int field points to the thunk map_lazy snd (fix rec_fun). This thunk will never be evaluated, because the outer call to rec_fun will go down the tree again, replacing this thunk with the sum calculated by the inner call . Therefore there is in fact two traversals of the tree, instead of one in the Haskell solution.

If you’re not convinced, you can slightly tweak the type of the tree so that it contains a mutable reference that will count the number of time the tree has been visited:

type tree =
  | Leaf of int ref * lazy_int
  | Node of int ref * lazy_int * tree * tree

Then you just have to increment the references in rec_sum_aux:

let rec rec_sum_aux tree s =
  match tree with
  | Leaf(r, n) -> incr r; (Leaf(r, s), Lazy.force n)
  | Node(r, n, t1, t2) ->
     let (t1, s1) = rec_sum_aux t1 s in
     let (t2, s2) = rec_sum_aux t2 s in
     incr r;
     (Node(r, s, t1, t2), Lazy.force n + s1 + s2)

And you’ll see that after the call to rec_sum, these references will contain 2, indicating two traversals. I have no idea however how to avoid this, so contributions are welcome!

We can slightly improve the formulation above, as it is a bit « unnecessarily lazy ». Indeed, we don’t need the output of rec_sum_aux to be evaluated lazily, we just need the constructed tree to contain lazy values. So we can rewrite rec_sum_aux and rec_sum as:

(* rec_sum_aux : tree -> lazy_int -> tree * int *)
let rec rec_sum_aux tree s =
  match tree with
  | Leaf(n) -> (Leaf s, Lazy.force n)
  | Node(n, t1, t2) ->
     let (t1, s1) = rec_sum_aux t1 s in
     let (t2, s2) = rec_sum_aux t2 s in
     (Node(s, t1, t2), Lazy.force n + s1 + s2))

let rec_sum tree =
  let rec_fun res =
    let (res_tree, int) = rec_sum_aux tree (map_lazy snd res) in
    (res_tree, int) in
  fst (fix rec_fun)

In Idris

The OCaml version can be ported directly to Idris. In Idris, the equivalent of the lazy keyword is Delay, and the equivalent of Lazy.force is simply Force.

However Idris can infer from the types where to put Delay and Force automatically. That makes the Idris version a lot more pleasant to read:

data Tree : Type where
  Leaf : Lazy Int -> Tree
  Node : Lazy Int -> Tree -> Tree -> Tree

recSumAux : Tree -> Lazy Int -> (Tree, Int)
recSumAux (Leaf n) s = (Leaf s, n)
recSumAux (Node n t1 t2) s =
  let (t1', s1) = recSumAux t1 s
      (t2', s2) = recSumAux t2 s
  in (Node s t1' t2', n + s1 + s2)

fix : (Lazy a -> a) -> a
fix f = f (fix f)

recSum : Tree -> Tree
recSum tree =
  fst (fix (\res => recSumAux tree (snd res)))

This version has the same problem as the OCaml one. And also it won’t pass Idris’s totality checker. It’s not a real problem since partial functions are allowed in Idris, but it is a bit unsatisfactory. Again, contributions are welcome!

Another problem with the Idris version is that, unlike in OCaml or Haskell, results of evaluating a lazy values aren’t cached. They basically behave like functions of type () -> a. To visualise that, you can replace s in recSumAux by unsafePerformIO (print "Foo" *> pure s) for instance, and you’ll see that the side effects are executed at each evaluation of s.

Conclusion

We’ve seen that it is possible to solve this problem in a strict language with careful use of laziness, even though in both cases we get a less efficient solution that does two traversals of the tree instead of one. In the case of Idris, the fact that the result of evaluations of lazy values are not cached make it, at least for now, quite unsuitable for this sort of circular programming techniques.

Post-scriptum

If you really have to solve this problem in a strict language, there is a way that doesn’t quite respect the rule (in that it doesn’t reconstruct the tree directly). Instead of returning a tree, you can make recSumAux return a closure that will reconstruct the tree when you pass it the sum, and then make recSum pass the sum to this closure:

recSumAux : Tree -> (Int -> Tree, Int)
recSumAux (Leaf n) = (\sum => Leaf sum, n)
recSumAux (Node x t1 t2) =
  let (ct1, s1) = recSumAux t1
      (ct2, s2) = recSumAux t2
  in (\sum => Node sum (ct1 sum) (ct2 sum), x + s1 + s2)

recSum : Tree -> Tree
recSum tree =
  let (ctree, sum) = recSumAux tree
  in ctree sum

Or, if your language has references, have your tree contain references to an integer instead of plain integer:

type tree =
  | Leaf of int ref
  | Node of int ref * tree * tree

let rec rec_sum_aux sum = function
  | Leaf(v) ->
     sum := !sum + !v;
     Leaf(sum)
  | Node(v, t1, t2) ->
     let t1 = rec_sum_aux sum t1 in
     let t2 = rec_sum_aux sum t2 in
     sum := !sum + !v;
     Node(sum, t1, t2)

let rec_sum tree = rec_sum_aux (ref 0) tree

Acknowledgement

Many thanks to Edwin Brady for helping me on the Idris implementation (that is, fixing the very stupid mistakes I had made and understand what the Idris errors meant as a result) and to Jonathan King, from Imperial College, for the discussion we had on how to make the Idris implementation pass the totality checker.