CPS fold -- fold with early exit · David Raab

CPS fold -- fold with early exit

The most general function to traverse a data-structures is the fold function. But fold has one problem that is sometimes not optimal. It always traverses the whole data-structure and we cannot abort the recusion early.

But sometimes, that is exactly what we want to do. For example when we want to search for a specific element in a list, when we found it, we don't want to go through the remaing list. When we want to check if all elements in a list satisfy a specific predicate then we also can stop on the first element that does not satisfy our predicate. And dozens of other cases where an early abort could be helpful.

We always can write our own recursive functions for those cases, but then we must ensure that we get tail-recursion right. Wouldn't it be better if we could abstract it just like fold? In this article I explain how to write a CPS fold that allows us to do this.

Continuation

The important concept in implementing a CPS fold is a so-called Continuation function or often just named CPS (Continuation-Passing Style). The idea of CPS is that we just pass an additional function as an argument that the user can call to explicitly recurs. This way the user of the CPS fold is in control of the recursion. The user then can decide if he wants to continue traversing a data-structure or return a value instead. But before we look into how we implement the function, let's see some use cases in how we use such a function.

We name our function foldk. Besides that, foldk looks nearly the same as fold, the only difference is that the function we pass to foldk now receives three arguments instead of just two.

1: 
2: 
[1..100] |> List.fold (fun acc x   -> acc + x) 0 // 5050
[1..100] |> foldk     (fun acc x k -> acc + x) 0 // 1

When we provide the same code then we already see how it differs. fold always runs through all elements of the list. It computes the accumulator and does all recursion on itself.

foldk on the other hand don't do any recursion on it's own. foldk always just do a single step. It just extract one element from our list and calls the folder function with the provided acc and the first element of our list.

That's why we get 1 as a result. It just calculates acc + x or 0 + 1 in the above example and then it immediately ends. We must explicitly tell foldk when it should recurs.

That's the reason why we have the third argument. k is the continuation function. k expects the next accumulator. When we call k we start recurring again on the next element in our list. The primary difference to fold is that the user of foldk has explicit control when recurring should happen.

1: 
[1..100] |> foldk (fun acc x k -> k (acc + x)) 0 // 5050

When we want to traverse all elements of our list, then we just call k with the next accumulator. But if we wanted to do that, we also could just use fold. So here is a more practical example in comparison to fold.

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
11: 
12: 
13: 
[5;10;15;10;5] |> List.fold (fun acc x ->
    if   x < 11
    then acc + x
    else acc
) 0
// 30

[5;10;15;10;5] |> foldk (fun acc x k ->
    if   x < 11
    then k (acc + x)
    else acc
) 0
// 15

Now we are getting different results. What fold did was: Pick every element that is smaller than 11 and add them together. foldk on the other hand runs as long all elements are smaller than 11 and only add those together it saw up to this point. As soon he encounter a bigger number it will stop traversing the list. fold calculated 5 + 10 + 10 + 5 while foldk just summed up the first two elements 5 + 10.

Implementing foldk

Implementing foldk is actually pretty easy, let's go over it:

1: 
2: 
3: 
4: 
let rec foldk f (acc:'State) xs =
    match xs with
    | []    -> acc
    | x::xs -> f acc x (fun lacc -> foldk f lacc xs)

First we start with the general pattern to traverse a list. We test if we either have an empty list or we extract the first element of our list. Then we think what we do in both cases.

The empty case is pretty easy. When we reached the end of our list, then we cannot advance forward anymore, that means we just return the accumulator acc.

Otherwise when we have an element we need to do something with every element. That is what our f is for. So we just call f acc x ???. In a normal fold we do the recursion inside fold, but in foldk we want to give the user the ability to recurs, that is why our third argument is now a continuation function (fun lacc -> foldk f lacc xs). The Continuation function expects the next accumulator. As you can see, when the user decides to call k, it just calls foldk again.

Let's go over a simple example to see how it works:

1: 
[1..5] |> foldk (fun acc x k -> k (acc + x)) 0 // 15

The result of our function is 15. Let's see step-by-step how we got this result. When we call foldk we start with 0 as our acc and the list [1;2;3;4;5] as our starting list. As a reminder this is foldk.

1: 
2: 
3: 
4: 
let rec foldk f (acc:'State) xs =
    match xs with
    | []    -> acc
    | x::xs -> f acc x (fun lacc -> foldk f lacc xs)

When i write f acc x k in the code section i refer to the whole right hand side of x::xs that means f acc x (fun lacc -> foldk f lacc xs). I just use [] and x::xs to represent the pattern matching in the foldk function.

Code

Evaluation / Description

foldk f 0 [1..5]

First call, we start foldk

[]

No, we did not reach the end

x::xs

1::[2;3;4;5] / Yes, it maches

f acc x k

f 0 1 (fun lacc -> foldk f lacc [2;3;4;5]) / We now execute f

k (acc + x)

k (0 + 1) / k is the lambda function passed to f

foldk f 1 [2;3;4;5]

[]

No

x::xs

2::[3;4;5]

f acc x k

f 1 2 (fun lacc -> foldk f lacc [3;4;5])

k (acc + x)

k (1 + 2)

foldk f 3 [3;4;5]

[]

No

x::xs

3::[4;5]

f acc x k

f 3 3 (fun lacc -> foldk f lacc [4;5])

k (acc + x)

k (3 + 3)

foldk f 6 [4;5]

[]

No

x::xs

4::[5]

f acc x k

f 6 4 (fun lacc -> foldk f lacc [5])

k (acc + x)

k (6 + 4)

foldk f 10 [5]

[]

No

x::xs

5::[]

f acc x k

f 10 5 (fun lacc -> foldk f lacc [])

k (acc + x)

k (10 + 5)

foldk f 15 []

[] -> acc

[] -> 15 / Yes, we just return acc (15)

And one-more time with an example that stops earlier:

1: 
[1..5] |> foldk (fun acc x k -> if x < 3 then k (acc + x) else acc) 0 // 3

Code

Evaluation / Description

foldk f 0 [1..5]

First call, we start foldk

[]

No

x::xs

1::[2;3;4;5]

f acc x k

f 0 1 (fun lacc -> foldk f lacc [2;3;4;5])

x < 3

1 < 3 / True, then branch

k (acc + x)

k (0 + 1)

foldk f 1 [2;3;4;5]

[]

No

x::xs

2::[3;4;5]

f acc x k

f 1 2 (fun lacc -> foldk f lacc [3;4;5])

x < 3

2 < 3 / True, then branch

k (acc + x)

k (1 + 2)

foldk f 3 [3;4;5]

[]

No

x::xs

3::[4;5]

f acc x k

f 3 3 (fun lacc -> foldk f lacc [4;5])

x < 3

3 < 3 / False, else branch

else acc

else 3 The pattern match on x::xs now returns 3

Implementing some other functions

Now, let's use our foldk function to implement some other functions. First tryPick

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
let tryPick predicate xs =
    xs |> foldk (fun acc x k ->
        if   predicate x
        then Some x
        else k acc
    ) None

[1..100] |> tryPick (fun x -> x % 5 = 0) // Some 5
[1..100] |> tryPick (fun x -> x > 10)    // Some 11
[1..100] |> tryPick (fun x -> x > 1000)  // None

We start with None as our default value. Once we found a matching x we just return it, otherwise we recurs. When we reach the end without finding an element, we return acc that still contains None.

1: 
2: 
3: 
4: 
5: 
6: 
7: 
8: 
9: 
let contains y xs =
    xs |> foldk (fun acc x k ->
        if   x = y
        then true
        else k acc
    ) false

[1..100] |> contains 10 // true
[1..100] |> contains 0  // false

Very similar to tryPick. We start with false and return it when we reach the end of the list without finding our wanted element. Otherwise we immediately return true. I think the rest of the function are nearly self-specking as they are all very similar.

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
11: 
12: 
13: 
14: 
15: 
16: 
17: 
18: 
19: 
20: 
21: 
22: 
23: 
24: 
25: 
26: 
27: 
28: 
29: 
30: 
31: 
32: 
33: 
34: 
35: 
36: 
37: 
38: 
39: 
40: 
41: 
42: 
43: 
44: 
45: 
46: 
47: 
48: 
let exists predicate xs =
    xs |> foldk (fun acc x k ->
        if   predicate x
        then true
        else k acc
    ) false

[1..100] |> exists (fun x -> x % 50 = 0) // true
[1..100] |> exists (fun x -> x < 0)      // false



let forall predicate xs =
    xs |> foldk (fun acc x k ->
        if   predicate x
        then k acc
        else false
    ) true

[2;4;6;8] |> forall (fun x -> x % 2 = 0) // true
[2;4;6;8] |> forall (fun x -> x % 2 = 1) // false



let item idx xs =
    xs |> foldk (fun acc x k ->
        if   idx = acc
        then x
        else k (acc + 1)
    ) 0

[2..2..100] |> item 0  // 2
[2..2..100] |> item 1  // 4
[2..2..100] |> item 2  // 6
[2..2..100] |> item 10 // 22



let take amount xs =
    xs |> foldk (fun (collected,acc) x k ->
        if   collected < amount
        then k (collected+1, x::acc)
        else (collected,acc)
    ) (0,[]) |> snd  |> List.rev

[1..100] |> take 0   // []
[1..100] |> take 3   // [1;2;3]
[1..3]   |> take 100 // [1;2;3]

The last line is btw. not the exact behaviour of List.take. The standard implementation throws an exception if the input list has not enough elements. We can achieve the same by checking the collected field after foldk finished and throw an exception if it is not the same as amount. But I like my behaviour more than the default implementation.

I could continue by implementing further functions, but I think at this point it should be obvious how foldk works and how we use it.

Summary

Implementing a CPS fold gives the control when to recurs to the caller of foldk.

Some task are easier solved by foldk. Some other task can sometimes be more efficient as we don't need to traverse the complete data-structure. In general if you ever wanted something similar to break or continue as you know it from imperative looping constructs. With foldk you have this ability and it works fine with immutable data-structures.

module Main
val foldk : f:('State -> 'a -> ('State -> 'State) -> 'State) -> acc:'State -> xs:'a list -> 'State

Full name: Main.foldk
val f : ('State -> 'a -> ('State -> 'State) -> 'State)
val acc : 'State
val xs : 'a list
val x : 'a
val lacc : 'State
Multiple items
module List

from Microsoft.FSharp.Collections

--------------------
type List<'T> =
  | ( [] )
  | ( :: ) of Head: 'T * Tail: 'T list
  interface IEnumerable
  interface IEnumerable<'T>
  member GetSlice : startIndex:int option * endIndex:int option -> 'T list
  member Head : 'T
  member IsEmpty : bool
  member Item : index:int -> 'T with get
  member Length : int
  member Tail : 'T list
  static member Cons : head:'T * tail:'T list -> 'T list
  static member Empty : 'T list

Full name: Microsoft.FSharp.Collections.List<_>
val fold : folder:('State -> 'T -> 'State) -> state:'State -> list:'T list -> 'State

Full name: Microsoft.FSharp.Collections.List.fold
val acc : int
val x : int
val k : (int -> int)
val tryPick : predicate:('a -> bool) -> xs:'a list -> 'a option

Full name: Main.tryPick
val predicate : ('a -> bool)
val acc : 'a option
val k : ('a option -> 'a option)
union case Option.Some: Value: 'T -> Option<'T>
union case Option.None: Option<'T>
val contains : y:'a -> xs:'a list -> bool (requires equality)

Full name: Main.contains
val y : 'a (requires equality)
val xs : 'a list (requires equality)
val acc : bool
val x : 'a (requires equality)
val k : (bool -> bool)
val exists : predicate:('a -> bool) -> xs:'a list -> bool

Full name: Main.exists
val forall : predicate:('a -> bool) -> xs:'a list -> bool

Full name: Main.forall
val item : idx:int -> xs:int list -> int

Full name: Main.item
val idx : int
val xs : int list
val take : amount:int -> xs:'a list -> 'a list

Full name: Main.take
val amount : int
val collected : int
val acc : 'a list
val k : (int * 'a list -> int * 'a list)
val snd : tuple:('T1 * 'T2) -> 'T2

Full name: Microsoft.FSharp.Core.Operators.snd
val rev : list:'T list -> 'T list

Full name: Microsoft.FSharp.Collections.List.rev