# CPS fold -- fold with early exit

The most general function to traverse a data-structures is the `fold` function. But `fold` has one problem that is sometimes not optimal. It always traverses the whole data-structure and we cannot abort the recusion early.

But sometimes, that is exactly what we want to do. For example when we want to search for a specific element in a list, when we found it, we don't want to go through the remaing list. When we want to check if all elements in a list satisfy a specific predicate then we also can stop on the first element that does not satisfy our predicate. And dozens of other cases where an early abort could be helpful.

We always can write our own recursive functions for those cases, but then we must ensure that we get tail-recursion right. Wouldn't it be better if we could abstract it just like fold? In this article I explain how to write a CPS fold that allows us to do this.

## Continuation

The important concept in implementing a CPS fold is a so-called Continuation function or often just named CPS (Continuation-Passing Style). The idea of CPS is that we just pass an additional function as an argument that the user can call to explicitly recurs. This way the user of the CPS fold is in control of the recursion. The user then can decide if he wants to continue traversing a data-structure or return a value instead. But before we look into how we implement the function, let's see some use cases in how we use such a function.

We name our function `foldk`. Besides that, `foldk` looks nearly the same as `fold`, the only difference is that the function we pass to `foldk` now receives three arguments instead of just two.

 ```1: 2: ``` ``````[1..100] |> List.fold (fun acc x -> acc + x) 0 // 5050 [1..100] |> foldk (fun acc x k -> acc + x) 0 // 1 ``````

When we provide the same code then we already see how it differs. `fold` always runs through all elements of the list. It computes the accumulator and does all recursion on itself.

`foldk` on the other hand don't do any recursion on it's own. `foldk` always just do a single step. It just extract one element from our list and calls the folder function with the provided `acc` and the first element of our list.

That's why we get `1` as a result. It just calculates `acc + x` or `0 + 1` in the above example and then it immediately ends. We must explicitly tell `foldk` when it should recurs.

That's the reason why we have the third argument. `k` is the continuation function. `k` expects the next accumulator. When we call `k` we start recurring again on the next element in our list. The primary difference to `fold` is that the user of `foldk` has explicit control when recurring should happen.

 ```1: ``` ``````[1..100] |> foldk (fun acc x k -> k (acc + x)) 0 // 5050 ``````

When we want to traverse all elements of our list, then we just call `k` with the next accumulator. But if we wanted to do that, we also could just use `fold`. So here is a more practical example in comparison to `fold`.

 ``` 1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: ``` ``````[5;10;15;10;5] |> List.fold (fun acc x -> if x < 11 then acc + x else acc ) 0 // 30 [5;10;15;10;5] |> foldk (fun acc x k -> if x < 11 then k (acc + x) else acc ) 0 // 15 ``````

Now we are getting different results. What `fold` did was: Pick every element that is smaller than 11 and add them together. `foldk` on the other hand runs as long all elements are smaller than `11` and only add those together it saw up to this point. As soon he encounter a bigger number it will stop traversing the list. `fold` calculated `5 + 10 + 10 + 5` while `foldk` just summed up the first two elements `5 + 10`.

## Implementing foldk

Implementing `foldk` is actually pretty easy, let's go over it:

 ```1: 2: 3: 4: ``` ``````let rec foldk f (acc:'State) xs = match xs with | [] -> acc | x::xs -> f acc x (fun lacc -> foldk f lacc xs) ``````

First we start with the general pattern to traverse a list. We test if we either have an empty list or we extract the first element of our list. Then we think what we do in both cases.

The empty case is pretty easy. When we reached the end of our list, then we cannot advance forward anymore, that means we just return the accumulator `acc`.

Otherwise when we have an element we need to do something with every element. That is what our `f` is for. So we just call `f acc x ???`. In a normal `fold` we do the recursion inside `fold`, but in `foldk` we want to give the user the ability to recurs, that is why our third argument is now a continuation function `(fun lacc -> foldk f lacc xs)`. The Continuation function expects the next accumulator. As you can see, when the user decides to call `k`, it just calls `foldk` again.

Let's go over a simple example to see how it works:

 ```1: ``` ``````[1..5] |> foldk (fun acc x k -> k (acc + x)) 0 // 15 ``````

The result of our function is `15`. Let's see step-by-step how we got this result. When we call `foldk` we start with `0` as our `acc` and the list `[1;2;3;4;5]` as our starting list. As a reminder this is `foldk`.

 ```1: 2: 3: 4: ``` ``````let rec foldk f (acc:'State) xs = match xs with | [] -> acc | x::xs -> f acc x (fun lacc -> foldk f lacc xs) ``````

When i write `f acc x k` in the code section i refer to the whole right hand side of `x::xs` that means `f acc x (fun lacc -> foldk f lacc xs)`. I just use `[]` and `x::xs` to represent the pattern matching in the `foldk` function.

Code

Evaluation / Description

`foldk f 0 [1..5]`

First call, we start foldk

`[]`

No, we did not reach the end

`x::xs`

`1::[2;3;4;5]` / Yes, it maches

`f acc x k`

`f 0 1 (fun lacc -> foldk f lacc [2;3;4;5])` / We now execute `f`

`k (acc + x)`

`k (0 + 1)` / `k` is the lambda function passed to `f`

`foldk f 1 [2;3;4;5]`

`[]`

No

`x::xs`

`2::[3;4;5]`

`f acc x k`

`f 1 2 (fun lacc -> foldk f lacc [3;4;5])`

`k (acc + x)`

`k (1 + 2)`

`foldk f 3 [3;4;5]`

`[]`

No

`x::xs`

`3::[4;5]`

`f acc x k`

`f 3 3 (fun lacc -> foldk f lacc [4;5])`

`k (acc + x)`

`k (3 + 3)`

`foldk f 6 [4;5]`

`[]`

No

`x::xs`

`4::`

`f acc x k`

`f 6 4 (fun lacc -> foldk f lacc )`

`k (acc + x)`

`k (6 + 4)`

`foldk f 10 `

`[]`

No

`x::xs`

`5::[]`

`f acc x k`

`f 10 5 (fun lacc -> foldk f lacc [])`

`k (acc + x)`

`k (10 + 5)`

`foldk f 15 []`

`[] -> acc`

`[] -> 15` / Yes, we just return `acc` (15)

And one-more time with an example that stops earlier:

 ```1: ``` ``````[1..5] |> foldk (fun acc x k -> if x < 3 then k (acc + x) else acc) 0 // 3 ``````

Code

Evaluation / Description

`foldk f 0 [1..5]`

First call, we start foldk

`[]`

No

`x::xs`

`1::[2;3;4;5]`

`f acc x k`

`f 0 1 (fun lacc -> foldk f lacc [2;3;4;5])`

`x < 3`

`1 < 3` / True, then branch

`k (acc + x)`

`k (0 + 1)`

`foldk f 1 [2;3;4;5]`

`[]`

No

`x::xs`

`2::[3;4;5]`

`f acc x k`

`f 1 2 (fun lacc -> foldk f lacc [3;4;5])`

`x < 3`

`2 < 3` / True, then branch

`k (acc + x)`

`k (1 + 2)`

`foldk f 3 [3;4;5]`

`[]`

No

`x::xs`

`3::[4;5]`

`f acc x k`

`f 3 3 (fun lacc -> foldk f lacc [4;5])`

`x < 3`

`3 < 3` / False, else branch

`else acc`

`else 3` The pattern match on `x::xs` now returns 3

## Implementing some other functions

Now, let's use our `foldk` function to implement some other functions. First `tryPick`

 ``` 1: 2: 3: 4: 5: 6: 7: 8: 9: 10: ``` ``````let tryPick predicate xs = xs |> foldk (fun acc x k -> if predicate x then Some x else k acc ) None [1..100] |> tryPick (fun x -> x % 5 = 0) // Some 5 [1..100] |> tryPick (fun x -> x > 10) // Some 11 [1..100] |> tryPick (fun x -> x > 1000) // None ``````

We start with `None` as our default value. Once we found a matching `x` we just return it, otherwise we recurs. When we reach the end without finding an element, we return `acc` that still contains `None`.

 ```1: 2: 3: 4: 5: 6: 7: 8: 9: ``` ``````let contains y xs = xs |> foldk (fun acc x k -> if x = y then true else k acc ) false [1..100] |> contains 10 // true [1..100] |> contains 0 // false ``````

Very similar to `tryPick`. We start with `false` and return it when we reach the end of the list without finding our wanted element. Otherwise we immediately return `true`. I think the rest of the function are nearly self-specking as they are all very similar.

 ``` 1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14: 15: 16: 17: 18: 19: 20: 21: 22: 23: 24: 25: 26: 27: 28: 29: 30: 31: 32: 33: 34: 35: 36: 37: 38: 39: 40: 41: 42: 43: 44: 45: 46: 47: 48: ``` ``````let exists predicate xs = xs |> foldk (fun acc x k -> if predicate x then true else k acc ) false [1..100] |> exists (fun x -> x % 50 = 0) // true [1..100] |> exists (fun x -> x < 0) // false let forall predicate xs = xs |> foldk (fun acc x k -> if predicate x then k acc else false ) true [2;4;6;8] |> forall (fun x -> x % 2 = 0) // true [2;4;6;8] |> forall (fun x -> x % 2 = 1) // false let item idx xs = xs |> foldk (fun acc x k -> if idx = acc then x else k (acc + 1) ) 0 [2..2..100] |> item 0 // 2 [2..2..100] |> item 1 // 4 [2..2..100] |> item 2 // 6 [2..2..100] |> item 10 // 22 let take amount xs = xs |> foldk (fun (collected,acc) x k -> if collected < amount then k (collected+1, x::acc) else (collected,acc) ) (0,[]) |> snd |> List.rev [1..100] |> take 0 // [] [1..100] |> take 3 // [1;2;3] [1..3] |> take 100 // [1;2;3] ``````

The last line is btw. not the exact behaviour of `List.take`. The standard implementation throws an exception if the input list has not enough elements. We can achieve the same by checking the `collected` field after `foldk` finished and throw an exception if it is not the same as `amount`. But I like my behaviour more than the default implementation.

I could continue by implementing further functions, but I think at this point it should be obvious how `foldk` works and how we use it.

## Summary

Implementing a CPS fold gives the control when to recurs to the caller of `foldk`.

Some task are easier solved by `foldk`. Some other task can sometimes be more efficient as we don't need to traverse the complete data-structure. In general if you ever wanted something similar to `break` or `continue` as you know it from imperative looping constructs. With `foldk` you have this ability and it works fine with immutable data-structures.

module Main
val foldk : f:('State -> 'a -> ('State -> 'State) -> 'State) -> acc:'State -> xs:'a list -> 'State

Full name: Main.foldk
val f : ('State -> 'a -> ('State -> 'State) -> 'State)
val acc : 'State
val xs : 'a list
val x : 'a
val lacc : 'State
Multiple items
module List

from Microsoft.FSharp.Collections

--------------------
type List<'T> =
| ( [] )
| ( :: ) of Head: 'T * Tail: 'T list
interface IEnumerable
interface IEnumerable<'T>
member GetSlice : startIndex:int option * endIndex:int option -> 'T list
member IsEmpty : bool
member Item : index:int -> 'T with get
member Length : int
member Tail : 'T list
static member Cons : head:'T * tail:'T list -> 'T list
static member Empty : 'T list

Full name: Microsoft.FSharp.Collections.List<_>
val fold : folder:('State -> 'T -> 'State) -> state:'State -> list:'T list -> 'State

Full name: Microsoft.FSharp.Collections.List.fold
val acc : int
val x : int
val k : (int -> int)
val tryPick : predicate:('a -> bool) -> xs:'a list -> 'a option

Full name: Main.tryPick
val predicate : ('a -> bool)
val acc : 'a option
val k : ('a option -> 'a option)
union case Option.Some: Value: 'T -> Option<'T>
union case Option.None: Option<'T>
val contains : y:'a -> xs:'a list -> bool (requires equality)

Full name: Main.contains
val y : 'a (requires equality)
val xs : 'a list (requires equality)
val acc : bool
val x : 'a (requires equality)
val k : (bool -> bool)
val exists : predicate:('a -> bool) -> xs:'a list -> bool

Full name: Main.exists
val forall : predicate:('a -> bool) -> xs:'a list -> bool

Full name: Main.forall
val item : idx:int -> xs:int list -> int

Full name: Main.item
val idx : int
val xs : int list
val take : amount:int -> xs:'a list -> 'a list

Full name: Main.take
val amount : int
val collected : int
val acc : 'a list
val k : (int * 'a list -> int * 'a list)
val snd : tuple:('T1 * 'T2) -> 'T2

Full name: Microsoft.FSharp.Core.Operators.snd
val rev : list:'T list -> 'T list

Full name: Microsoft.FSharp.Collections.List.rev