On Recursion, Continuations and Trampolines



How is tail recursion different from regular recursion? What do continuations have to do with this, what is CPS, and how do trampolines help? This article provides an introduction, with code samples in Python and Clojure.

Recursion and Tail Recursion

Here's a textbook version of a recursive factorial implementation in Python:

def fact_rec(n):
    if n == 0:
        return 1
    else:
        return n * fact_rec(n - 1)

Tail recursion is when the recursive call happens in tail position, meaning that it is the last thing the function does before returning its own result. Here's a tail-recursive version of factorial:

def fact_tailrec(n, result=1):
    if n == 0:
        return result
    else:
        return fact_tailrec(n - 1, result * n)

The tail call doesn't have to be directly recursive. It can call another function as well, implementing mutual recursion or some more complex scheme. Here's a canonical example of mutual recursion - a silly way to tell whether a number is odd or even:

def is_even(n):
    if n == 0:
        return True
    else:
        return is_odd(n - 1)

def is_odd(n):
    if n == 0:
        return False
    else:
        return is_even(n - 1)

All the function calls here are in tail position.

Both these examples are simple in a way, because they only contain a single call within each function. When functions make multiple calls, things become more challenging. Computing the Fibonacci sequence is a good example:

def fib_rec(n):
    if n < 2:
        return 1
    else:
        return fib_rec(n - 1) + fib_rec(n - 2)

Here we have two recursive calls to fib_rec within itself. Converting this function to a tail-call variant will be more challenging. How about this attempt:

def fib_almost_tail(n, result=1):
    if n < 2:
        return result
    prev2 = fib_almost_tail(n - 2)
    return fib_almost_tail(n - 1, prev2 + result)

The last thing fib_almost_tail does is call itself; so is this function tail-recursive? No, because there's another call to fib_almost_tail, and that one is not in tail position. Here's a more thorough attempt:

def fib_tail(n, accum1=1, accum2=1):
    if n < 2:
        return accum1
    else:
        return fib_tail(n - 1, accum1 + accum2, accum1)

Note that this conversion wasn't as simple as for the factorial; it's much less obvious how to come up with the algorithm, and we even changed the number of calls - there's only one recursive call here, while the original fib_rec had two. Obviously, it's challenging to have no calls outside a tail position in a function that calls multiple functions.

Blowing up the stack

Recursive solutions tend to be succinct and elegant; however, they carry a dangerous burden - the possibility of blowing up the runtime stack. In Python [1], the default call stack depth is 1000. If we try the fact_rec function shown above in a terminal, we'll get:

>>> fact_rec(900)
... some uselessly huge number

>>> fact_rec(1000)
... spew ...
RecursionError: maximum recursion depth exceeded in comparison

You may say - who needs to compute a factorial of 1000? And that may be true; however, with multiple recursion the problems start much earlier. If you try to compute the 50th Fibonacci number using fib_rec as shown above, you'll end up waiting for a very long time, even though the request seems modest at first glance. The reason is the exponential complexity of the naive implementation.

Note that fib_tail doesn't suffer from this problem because there's no exponential tree of calls, but it will also happily blow the stack when run with a sufficiently large number. The same is true for fact_tail, by the way. Tail recursion itself doesn't solve the stack issue; another ingredient is required and we'll cover it shortly.

Solutions: TCO or manual conversion to iteration

The problems described in the previous section help motivate the discussion of tail calls. Why convert to tail calls at all? Because then, in some languages, the compiler can automatically elide the stack buildup by converting the tail call to a jump. This trick is called tail-call optimization (TCO) [2]. Scheme has been doing it since the 1970s - indeed, since Scheme encourages programmers to write recursive algorithms, TCO is at the core of the language. More modern languages are catching up too - Lua supports TCO and JavaScript will too, once ES6 becomes the de-facto universal version.

Some languages do not support TCO, however. Python is one of those - Guido explicitly states that TCO is unpythonic and he doesn't want it in the language. In the end of this post I'll explain why I think it's not a big deal for Python. For other languages, it's a much bigger problem.

Take Clojure for example. Since Clojure is built on top of the JVM, it has to use JVM semantics for calls (if it wants any speed at all). The JVM doesn't have full support for TCO; so Clojure - a Lisp, mind you - ends up without TCO [3] Clojure takes a pragmatic approach and faces this problem with valor - it encourages a manual TCO conversion using the loop...recur pair:

(defn fib_iterative
  [n]
  (loop [n n
         accum1 1
         accum2 1]
    (if (< n 2)
      accum1
      (recur (- n 1) (+ accum1 accum2) accum1))))

Note the similarity between this code and the Python fib_tail shown earlier. This is not a coincidence! Once the algorithm is expressed in tail form, it's pretty easy to convert it to an iteration pattern manually; if it wasn't easy, compilers wouldn't be able to do it automatically for the past 40 years!

Just as a point of reference, here's fib_iterative in Python:

def fib_iterative(n):
    accum1, accum2 = 1, 1
    accum2 = 1
    while n >= 2:
        n -= 1
        accum1, accum2 = accum1 + accum2, accum1
    return accum1

Only slightly more awkward than the Clojure version - but it's essentially the same approach. Since the tail call carries the whole state around in arguments, we just imitate this using an explicit loop and state variables.

The iterative solution is what we really want here - it avoids the exponential algorithm and the stack explosion. It also doesn't incur the costs of a function call and return for every iteration. The only problem is that we have to do this manually in languages that don't support TCO. The beauty of automatic TCO is that you can write your algorithm recursively, and get the performance & runtime characteristics of an iterative solution.

At this point you may wonder how to convert indirect / mutual recursion to an iterative pattern - for example the even / odd pair above. While this doesn't present a problem for the compiler [4], to do it manually is indeed more challenging. We'll be covering this topic later in the article when we get to trampolines.

More realistic examples

Before we get to the more advanced topics, I'd like to present a few more realistic functions with an elegant recursive formulation that would be challenging to rewrite iteratively.

Let's start with merge sorting. Here's a straightforward Python implementation:

def merge_sort(lst):
    n = len(lst)
    if n <= 1:
        return lst
    else:
        mid = int(n / 2)
        return merge(merge_sort(lst[:mid]), merge_sort(lst[mid:]))

def merge(lst1, lst2):
    """Merges two sorted lists into a single sorted list.

    Returns new list. lst1 and lst2 are destroyed in the process."""
    result = []
    while lst1 or lst2:
        if not lst1:
            return result + lst2
        elif not lst2:
            return result + lst1
        if lst1[0] < lst2[0]:
            # Note: pop(0) may be slow -- this isn't optimized code.
            result.append(lst1.pop(0))
        else:
            result.append(lst2.pop(0))
    return result

In a way, merge sort always reminds me of postorder tree traversal - we recurse to the left, then recurse to the right, then combine the results. Such algorithms are fairly tricky to convert to non-recursive code. Try it! Chances are you'll end up emulating a stack, or coming up with an entirely different algorithm.

Merge sort is an example of multiple recursion, which as we've seen even for the simple Fibonacci, presents a challenge for TCO. Another common problem is indirect recursion. We've seen the trivial case of even / odd. For something more realistic consider a recursive-descent parser for this grammar:

<expr>    : <term> + <expr>
            <term>
<term>    : <factor> * <factor>
            <factor>
<factor>  : <number>
          | '(' <expr> ')'
<number>  : \d+

The full Python code is here; parse_expr calls parse_term; parse_term calls parse_factor; parse_factor, in term, calls parse_expr. For a complex expression, the call stack will end up containing multiple instances of each function, and at least in theory it's unbounded.

Continuations and CPS

Continuations are a cool concept in computer science, hailing from the earliest days of functional programming. There's tons of information online about continuations; my modest attempt to explain them here is just the beginning! If this looks interesting, make sure to google for more information.

Consider the following expression:

2 * (3 + 4)

One way to reason about its evaluation is:

  1. Compute value = 3 + 4
  2. Then compute 2 * value

We can view 2 * value to be the continuation of value = 3 + 4. Similarly, if the expression above is part of the bigger expression:

2 * (3 + 4) + 100

We can say that value + 100 is the continuation of 2 * (3 + 4). This may seem a bit abstract, so let's convert it to Lisp-y syntax [5] to bring it back to the domain of programming. Here is one way to compute 2 * (3 + 4):

(defn expr
  []
  (* 2 (+ 3 4)))

Now we can call (expr) and get 14 back. Another way to express the same computation is:

(defn end-cont [value] (print value))

(defn apply-cont
  [cont value]
  (cont value))

(defn expr-cps
  [cont]
  (apply-cont cont (+ 3 4)))

(defn make-double-cont
  [saved-cont]
  (fn [value]
    (apply-cont saved-cont (* 2 value))))

(expr-cps (make-double-cont end-cont))

We represent continuations as functions taking a single argument value. We also abstract away the concept of applying a continuation with apply-cont. The final continuation end-cont consumes the result of the whole computation and prints it out. Note how continuations are composed here: we invoke expr-cps, which expects a continuation. We use the make-double-cont constructor to create a continuation that doubles its value. Note how this doubling continuation works: it knows what its own continuation is, and applies it to (* 2 value). In fact, make-double-cont is just syntactic sugar; we could do without it:

(expr-cps (fn [value] (apply-cont end-cont (* 2 value))))

Now let's see how to do this for the longer expression. We keep the utilities defined earlier and add:

(defn make-plus100-cont
  [saved-cont]
  (fn [value]
    (apply-cont saved-cont (+ value 100))))

(expr-cps (make-double-cont (make-plus100-cont end-cont)))

What happens in this last invocation?

  1. expr-cps gets called with some continuation. It computes 3 + 4 and passes the result into the continuation.
  2. This continuation happens to be the doubling continuation, which applies 2 * value to its value and passes this result to its own continuation.
  3. That continuation, in turn, is a "plus 100" continuation: it applies value + 100 to its value and passes the result to its own continuation.
  4. The last continuation in the chain happens to be end-cont, which prints the overall result: 114

If all of this looks like a masochistic exercise in inverting the call stack (note how the continuations are composed - from the inside out), just a bit more patience - it will all start making sense soon. The -cps suffix of expr-cps stands for Continuation Passing Style, by the way, which is the style of programming we're seeing here; converting "normal" code into this style is called CPS-transform (or CPS conversion).

The lightbulb should go on when you make the following observation: all the expressions computed in this CPS approach are in tail position. Wait, what does it mean? The original function computing the full expression is:

(defn expr
  []
  (+ (* 2 (+ 3 4))) 100)

The only tail call here is to the outermost +. Both the * and the inner + are not in tail position. However, if you carefully examine the CPS approach, all the operator calls are in tail positions - their results are passed directly into the relevant continuations, without any changes. For this purpose we do not count the continuation application as a function call. We're going to be using this wonderful feature of CPS very soon to great benefit. But first, a brief dip into theory.

Undelimited and delimited continuations

The formulation of end-cont I'm using in the example above may appear peculiar. It prints its value - but what if we want to do something else with it? The strange print is a trick to emulate real, or unbounded continuations in a language that doesn't support them [6]. Applying unbounded continuations is not just calling a function - it's passing control without hope of return. Just like coroutines, or longjmp in C.

Unbounded continuations do not return to their caller. They express a flow of computation where results flow in one direction, without ever returning. This is getting beyond the scope of the article, but when unbounded continuations can be treated as first-class values in a language, they become so powerful that they can be used to implement pretty much any control-flow feature you can imagine - exceptions, threads, coroutines and so on. This is precisely what continuations are sometimes used for in the implementation of functional languages, where CPS-transform is one of the compilation stages.

I'd love to expound more on the topic, but I'll have to leave it to another day. If you're interested, read some information online and play with a language that supports real continuations - like Scheme with call/cc. It's fun and scary at the same time.

Even though most programming languages don't support real, unbounded continuations, bounded continuations is another deal. A bounded continuation is just a function that returns to its caller. We can still use CPS but just have to be realistic about our expectations. Applying a bounded continuation simply means calling a function - so the stack will grow.

If we cycle back to our expression, we can stop pretending our continuations are anything except a simulation, and just define:

(defn end-cont [value] value)

In fact, we don't even have to pretend (apply-cont cont value) is any different from simply calling (cont value), so now we can rewrite our CPS expression much more succinctly:

(defn real-end-cont [value] value)

(expr-cps (fn [value]
            ((fn [value] (real-end-cont (+ value 100))) (* 2 value))))

It looks a bit weird because we could just inline the (* 2 value) into the internal call, but keeping them separate will help us later.

Synthesizing tail calls with CPS-transform

Armed with this new tool, let's revisit some of the Python functions from the beginning of the article. For the factorial, we used an extra parameter to get a tail-call version; for Fibonacci we needed two; for more advanced examples (like merge sort) it wasn't very clear how to do the conversion. CPS-transform to the rescue!

It turns out we can convert any function to use tail calls instead of recursion (direct or indirect) by applying the following recipe:

  1. Pass each function an extra parameter - cont.
  2. Whenever the function returns an expression that doesn't contain function calls, send that expression to the continuation cont instead.
  3. Whenever a function call occurs in a tail position, call the function with the same continuation - cont.
  4. Whenever a function call occurs in an operand (non-tail) position, instead perform this call in a new continuation that gives a name to the result and continues with the expression.

This may not make much sense without examples. Let's review the recursive factorial first:

def fact_rec(n):
    if n == 0:
        return 1                       # (1)
    else:
        return n * fact_rec(n - 1)     # (2)

The line marked with (1) hits step 2 of the recipe; the line marked with (2) hits step 4, since a function call (to fact_rec itself) occurs in an operand position. Here is how we transform this function to CPS:

def fact_cps(n, cont):
    if n == 0:
        return cont(1)
    else:
        return fact_cps(n - 1, lambda value: cont(n * value))

The application of steps 1 and 2 is straightforward. Step 4 requires a bit more explanation. Since the call fact_rec(n - 1) is the one occurring in operand position, we extract it out and perform it in a new continuation. This continuation then passes n * value to the original continuation of fact_cps. Take a moment to convince yourself that this code does, in fact, compute the factorial. We have to run it with the "end continuation" discussed before:

>>> end_cont = lambda value: value
>>> fact_cps(6, end_cont)
720

Now let's do the same thing for Fibonacci, which demonstrates a more complex recursion pattern:

def fib_rec(n):
    if n < 2:
        return 1                                 # (1)
    else:
        return fib_rec(n - 1) + fib_rec(n - 2)   # (2)

Once again, applying steps 1 and 2 is trivial. Step 4 will have to applied on line marked with (2), but twice, since we have two function calls in operand positions. Let's handle the fib_rec(n - 1) first, similarly to what we did for the factorial:

def fib_cps_partial(n, cont):
    if n < 2:
        return cont(1)
    else:
        return fib_cps_partial(
                n - 1,
                lambda value: value + fib_cps_partial(n - 2, cont))

All calls in fib_cps_partial are in tail position now, but there's a problem. The continuation we crafted for the recursive call... itself has a call not in tail position. We'll have to apply CPS-transform once again, recursively. We'll treat the expression inside the lambda as just another function definition to transform. Here's the final version, which is fully transformed:

def fib_cps(n, cont):
    if n < 2:
        return cont(1)
    else:
        return fib_cps(
                 n - 1,
                 lambda value: fib_cps(
                                 n - 2,
                                 lambda value2: cont(value + value2)))

And once again, it's easy to see this version contains no calls that aren't in tail position. As opposed to the conversions shown in the beginning of the article, this one is much less ad-hoc and follows a clear recipe. In fact, it can be performed automatically by a compiler or a source transformation tool!

Just to show this is actually helpful in more general cases, let's tackle merge sort again. We have the recursive implementation at the top of this post, with the tricky part in the line:

return merge(merge_sort(lst[:mid]), merge_sort(lst[mid:]))

But transforming merge sort to CPS turns out not much different from transforming Fibonacci. I won't go through the partial stage for this now, and will just present the final answer:

def merge_sort_cps(lst, cont):
    n = len(lst)
    if n <= 1:
        return cont(lst)
    else:
        mid = int(n / 2)
        return merge_sort_cps(
                lst[:mid],
                lambda v1: merge_sort_cps(lst[mid:],
                                          lambda v2: cont(merge(v1, v2))))

The recursive-descent parser sample has an example of a more complex CPS-transform applied to realistic code, if you're interested.

Trampolines to avoid stack growth in tail-recursive calls

Now we're ready to discuss why we want to place all calls in tail position, even if our language doesn't support TCO. The final tool that ties things together is trampolines.

A blue jumpy trampoline

... not this kind!

Let's borrow a definition from Wikipedia:

As used in some Lisp implementations, a trampoline is a loop that iteratively invokes thunk-returning functions (continuation-passing style). A single trampoline suffices to express all control transfers of a program; a program so expressed is trampolined, or in trampolined style; converting a program to trampolined style is trampolining. Programmers can use trampolined functions to implement tail-recursive function calls in stack-oriented programming languages.

But wait, what's a "thunk-returning function"?

A thunk, in programming language jargon, is simply some expression wrapped in an argument-less function. This wrapping delays the evaluation of the expression until the point at which the function is called:

>>> 2 * (3 + 4)
14
>>> thunk = lambda: 2 * (3 + 4)
>>> thunk
<function <lambda> at 0x7f2c2977c510>
>>> thunk()
14

This example shows how we ask the interpreter to evaluate an expression. Then, we wrap it in a thunk: in Python simply a lambda with no arguments. The thunk itself is just a function. But when we call the thunk, the expression is actually evaluated. Thunks can be used to emulate Lazy Evaluation in languages that don't support it by default (like Python, or Clojure). But for our uses in this post, thunks are an essential part of the solution to the stack explosion problem.

The missing part of the puzzle is this:

def trampoline(f, *args):
    v = f(*args)
    while callable(v):
        v = v()
    return v

The trampoline is a Python function. It takes a function and a sequence of arguments, and applies the function to the arguments. Nothing more exciting than delayed evaluation so far. But there's more. If the function returns a callable, the trampoline assumes it's a thunk and calls it. And so on, until the function returns somethings that's no longer callable [7].

Remember how I said, when discussing unbounded continuations, that in "regular" languages like Python we're just cheating and simulating continuations with function calls? Trampolines is what make this viable without blowing the stack. Let's see how. Here's our CPS version of factorial, transformed once again to return a thunk:

def fact_cps_thunked(n, cont):
    if n == 0:
        return cont(1)
    else:
        return lambda: fact_cps_thunked(
                         n - 1,
                         lambda value: lambda: cont(n * value))

In this case the transformation is straightforward: we just wrap the tail calls in an argument-less lambda [8]. To invoke this function properly, we have to use a trampoline. So, to compute the factorial of 6, we'll do:

>>> trampoline(fact_cps_thunked, 6, end_cont)
720

Now comes the bang! If you carefully trace the execution of this trampoline, you'll immediately note that the stack doesn't grow! Instead of calling itself, fact_cps_thunked returns a thunk, so the call is done by the trampoline. Indeed, if we trace the function calls for the recursive factorial we get:

fact_rec(6)
  fact_rec(5)
    fact_rec(4)
      fact_rec(3)
        fact_rec(2)
          fact_rec(1)
            fact_rec(0)

But if we do the same for the thunked version, we get:

trampoline(<callable>, 6, <callable>)
  fact_cps_thunked(6, <callable>)
  fact_cps_thunked(5, <callable>)
  fact_cps_thunked(4, <callable>)
  fact_cps_thunked(3, <callable>)
  fact_cps_thunked(2, <callable>)
  fact_cps_thunked(1, <callable>)
  fact_cps_thunked(0, <callable>)

Remember how, eaerlier in the post, we've discovered the maximum stack depth of Python by invoking fact_rec(1000) and observing it blow up? No such problem with the thunked version:

>>> trampoline(fact_cps_thunked, 1000, end_cont)
... number with 2568 digits

>>> trampoline(fact_cps_thunked, 2000, end_cont)
... number with 5736 digits

The full Fibonacci sample shows how to use thunks and trampolines to compute the Fibonacci sequence without growing the stack.

I hope the pieces have fallen into place by now. By using a combination of CPS and trampolines, we've taken arbitrary recusive functions and converted them to tail-recursive versions that consume only a bounded number of stack frames. All of this in a language without TCO support.

Trampolines for mutual recursion

If you're left wondering how realistic this is, let's go back to the topic of mutual recursion. As I've mentioned before, Clojure doesn't support TCO, even though it's a Lisp. To overcome this, the recommended programming style in Clojure is explicit loop...recur iteration, which makes tail-recursive algorithms relatively easy (and efficient) to express. But this still leaves Clojure with the problem of mutual or indirect recursion, where loop...recur doesn't help.

Here's that silly even/odd example again, this time in Clojure:

(declare is-even?)

(defn is-odd?
  [n]
  (if (= n 0)
    false
    (is-even? (- n 1))))

(defn is-even?
  [n]
  (if (= n 0)
    true
    (is-odd? (- n 1))))

We can't get rid of the tail recursion here with loop...recur. But Clojure solves the problem by offering trampoline in the language core! Here's a thunked version:

(declare is-even-thunked?)

(defn is-odd-thunked?
  [n]
  (if (= n 0)
    false
    #(is-even-thunked? (- n 1))))

(defn is-even-thunked?
  [n]
  (if (= n 0)
    true
    #(is-odd-thunked? (- n 1))))

To invoke it:

=> (trampoline is-even-thunked? 3)
false

Note how small the difference from the non-thunked version is. This is due to Clojure's awesome syntax for anonymous functions, where a thunk is simply #(...).

Clojure's own implementation of trampoline is about what we'd expect. Here it is, pasted in full, including its educational docstring:

(defn trampoline
  "trampoline can be used to convert algorithms requiring mutual
  recursion without stack consumption. Calls f with supplied args, if
  any. If f returns a fn, calls that fn with no arguments, and
  continues to repeat, until the return value is not a fn, then
  returns that non-fn value. Note that if you want to return a fn as a
  final value, you must wrap it in some data structure and unpack it
  after trampoline returns."
  {:added "1.0"
   :static true}
  ([f]
     (let [ret (f)]
       (if (fn? ret)
         (recur ret)
         ret)))
  ([f & args]
(trampoline #(apply f args))))

Back to reality

While the tools described in this post can (and do) serve as building blocks for some compilers of functional languages, how relevant are they to day-to-day programming in languages like Python and Clojure?

The answer is, IMHO, not very, but they're still worth knowing about. For Clojure, quite obviously Rick Hickey found trampolines important enough to include them in the language. Since Clojure is not TCO'd and loop...recur is only good for direct recursion, some solution had to be offered for mutual/indirect recursion. But how often would you use it anyway?

Algorithms like merge-sort, or any tree-like traversal, tend to be fine just with regular recursion because the supported depth is more than sufficient. Because of the logarithmic nature of depth vs. total problem size, you're unlikely to recurse into merge-sort more than a few dozen times. An array needing just 30 divisions has to contain about a billion items.

The same is true for recursive-descent parsing, since realistic expressions get only so large. However, with some algorithms like graph traversals we definitely have to be more careful.

Another important use case is state machines, which may be conveniently expressed with indirect recursive calls for state transitions. Here's a code sample. Guido mentions this problem in his post on TCO in Python, and actually suggests trampolined tail-calls as a possible solution.

Python generators and coroutines as an alternative

That said, I personally believe that Python offers better ways to solve these problems. I've written before about using coroutines for modeling state machines. Since then, Python grew more supportive features - I've written about using yield from to implement lexical scanning, for example, and similar techniques can be adapted for parsing.

In Python 3.5, even more features were added to support coroutines. I plan to find some time to dig in these and write more about them.

All of this is to say that I wouldn't find much use for direct expression of CPS and trampolines in Python code these days. But I may be wrong! Please feel free to make suggestions in the comments or by email - I'll be really curious to know about realistic use cases where these techniques could be employed in modern Python.

Regardless of their usefulness in modern Python, these topics are fascinating and I feel they improve my understanding of how some programming languages work under the hood.


[1]For this post I'm using Python 3.5 to run and test all the code, on an Ubuntu Linux machine; for a different version of Python or environment, the recursion limit may be different.
[2]Alternatively you may see TRE (Tail Recursion Elimination). TCO is more general because tail calls don't necessarily have to be directly recursive (as the is_even / is_odd example demonstrates).
[3]AFAIU, some JVM languages emulate TCO by using trampolines, but Clojure doesn't since it prefers to be compatible with Java and retain high performance. My information may be out of date here, so please feel free to comment if I'm wrong.
[4]The compiler can cheat by emitting constructs not accessible from the source language. What happens in practice, on the lowest level of emitted machine code, is that instead of a call the compiler just prepares a stack frame for the called function and jumps to it. Since the results of the called function will not be used in the caller (except returning them further up the chain), the compiler doesn't have to save any caller state.
[5]I'm using a Lisp (Clojure) here because it unifies the syntax of mathematical operations with function calls, which makes the explanation less convoluted.
[6]Clojure in this case, but Python doesn't support them either.
[7]Astute readers will note this is problematic when our code makes heavy use of first-class functions and we may legitimately return a function from another function call. This is true, and in this case it should be easy to make a thunk something more explicit than simply a lambda. We could, for example, encapsulate it in a type - Thunk, that we would check with isinstance inside trampoline, instead of just callable.
[8]Note that the new continuation constructed for the recursive call also returns a thunk. Figuring out why this is necessary is a good exercise!

Summary of reading: January - March 2017



  • "The history of the Supreme Court" by Peter Irons (audio course) - a comprehensive history of the US Supreme Court, focusing on the most influential judges and the most seminal decisions it made over the years. I liked the latter parts of the book better than the former, except hearing about the lecturer's own opinion on the cases a bit too much for my taste. I expected a more objective historic treatment.
  • "The Three Comrades" by Erich Maria Remarque - a somewhat dark, but well-written novel about life in Germany in the years after WWI. Not quite as good as the Western Front, but still enjoyable. I wonder how much money the main character would have if he just wouldn't drink so much, but that probably doesn't matter in the long term. I just seem to have something against books focused on drinking (this is why I find Hemingway painful to read).
  • "Strangers in Their Own Land: Anger and Mourning on the American Right" by Arlie Russell Hochschild - Set to understand the paradox of many "red state" voters in voting against their apparent interests (in terms of public services, welfare and environmental issues), the author - a liberal Berkeley professor - went to Louisiana to befriend some of its inhabitants and to bridge the "empathy wall" between left and right. Very timely book with the recent election results. I wouldn't say the book fully answers the big question, but it does provide some interesting and useful clues. The most insightful observation is the "cutting in line" theory, in which white, working Christian males from the south see their status in society continually eroded by granting rights and recognition to other segments of the population - blacks, women, poor folks on welfare and most recently LGBT; the latter, along with abortions go against these voters' notion of a normative family, and makes their deepest traditions seem under attack. It's not very clear what the actionable conclusions from this book might be, but it definitely is a good attempt to help left-wing liberals understand what drives folks on the other side of the divide.
  • "The Words We Live By" by Linda R. Monk - this is an annotated guide to the US constitution. The constitution is spread out across 270 (large) pages, article by article, section by section, interspersed with historical background, relevant supreme court cases and quotes from politicians, judges and notable activists. Very nicely put together - this book is pretty interesting to just read cover-to-cover, and can serve as a good introductory reference material to the constitution as well.
  • "The Blessing of a Skinned Knee" by Wendy Mogel - nominally, this is a parenting book written by a child psychologist but focusing mostly on learnings from Jewish tradition and religious texts. In the role of a parenting book I found it fairly average - it lists some good ideas I agree with, but suffers from evanescence typical of self-help books. However, it a role of conveying the important of religion and tradition to overall family life and well-being, I found it intriguing. I doubt I can bring myself to do the leap of faith necessary to follow the author's path, but I found these parts of the book thought provoking nevertheless.
  • "The Gene: An Intimate History" by Siddhartha Mukherjee - good, but not great, book about genetics. I was very happy when the historical coverage appeared to have ended shortly after the middle of the book, anticipating lots of interesting reading about modern research. However, at that point the author spent too much paper talking about the moral implications of genetic research, rather than focusing on the science itself. The writing is excellent, but for me this book didn't really stand out among the crowd of other books on this subject.
  • "1491: New Revelations of the Americas Before Columbus" by Charles C. Mann - this book is set to attack the myth that prior to the arrival Europeans, native Americans were sparsely settled in small populations that didn't affect the environment too much. It provides evidence and historical/archaeological research in favor of much larger populations that shaped their environment and developed highly complex cultures. The book is nice to read, but I found many of the arguments unconvincing (although I wanted to be convinced). It seems like there's a dearth of real scientific data, and a lot of "proof" relies on badly preserved verbal accounts of various European explorers and ambiguous archaeological findings. Some of the arguments are more convincing, like the extent to which European disease decimated local populations. For some reason the book also spends lots of pages recounting various native American myths and folk stories, which feels absolutely unnecessary. It would be a better, albeit shorter book without this.
  • "The Blank Slate: The Modern Denial of Human Nature" by Steven Pinker - somewhat long-winded treatise on the generic vs. environmental factors involved in the psychological development of humans. Lots of pages spent flaming opposing works, which may make someone feel on the scientific edge, but I personally felt wasn't that interesting. All in all, I found this book less interesting and useful than previous books I read on this and similar topics. One notable mention is Pinker's treatment of post-modernism in art. Specifically, it suddenly hit me that post-modernism is precisely what Any Rand was mocking in "The Fountainhead".
  • "A Conflict of Visions" by Thomas Sowell - one of the densest and least readable books I've encountered in the past few years. The laconic audio narration didn't help here. The ideas Sowell presents are interesting, with the caveat that you can never divide something so complex into two categories cleanly. It's certainly a brave attempt at digging up the real roots of the differences between conservatives and liberals. That said, it utterly fails to explain mixed views some people hold that are part-conservative and part-liberal (myself included).

Re-reads:

  • "The Grapes of Wrath" by John Steinbeck
  • "Peopleware: Productive Projects and Teams" by Tom DeMarco and Tim Lister

Adventures in JIT compilation: Part 2 - an x64 JIT



In the first part of the series I've briefly introduced the BF source language and went on to present four interpreters with increasing degree of optimization. That post should serve as a good backgroud before diving into actual JIT-ing.

Another important part of the background puzzle is my How to JIT - an introduction post from 2013; there, I discuss some of the basic tools needed to emit executable x64 machine code at run-time and actually run it on Linux. Please go through it quickly if these things are new to you.

The two phases of JIT

As I wrote previously, the JIT technique is easier to understand when divided into two distinct phases:

  1. Create machine code at program run-time.
  2. Execute that machine code, also at program run-time.

Phase 2 for our BF JIT is exactly identical to the method described in that introductory post. Take a look at the JitProgram class in jit_utils for details. We'll be more focused on phase 1, which will be translating BF to x64 machine code; per the definition quoted in part 1 of the series, we're going to develop an actual BF compiler (compiling from BF source to x64 machine code).

Compilers, assemblers and instruction encoding

Traditionally, compilation was divided into several stages. The actual source language compiler would translate some higher-level language to target-specific assembly; then, an assembler would translate assembly to actual machine code [1]. There's a number of important benefits assembly language provides over raw machine code. Salient examples include:

  • Instruction encoding: it's certainly nicer to write inc %r13 to increment the contents of register r13 than to write 0x49, 0xFF, 0xC5. Instruction encoding for the popular architectures is notoriously complicated.
  • Naming labels and procedures for jumps/calls: it's easier to write jl loop than to figure out the encoding for the instruction, along with the relative position of the loop label and encoding the delta to it (not to mention this delta changes every time we add instructions in between and needs to be recomputed). Similarly for functions, call foo instead of doing it by address.

One of my guiding principles through the field of programming is that before diving into the possible solutions for a problem (for example, some library for doing X) it's worth working through the problem manually first (doing X by hand, without libraries). Grinding your teeth over issues for a while is the best way to appreciate what the shrinkwrapped solution/library does for you.

In this spirit, our first JIT is going to be completely hand-written.

Simple JIT - hand-rolling x64 instruction encoding

Out first JIT for this post is simplejit.cpp. Similarly to the interpreters of part 1, all the action happens in a single function (here called simplejit) invoked from main. simplejit goes through the BF source and emits x64 machine code into a memory buffer; in the end, it jumps to this machine code to run the BF program.

Here's its beginning:

std::vector<uint8_t> memory(MEMORY_SIZE, 0);

// Registers used in the program:
//
// r13: the data pointer -- contains the address of memory.data()
//
// rax, rdi, rsi, rdx: used for making system calls, per the ABI.

CodeEmitter emitter;

// Throughout the translation loop, this stack contains offsets (in the
// emitter code vector) of locations for fixup.
std::stack<size_t> open_bracket_stack;

// movabs <address of memory.data>, %r13
emitter.EmitBytes({0x49, 0xBD});
emitter.EmitUint64((uint64_t)memory.data());

As usual, we have our BF memory buffer in a std::vector. The comments reveal some of the conventions used througout the emitted program: our "data pointer" will be in r13.

CodeEmitter is a very simple utility to append bytes and words to a vector of bytes. Its full code is here. It's platform independent except the assumption of little-endian (for EmitUint64 it will write the lowest byte of the 64-bit word first, then the second lowest byte, etc.)

Our first bit of actual machine code emission follows:

// movabs <address of memory.data>, %r13
emitter.EmitBytes({0x49, 0xBD});
emitter.EmitUint64((uint64_t)memory.data());

And it's a cool one, mixing elements from the host (the C++ program doing the emission) and the JITed code. First note the usage of movabs, a x64 instruction useful for placing 64-bit immediates in a register. This is exactly what we're doing here - placing the address of the data buffer of memory in r13. The call to EmitBytes with a cryptic sequence of hex values is preceded by a snippet of assembly in a comment - the assembly conveys the meaning for human readers, the hex values are the actual encoding the machine will understand.

Then comes the BF compilation loop, which looks at the next BF instruction and emits the appropriate machine code for it. Our compiler works in a single pass; this means that there's a bit of trickiness in handling the jumps, as we will soon see.

for (size_t pc = 0; pc < p.instructions.size(); ++pc) {
  char instruction = p.instructions[pc];
  switch (instruction) {
  case '>':
    // inc %r13
    emitter.EmitBytes({0x49, 0xFF, 0xC5});
    break;
  case '<':
    // dec %r13
    emitter.EmitBytes({0x49, 0xFF, 0xCD});
    break;
  case '+':
    // Our memory is byte-addressable, so using addb/subb for modifying it.
    // addb $1, 0(%r13)
    emitter.EmitBytes({0x41, 0x80, 0x45, 0x00, 0x01});
    break;
  case '-':
    // subb $1, 0(%r13)
    emitter.EmitBytes({0x41, 0x80, 0x6D, 0x00, 0x01});
    break;

These are pretty straightforward; since r13 is the data pointer, > and < increment and decrement it, while + and - increment and decrement what it's pointing to. One slightly subtle aspect is that I chose a byte-value memory for our BF implementations; this means we have to be careful when reading or writing to memory and do byte-addressing (the b suffixes on add and sub above) rather than the default 64-bit-addressing.

The code emitted for . and , is a bit more exciting; in the effort of avoiding any external dependencies, we're going to invoke Linux system calls directly. WRITE for .; READ for ,. We're using the x64 ABI here with the syscall identifier in rax:

  // To emit one byte to stdout, call the write syscall with fd=1 (for
  // stdout), buf=address of byte, count=1.
  //
  // mov $1, %rax
  // mov $1, %rdi
  // mov %r13, %rsi
  // mov $1, %rdx
  // syscall
  emitter.EmitBytes({0x48, 0xC7, 0xC0, 0x01, 0x00, 0x00, 0x00});
  emitter.EmitBytes({0x48, 0xC7, 0xC7, 0x01, 0x00, 0x00, 0x00});
  emitter.EmitBytes({0x4C, 0x89, 0xEE});
  emitter.EmitBytes({0x48, 0xC7, 0xC2, 0x01, 0x00, 0x00, 0x00});
  emitter.EmitBytes({0x0F, 0x05});
  break;
case ',':
  // To read one byte from stdin, call the read syscall with fd=0 (for
  // stdin),
  // buf=address of byte, count=1.
  emitter.EmitBytes({0x48, 0xC7, 0xC0, 0x00, 0x00, 0x00, 0x00});
  emitter.EmitBytes({0x48, 0xC7, 0xC7, 0x00, 0x00, 0x00, 0x00});
  emitter.EmitBytes({0x4C, 0x89, 0xEE});
  emitter.EmitBytes({0x48, 0xC7, 0xC2, 0x01, 0x00, 0x00, 0x00});
  emitter.EmitBytes({0x0F, 0x05});
  break;

The comments certainly help, don't they? I hope these snippets are a great motivation for using assembly language rather than encoding instructions manually :-)

The jump instructions are always the most interesting in BF. For [ we do:

case '[':
  // For the jumps we always emit the instruciton for 32-bit pc-relative
  // jump, without worrying about potentially short jumps and relaxation.

  // cmpb $0, 0(%r13)
  emitter.EmitBytes({0x41, 0x80, 0x7d, 0x00, 0x00});

  // Save the location in the stack, and emit JZ (with 32-bit relative
  // offset) with 4 placeholder zeroes that will be fixed up later.
  open_bracket_stack.push(emitter.size());
  emitter.EmitBytes({0x0F, 0x84});
  emitter.EmitUint32(0);
  break;

Note that we don't know where this jump leads at this point - it will go to the matching ], which we haven't encountered yet! Therefore, to keep our compilation in a single pass [2] we use the time-honored technique of backpatching by emitting a placeholder value for the jump and fixing it up once we encounter the matching label. Another thing to note is always using a 32-bit pc-relative jump, for simplicity; we could save a couple of bytes with a short jump in most cases (see my article on assembler relaxation for the full scoop), but I don't think it's worth the effort here.

Compiling the matching ] is a bit trickier; I hope the comments do a good job explaining what's going on, and the code itself is optimized for readability rather than cleverness:

case ']': {
  if (open_bracket_stack.empty()) {
    DIE << "unmatched closing ']' at pc=" << pc;
  }
  size_t open_bracket_offset = open_bracket_stack.top();
  open_bracket_stack.pop();

  // cmpb $0, 0(%r13)
  emitter.EmitBytes({0x41, 0x80, 0x7d, 0x00, 0x00});

  // open_bracket_offset points to the JZ that jumps to this closing
  // bracket. We'll need to fix up the offset for that JZ, as well as emit a
  // JNZ with a correct offset back. Note that both [ and ] jump to the
  // instruction *after* the matching bracket if their condition is
  // fulfilled.

  // Compute the offset for this jump. The jump start is computed from after
  // the jump instruction, and the target is the instruction after the one
  // saved on the stack.
  size_t jump_back_from = emitter.size() + 6;
  size_t jump_back_to = open_bracket_offset + 6;
  uint32_t pcrel_offset_back =
      compute_relative_32bit_offset(jump_back_from, jump_back_to);

  // jnz <open_bracket_location>
  emitter.EmitBytes({0x0F, 0x85});
  emitter.EmitUint32(pcrel_offset_back);

  // Also fix up the forward jump at the matching [. Note that here we don't
  // need to add the size of this jmp to the "jump to" offset, since the jmp
  // was already emitted and the emitter size was bumped forward.
  size_t jump_forward_from = open_bracket_offset + 6;
  size_t jump_forward_to = emitter.size();
  uint32_t pcrel_offset_forward =
      compute_relative_32bit_offset(jump_forward_from, jump_forward_to);
  emitter.ReplaceUint32AtOffset(open_bracket_offset + 2,
                                pcrel_offset_forward);
  break;
}

This concludes the compiler loop; we end up with a bunch of potentially executable machine code in vector. This code refers to the host program (the address of memory.data()), but that's OK since the host program's lifetime wraps the lifetime of the JITed code. What's remaining is to actually invoke this machine code:

// ... after the compilation loop

// The emitted code will be called as a function from C++; therefore it has to
// use the proper calling convention. Emit a 'ret' for orderly return to the
// caller.
emitter.EmitByte(0xC3);

// Load the emitted code to executable memory and run it.
std::vector<uint8_t> emitted_code = emitter.code();
JitProgram jit_program(emitted_code);

// JittedFunc is the C++ type for the JIT function emitted here. The emitted
// function is callable from C++ and follows the x64 System V ABI.
using JittedFunc = void (*)(void);

JittedFunc func = (JittedFunc)jit_program.program_memory();
func();

The call should be familiar from reading the How to JIT post. Note that here we opted for the simplest function possible - no arguments, no return value; in future sections we'll spice it up a bit.

Taking our JIT for a spin

In part 1, I presented a trivial BF program that prints the numbers 1 to 5 to the screen:

++++++++ ++++++++ ++++++++ ++++++++ ++++++++ ++++++++
>+++++
[<+.>-]

Let's see what our compiler translates it to. Even though the code vector inside simplejit is ephemeral (lives only temporarily in memory), we can serialize it to a binary file which we can then disassemble (with objdump -D -b binary -mi386:x86-64). The following is the disassembly listing with comments I embedded to explain what's going on:

 # The runtime address of memory.data() goes into r13; note that this will
 # likely be a different value in every invocation of the JIT.

  0:   49 bd f0 54 e3 00 00    movabs $0xe354f0,%r13
  7:   00 00 00

 # A sequence of 48 instructions that all do the same, for the initial sequence
 # of +s; this makes me miss our optimizing interpreter, by worry not - we'll
 # make this go away later in the post.

  a:   41 80 45 00 01          addb   $0x1,0x0(%r13)
  f:   41 80 45 00 01          addb   $0x1,0x0(%r13)

 # [...] 46 more 'addb'

 # >+++++

 fa:   49 ff c5                inc    %r13
 fd:   41 80 45 00 01          addb   $0x1,0x0(%r13)
102:   41 80 45 00 01          addb   $0x1,0x0(%r13)
107:   41 80 45 00 01          addb   $0x1,0x0(%r13)
10c:   41 80 45 00 01          addb   $0x1,0x0(%r13)
111:   41 80 45 00 01          addb   $0x1,0x0(%r13)

 # Here comes the loop! Note that the relative jump offset is already inserted
 # into the 'je' instruction by the backpatching process.

116:   41 80 7d 00 00          cmpb   $0x0,0x0(%r13)
11b:   0f 84 35 00 00 00       je     0x156
121:   49 ff cd                dec    %r13
124:   41 80 45 00 01          addb   $0x1,0x0(%r13)

 # The '.' is translated into a syscall to WRITE

129:   48 c7 c0 01 00 00 00    mov    $0x1,%rax
130:   48 c7 c7 01 00 00 00    mov    $0x1,%rdi
137:   4c 89 ee                mov    %r13,%rsi
13a:   48 c7 c2 01 00 00 00    mov    $0x1,%rdx
141:   0f 05                   syscall
143:   49 ff c5                inc    %r13
146:   41 80 6d 00 01          subb   $0x1,0x0(%r13)
14b:   41 80 7d 00 00          cmpb   $0x0,0x0(%r13)

 # Jump back to beginning of loop

150:   0f 85 cb ff ff ff       jne    0x121

 # We're done

156:   c3                      retq

How does it perform?

It's time to measure the performance of our JIT against the interpreters from part 1. optinterp3 was about 10x faster than the naive interpreter - how will this JIT measure up? Note that it has no optimizations (except not having to recompute the jump destination for every loop iteration as the naive interpreter did). Can you guess? The results may surprise you...

The simple JIT runs mandelbrot in 2.89 seconds, and factor in 0.94 seconds - much faster still than opt3interp; here's the comparison plot (omitting the slower interpreters since they skew the scale):

BF opt3 vs simplejit

Why is this so? opt3interp is heavily optimized - it folds entire loops into a single operation; simplejit does none of this - we've just seen the embarrassing sequence of addbs it emits for a long sequence of +s.

The reason is that the baseline performance of the JIT is vastly better. I've mentioned this briefly in part 1 - imagine what's needed to interpret a single instruction in the fastest interpreter.

  1. Advance pc and compare it to program size.
  2. Grab the instruction at pc.
  3. Switch on the value of the instruction to the right case.
  4. Execute the case.

This requires a whole sequence of machine instructions, with at least two branches (one for the loop, one for the switch). On the other hand, the JIT just emits a single instruction - no branches. I would say that - depending on what the compiler did while compiling the interpreter - the JIT is between 4 and 8 times faster at running any given BF operation. It has to run many more BF operations because it doesn't optimize, but this difference is insufficient to close the huge baseline gap. Later in this post we're going to see an optimized JIT which performs even better.

But first, let's talk about this painful instruction encoding business.

Manually encoding instructions

As promised, simplejit is completely self-contained. It doesn't use any external libraries, and encodes all the instructions by hand. It's not hard to see how painful that process is, and the code is absolutely unreadable unless accompanied by detailed comments; moreover, changing the code is a pain, and changes happen in unexpected ways. For example, if we want to use some other register in an instruction, the change to emitted code won't be intuitive. add %r8, %r9 is encoded as 0x4C, 0x01, 0xC8, but add %r8, %r10 is 0x4C, 0x01, 0xD0; since registers are specified in sub-byte nibbles, one needs very good memory and tons of experience to predict what goes where.

Would you expect related instructions to look somewhat similar? They don't. inc %r13 is encoded as 0x49, 0xFF, 0xC0, for example. To put it bluntly - unless you're Mel, you're going to have a hard time. Now imagine that you have to support emitting code for multiple architectures!

This is why all compilers, VMs and related projects have their own layers to help with this encoding task, along with related tasks like labels and jump computations. Most are not exposed for easy usage outside their project; others, like DynASM (developed as part of the LuaJIT project) are packaged for separate usage. DynASM is an example of a low-level framework - providing instruction encoding and not much else; some frameworks are higher-level, doing more compiler-y things like register allocation. One example is libjit; another is LLVM.

asmjit

While looking for a library to help me encode instructions, I initially tried DynASM. It's an interesting approach - and you can see Josh Haberman's post about using it for a simple BF JIT, but I found it to be a bit too abandonware-ish for my taste. Besides, I don't like the funky preprocessor approach with a dependency on Lua.

So I found another project that seemed to fit the bill - asmjit - a pure C++ library without any preprocessing. asmjit began about 3 years ago to ease its author's development of fast kernels for graphics code. Its documentation isn't much better than dynasm's, but being just a C++ library I found it easier to dive into the source when questions arose the docs couldn't answer. Besides, the author is very active and quick in answering questions on Github and adding missing featuers. Therefore, the rest of this post shows BF JITs that use asmjit - these can also serve as a non-trivial tutorial for the library.

simpleasmjit - JIT with sane instruction encoding

Enter simpleasmjit.cpp - the same simple JIT (no optimizations) as simplejit, but using asmjit for the instruction encoding, labels and so on. Just for fun, we'll mix things up a bit. First, we'll change the JITed function signature from void (*)(void) to void (*)(uint64_t); the address of the BF memory buffer will be passed as argument into the JITed function rather than hard-coded into it.

Second, we'll use actual C functions to emit / input characters, rather than system calls. Moreover, since putchar and getchar may be macros on some systems, taking their address can be unsafe. So we'll wrap them in actual C++ functions, whose address it is safe to take in emitted code:

void myputchar(uint8_t c) {
  putchar(c);
}

uint8_t mygetchar() {
  return getchar();
}

simpleasmjit starts by initializing an asmjit runtime, code holder and assembler [3]:

asmjit::JitRuntime jit_runtime;
asmjit::CodeHolder code;
code.init(jit_runtime.getCodeInfo());
asmjit::X86Assembler assm(&code);

Next, we'll give a mnemonic name to our data pointer, and emit a copy of the address of the memory buffer into it (it's in rdi initially, as the first function argument in the x64 ABI):

// We pass the data pointer as an argument to the JITed function, so it's
// expected to be in rdi. Move it to r13.
asmjit::X86Gp dataptr = asmjit::x86::r13;
assm.mov(dataptr, asmjit::x86::rdi);

Then we get to the usual BF processing loop that emits code for every BF op:

for (size_t pc = 0; pc < p.instructions.size(); ++pc) {
  char instruction = p.instructions[pc];
  switch (instruction) {
  case '>':
    // inc %r13
    assm.inc(dataptr);
    break;
  case '<':
    // dec %r13
    assm.dec(dataptr);
    break;
  case '+':
    // addb $1, 0(%r13)
    assm.add(asmjit::x86::byte_ptr(dataptr), 1);
    break;
  case '-':
    // subb $1, 0(%r13)
    assm.sub(asmjit::x86::byte_ptr(dataptr), 1);
    break;

Notice the difference! No more obscure hex codes - assm.inc(dataptr) is so much nicer than 0x49, 0xFF, 0xC5, isn't it?

For input and output we emit calls to our wrapper functions:

case '.':
  // call myputchar [dataptr]
  assm.movzx(asmjit::x86::rdi, asmjit::x86::byte_ptr(dataptr));
  assm.call(asmjit::imm_ptr(myputchar));
  break;
case ',':
  // [dataptr] = call mygetchar
  // Store only the low byte to memory to avoid overwriting unrelated data.
  assm.call(asmjit::imm_ptr(mygetchar));
  assm.mov(asmjit::x86::byte_ptr(dataptr), asmjit::x86::al);
  break;

The magic is in the imm_ptr modifier, which places the address of the function in the emitted code.

Finally, the code handling [ and ] is also much simpler due to asmjit's labels, which can be used before they're actually emitted:

case '[': {
  assm.cmp(asmjit::x86::byte_ptr(dataptr), 0);
  asmjit::Label open_label = assm.newLabel();
  asmjit::Label close_label = assm.newLabel();

  // Jump past the closing ']' if [dataptr] = 0; close_label wasn't bound
  // yet (it will be bound when we handle the matching ']'), but asmjit lets
  // us emit the jump now and will handle the back-patching later.
  assm.jz(close_label);

  // open_label is bound past the jump; all in all, we're emitting:
  //
  //    cmpb 0(%r13), 0
  //    jz close_label
  // open_label:
  //    ...
  assm.bind(open_label);

  // Save both labels on the stack.
  open_bracket_stack.push(BracketLabels(open_label, close_label));
  break;
}
case ']': {
  if (open_bracket_stack.empty()) {
    DIE << "unmatched closing ']' at pc=" << pc;
  }
  BracketLabels labels = open_bracket_stack.top();
  open_bracket_stack.pop();

  //    cmpb 0(%r13), 0
  //    jnz open_label
  // close_label:
  //    ...
  assm.cmp(asmjit::x86::byte_ptr(dataptr), 0);
  assm.jnz(labels.open_label);
  assm.bind(labels.close_label);
  break;
}

We just have to remember which label we used for the jump and emit the exact same Label object - asmjit handles the backpatching on its own! Moreover, all the jump offset computations are performed automatically.

Finally, after emitting the code we can call it:

using JittedFunc = void (*)(uint64_t);

JittedFunc func;
asmjit::Error err = jit_runtime.add(&func, &code);
// [...]
// Call it, passing the address of memory as a parameter.
func((uint64_t)memory.data());

That's it. This JIT emits virtually the same exact code as simplejit, and thus we don't expect it to perform any differently. The main point of this exercise is to show how much simpler and more pleasant emitting code is with a library like asmjit. It hides all the icky encoding and offset computations, letting us focus on what's actually unique for our program - the sequence of instructions emitted.

optasmjit - combining BF optimizations with a JIT

Finally, it's time to combine the clever optimizations we've developed in part 1 with the JIT. Here, I'm essentially taking optinterp3 from part 1 and bolting a JIT backend onto it. The result is optasmjit.cpp.

Recall that instead of the 8 BF ops, we have an extended set, with integer arguments, that conveys higher-level ops in some cases:

enum class BfOpKind {
  INVALID_OP = 0,
  INC_PTR,
  DEC_PTR,
  INC_DATA,
  DEC_DATA,
  READ_STDIN,
  WRITE_STDOUT,
  LOOP_SET_TO_ZERO,
  LOOP_MOVE_PTR,
  LOOP_MOVE_DATA,
  JUMP_IF_DATA_ZERO,
  JUMP_IF_DATA_NOT_ZERO
};

The translation phase from BF ops to a sequence of BfOpKind is exactly the same as it was in optinterp3. Let's take a look at how a couple of the new ops are implemented now:

case BfOpKind::INC_PTR:
  assm.add(dataptr, op.argument);
  break;

As before with the interpreters, an increment of 1 is replaced by the addition of an argument. We use a different instruction for this - add instead of inc [4]. How about something more interesting:

case BfOpKind::LOOP_MOVE_DATA: {
  // Only move if the current data isn't 0:
  //
  //   cmpb 0(%r13), 0
  //   jz skip_move
  //   <...> move data
  // skip_move:
  asmjit::Label skip_move = assm.newLabel();
  assm.cmp(asmjit::x86::byte_ptr(dataptr), 0);
  assm.jz(skip_move);

  assm.mov(asmjit::x86::r14, dataptr);
  if (op.argument < 0) {
    assm.sub(asmjit::x86::r14, -op.argument);
  } else {
    assm.add(asmjit::x86::r14, op.argument);
  }
  // Use rax as a temporary holding the value of at the original pointer;
  // then use al to add it to the new location, so that only the target
  // location is affected: addb %al, 0(%r13)
  assm.mov(asmjit::x86::rax, asmjit::x86::byte_ptr(dataptr));
  assm.add(asmjit::x86::byte_ptr(asmjit::x86::r14), asmjit::x86::al);
  assm.mov(asmjit::x86::byte_ptr(dataptr), 0);
  assm.bind(skip_move);
  break;
}

I'll just note again how much simpler this code is to write with asmjit than without it. Also note the careful handling of the byte-granulated data when touching memory - I ran into a number of nasty bugs when developing this. In fact, using the native machine word size (64 bits in this case) for BF memory cells would've made everything much simpler; 8-bit cells are closer to the common semantics of the language and provide an extra challenge.

Performance

Let's see how optasmjit fares against the fastest interpreter and the unoptimized JIT - 0.93 seconds for mandelbrot, 0.3 seconds for factor - another factor of 3 in performance:

BF opt3 vs simplejit vs optasmjit

Notably, the performance delta with the optimized interpreter is huge: the JIT is more than 4x faster. If we compare it all the way to the initial simple interpreter, optasmjit is about 40x faster - making it hard to even compare on the same chart :-)

BF full performance comparison for part 2

JITs are fun!

I find writing JITs lots of fun. It's really nice to be able to hand-craft every instruction emitted by the compiler. While this is quite painful to do without any encoding help, libraries like asmjit make the process much more pleasant.

We've done quite a bit in this part of the series. optasmjit is a genuine optimizing JIT for BF! It:

  • Parses BF source
  • Translates it to a sequence of higher-level ops
  • Optimizes these ops
  • Compiles the ops to tight x64 assembly in memory and runs them

Let's connect these steps to some real compiler jargon. BfOpKind ops can be seen as the compiler IR. Translation of human-readable source code to IR is often the first step in compilation (though it in itself is sometimes divided into multiple steps for realistic languages). The translation/compilation of ops to assembly is often called "lowering"; in some compilers this involves multiple steps and intermediate IRs.

I left a lot of code out of the blog post - otherwise it would be huge! I encourage you to go back through the full source files discussed here and understand what's going on - every JIT is a single standalone C++ file.


[1]I said traditionally because many modern compilers no longer work this way. For example, LLVM compiles IR to another, much lower-level IR that represents machine-code level instructions; assembly can be emitted from this IR, but also machine code directly - so the assembler is integrated into the compiler.
[2]Some compilers would do two passes; this is similar to our first interpreter optimization in part 1: the first pass collects information (such as location of all matching ]s), so the second pass already knows what offsets to emit.
[3]Please refer to asmjit's documentation for the full scoop. I'll also mention that asmjit has a "compiler" layer which does more sophisticated things like register allocation; in this post I'm only using the base assembly layer.
[4]Wondering whether we could have just used add 1 instead of inc in the first place? Certainly! In fact, while there probably used to be a good reason for a separate inc instruction, in these days of complex multi-port pipelined x64 CPUs, it's not clear which one is faster. I just wanted to show both for diversity.