Memoization for Pure Functions
<no subtitle>
# Levenshtein Distance / String Edit Distance
(Note: All code in this post is part of guile-algorithms, which is AGPLv3 licensed.)
For some time I had the idea to look into how Levenshtein distance, also known as string edit distance, is calculated and implemented. Recently, I finally looked at its definition on Wikipedia. The naive recursive definition appeared to be quite simple.
I went ahead and implemented it in GNU Guile:
(define lev-naive (lambda* (str1 str2 #:key (equal? string=?)) (cond [(string-null? str1) (string-length str2)] [(string-null? str2) (string-length str1)] [else (let ([head1 (substring str1 0 1)] [head2 (substring str2 0 1)] [tail1 (substring str1 1)] [tail2 (substring str2 1)]) (cond [(equal? head1 head2) (lev-naive tail1 tail2)] [else (+ 1 (min ;; first character of str2 is inserted (lev-naive str1 tail2) ;; first character of str1 is inserted (lev-naive tail1 str2) ;; first character is replaced (lev-naive tail1 tail2)))]))])))
As one can see, this solution makes use of recursion and in functional style does not involve any mutation. Instead, it solves different parts of arguments by passing new, different values to functions. The whole thing has no side-effects and can easily be tested since same input arguments will always result in same return value [1]. The result entirely depends on the input arguments. Neat!
# Fatal Flaw
However, while this code is correct and works, it has a significant deficiency: It performs duplicate work many times. This slows it down significantly.
In the else branch of the cond form, we call lev-naive in various ways, but consider the function calls resulting from the the recursive calls:
- (lev-naive str1 tail2) and then (lev-naive tail1 str2), first shortening the second string, and then shortening the first string
- (lev-naive tail1 str2) and then (lev-naive str1 tail2), first shortening the first string, and then shortening the second string
Or the following recursive calls:
- (lev-naive str1 tail2) and then (lev-naive tail1 str2), first shortening the second string, and then shortening the first string
- (lev-naive tail1 tail2)
These result in the same arguments for the function calls, but there is no mechanism, that prevents the computational process from performing this duplicate work.
# Solution in the Imperative World
This is the point, where many computer programmers, and especially leetcoders will start shouting: "Dynamic programming!" [2]. That term, but not really meaningful term for "caching" or "memoization" of intermediate results of function calls. In typical leetcoder fashion, most would quickly introduce some kind of lookup table, usually a mutable hash table. During the execution of the code, the procedure would store results of recursive calls in the lookup table. At each procedure [3] call one would then check, whether the result of the procedure call involving the specific tuple of procedure arguments is already stored in the lookup table. If it is, one would return that result immediately, instead of performing duplicate work of calculating the result again.
# Solutions in the Functional World
Sticking with the functional programming paradigm, we cannot use a mutable data structure, as we cannot use mutation to update it. So what can we do? Have we already lost? No. There are ways of not using mutation and still having a lookup table that can be used to prevent duplicate work.
The "trick" is to use a persistent or purely functional data structure. Such a data structure does not allow mutation, but still offers all functions one needs. The functions take the data structure and any data needed to make an updated version of that data structure. Instead of mutating the data structure, a new, updated version of the data structure is returned. All previous handles on the passed in data structure remain bound to the same value as before (the "persistency" in "persistent data structure"), and no other part of the program needs to ever worry about not noticing changes. The effect is entirely local and encapsulated in the return value of the function.
# A Functional Lookup Table
Fortunately, there is a library for GNU Guile, which already contains some of these purely functional data structures: guile-pfds, which is also available on GNU Guix.
guile-pfds does not contain a functional lookup table or hash table. Many functional data structures are actually tree data structures. One of the data structures in guile-pfds is a "bbtree", a "balanced binary tree". We can implement a functional lookup table based on balanced binary trees. This might not be the very best solution, but it is sufficient for our purposes:
(define-module (functional-hash) #:use-module (ice-9 control) ;; #:use-module (pfds sets) #:use-module (pfds bbtrees) ;; for functional structs (not part of srfi-9 directly) #:use-module (srfi srfi-9 gnu) #:export (make-hash-table hash-table-set hash-table-ref hash-table-contains? hash-table-empty?)) ;; Define a unique value, to make sure it is always distinguishable ;; from any value that could be in hash table. (define-immutable-record-type <lookup-failure> (make-lookup-failure) lookup-failure?) (define make-hash-table (λ (<) (make-bbtree <))) (define hash-table-set (lambda* (bbtree key value #:key (update (λ (old new) new)) (default #f)) (define update-proc (λ (old) (update old value))) (bbtree-update bbtree key update-proc default))) (define hash-table-ref (lambda* (bbtree key #:optional (lookup-failure-proc (λ () #f))) (let ([result (bbtree-ref bbtree key (make-lookup-failure))]) (if (lookup-failure? result) (lookup-failure-proc) result)))) (define hash-table-contains? (lambda* (bbtree key) (hash-table-ref bbtree key (λ () #f)))) (define hash-table-empty? (λ (bbtree) (= (bbtree-size bbtree) 0)))
Note, that there is an operation <, which one needs to pass in, to get a balanced binary tree instance. This is because of the tree nature of the data structure. It requires an ordering. That ordering is solely defined by the < operation, that we pass to make-hash-table. The ordering operation needs to be applicable to the things we want to store in the lookup table. This will become relevant later.
# Adding the Lookup Table
Now that we have the a functional data structure, how do we use it in the string distance function? There are more difficulties ahead.
If we look at the recursive calls in the cond form of lev-naive again, we can see, that there are multiple arguments for min, which are all recursive calls, which can result in duplicate work.
In the imperative solution, we could simply lookup one argument after another in the lookup table, or calculate it, if it is not in the lookup table already, and memorize the result value in the lookup table, if not already present. The lookup table would be defined in an outer scope, so that its state is persisted across the recursive calls.
Here is how we can implement it in purely functional style:
(define string-pair< (λ (pair1 pair2) (let ([pair1-str1 (car pair1)] [pair1-str2 (cdr pair1)] [pair2-str1 (car pair2)] [pair2-str2 (cdr pair2)]) (or (string< pair1-str1 pair2-str1) (and (string=? pair1-str1 pair2-str1) (string< pair1-str2 pair2-str2))))))
Since we want to store distances between pairs of strings in our lookup table, we need to define a < operation, that works on pairs of strings. string-pair< does exactly that. It compares the pairs "hierarchically". If the first string of the first pair is "less" than the first string of the second pair, it considers the first pair to be less. If the first strings are equal, it compares the second strings of the pairs.
Using the ordering operation string-pair< we can create a lookup table as follows:
(define lookup-table (make-hash-table string-pair<))
In our purely functional implementation however, we cannot mutate the lookup table. We can only get a new version of it, which contains more entries. The recursive calls in the else branch of the cond form each require a lookup table that already has the results of the recursive calls that come before them. This means the calls need to happen sequentially and ensure that we pass updated versions of the lookup table. Here is one way to do it:
(define lev-naive-with-memoization (lambda* (str1 str2 #:key (equal? string=?)) (let iter ([str1 str1] [str2 str2] [lookup-table (make-hash-table string-pair<)]) (let ([cached (hash-table-ref lookup-table (cons str1 str2) (λ () #f))]) (cond [cached (values cached lookup-table)] [(string-null? str1) (let ([result (string-length str2)]) (values result (hash-table-set lookup-table (cons str1 str2) result)))] [(string-null? str2) (let ([result (string-length str1)]) (values result (hash-table-set lookup-table (cons str1 str2) result)))] [else (let ([head1 (substring str1 0 1)] [head2 (substring str2 0 1)] [tail1 (substring str1 1)] [tail2 (substring str2 1)]) (cond [(equal? head1 head2) (iter tail1 tail2 lookup-table)] [else ;; Case in which the first character of str2 is inserted. (let-values ([(result^ lookup-table^) (iter str1 tail2 lookup-table)]) ;; Case in which the first character of str1 is inserted. (let-values ([(result^^ lookup-table^^) (iter tail1 str2 lookup-table^)]) ;; Case in which first character is replaced. (let-values ([(result^^^ lookup-table^^^) (iter tail1 tail2 lookup-table^^)]) ;; Now we have all the parts and can calculate a result. (let ([result (+ 1 (min result^ result^^ result^^^))]) (values result (hash-table-set lookup-table^^^ (cons str1 str2) result))))))]))]))))
There are a few things that can be seen here:
- The recursive calls to iter now return multiple values. Multiple values here means actually multiple values, not a list or tuple or some other thing, that consists of multiple values. Not like someone hands you a box which contains multiple things, but like someone throws multiple things at you and says: "Catch!" and you need to use both hands to quickly catch everything. That's where let-values comes into play. It is a form that is able to receive multiple values returned by a procedure call.
- The 2 values returned by iter are the string edit distance for the arguments of the call (result^) and an updated lookup table (lookup-table^).
- Each updated lookup table is used for the next recursive call to iter, until all arguments of min are calculated.
- At no point we are actually replacing the original lookup-table binding or mutating anything.
# CPS solution
There is also another solution one could implement:
(define lev-naive-with-memoization-cps (lambda* (str1 str2 #:key (equal? string=?)) (let iter ([str1 str1] [str2 str2] [lookup-table (make-hash-table string-pair<)] [cont (λ (distance lookup-table) distance)]) (let ([cached (hash-table-ref lookup-table (cons str1 str2) (λ () #f))]) (cond [cached (cont cached lookup-table)] [(string-null? str1) (let ([result (string-length str2)]) (cont result (hash-table-set lookup-table (cons str1 str2) result)))] [(string-null? str2) (let ([result (string-length str1)]) (cont result (hash-table-set lookup-table (cons str1 str2) result)))] [else (let ([head1 (substring str1 0 1)] [head2 (substring str2 0 1)] [tail1 (substring str1 1)] [tail2 (substring str2 1)]) (cond [(equal? head1 head2) (iter tail1 tail2 lookup-table cont)] [else ;; Case in which the first character of str2 is inserted. (iter str1 tail2 lookup-table (λ (result^ lookup-table^) ;; Case in which the first character of str1 is ;; inserted. (iter tail1 str2 lookup-table^ (λ (result^^ lookup-table^^) ;; Case in which first character is ;; replaced. (iter tail1 tail2 lookup-table^^ (λ (result^^^ lookup-table^^^) ;; Now we have all the parts and ;; can calculate a result. But ;; there could be more ;; computations to be done, ;; wrapped inside the continuation ;; cont. We pass the result of ;; this whole branch to the ;; continuation. (let ([result (+ 1 (min result^ result^^ result^^^))]) (cont result (hash-table-set lookup-table^^^ (cons str1 str2) result)))))))))]))])))))
In this solution the recursive calls to iter take an additional argument cont. This is a short name for "continuation". It can be understood as the code that is still to be run when the call to iter finishes. For calculating an argument of min, that remaining code to be run is the calculation of the next argument. The calculation of the next argument needs to continue calculating the third argument. Along with the calls of iter the result values (result^, result^^, result^^^) for string distances need to be passed, as well as updated versions of the lookup table (lookup-table^, lookup-table^^, lookup-table^^^).
Instead of handling all calculations of arguments of min in the same call to iter, like in lev-naive-with-memoization, iter in lev-naive-with-memoization-cps accumulates the continuation, which will calculate the remaining arguments of min later. The continuation of the later min argument calculations is already clear. It is only the calculation of the next argument respectively, and ultimately the calculation of min and increment by 1. All this will give the result for another, outer, earlier, recursive call, unless it is the top level initial call of iter, at which point the function will be done calculating the string distance.
This style of writing the function is called "Continuation-passing style", or short "CPS".
# Differences between direct style and CPS
Some differences can be observed:
- Each recursive call will be a call in the tail position, a tail-call.
- Though for this lev-naive-with-memoization-cps function, it is not really that useful.
- The code in CPS seems harder to understand, because it diverges from the usual flow, where there is a return value, that is more easily visible. In CPS there does not need to be a return value. If the procedures called have side-effects or only some result at the very end of the whole computation is relevant, then a whole program written this way does not need to ever return or only return once at the very end of recursive calls. The continuations and calls of continuation and building up of more continuations can simply continue for the whole runtime of the program.
- Lets imagine we wanted to have concurrency. If our function is written in CPS, we could at any point at which a continuation would be called pause the evaluation of the function, keep the continuation around somewhere, and continue it at a later time. Meanwhile we could run other code.
| [1] | Called "referential transparency". |
| [2] | Dynamic programming is nothing but a fancy sounding term, that does nothing to explain what needs to be done to fix the flaw we discovered. Even the inventor of that term admitted it does not have a deeper meaning. When I hear the term "dynamic programming" I keep finding myself explaining it to myself as: "The program changes its behavior over the course of it running, by accumulating results of sub-problems, therefore it programs itself dynamically (at runtime).". However, this explanation is after the fact and not in any way the intention of the person inventing the term. |
| [3] | Since the code would involve mutating a thing in an outer scope, having side-effects, it is no longer really correct to call it a "function" in the mathematical sense, which is why we call it a "procedure". |
# Conclusions
- Functional data structures are great and enable us to stay in the functional realm, avoiding the complexity of mutation.
- There are solutions to memoization requiring algorithms, that stay within the constraints of the functional programming paradigm.
- We have seen CPS in action, even though for this particular example, it might not have been so useful.
- There are potential applications of CPS code for concurrency.
- Memoization, caching or lookup tables are not something that can only be done in imperative style. Purely functional lookup tables exist and we can add entries to them.