This post is about a somewhat more interesting and complex use of llvmlite than the basic example presented in my previous article on the subject.

I see compilation as a meta-tool. It lets us build new levels of abstraction and expressiveness within our code. We can use it to build additional languages on top of our host language (common for C, C++ and Java-based systems, less common for Python), to accelerate some parts of our host language (more common in Python), or anything in between.

To fully harness the power of runtime compilation (JITing), however, it's very useful to know how to bridge the gap between the host language and the JITed language; preferably in both directions. As the previous article shows, calling from the host into the JITed language is trivial. In fact, this is what JITing is mostly about. But what about the other direction? This is somewhat more challenging, but leads to interesting uses and additional capabilities.

While the post uses llvmlite for the JITing, I believe it presents general concepts that are relevant for any programming environment.

Callback from JITed code to Python

Let's start with a simple example: we want to be able to invoke some Python function from within JITed code.

from ctypes import c_int64, c_void_p, CFUNCTYPE
import sys

import llvmlite.ir as ir
import llvmlite.binding as llvm

def create_caller(m):
    # define i64 @caller(i64 (i64, i64)* nocapture %f, i64 %i) #0 {
    # entry:
    #   %mul = shl nsw i64 %i, 1
    #   %call = tail call i64 %f(i64 %i, i64 %mul) #1
    #   ret i64 %call
    # }
    i64_ty = ir.IntType(64)

    # The callback function 'caller' accepts is a pointer to FunctionType with
    # the appropriate signature.
    cb_func_ptr_ty = ir.FunctionType(i64_ty, [i64_ty, i64_ty]).as_pointer()
    caller_func_ty = ir.FunctionType(i64_ty, [cb_func_ptr_ty, i64_ty])

    caller_func = ir.Function(m, caller_func_ty, name='caller')
    caller_func.args[0].name = 'f'
    caller_func.args[1].name = 'i'
    irbuilder = ir.IRBuilder(caller_func.append_basic_block('entry'))
    mul = irbuilder.mul(caller_func.args[1],
                        irbuilder.constant(i64_ty, 2),
                        name='mul')
    call = irbuilder.call(caller_func.args[0], [caller_func.args[1], mul])
    irbuilder.ret(call)

create_caller creates a new LLVM IR function called caller and injects it into the given module m.

If you're not an expert at reading LLVM IR, caller is equivalent to this C function:

int64_t caller(int64_t (*f)(int64_t, int64_t), int64_t i) {
  return f(i, i * 2);
}

It takes f - a pointer to a function accepting two integers and returning an integer (all integers in this post are 64-bit), and i - an integer. It calls f with i*2 and i as the arguments. That's it - pretty simple, but sufficient for our demonstration's purposes.

Now let's define a Python function:

def myfunc(a, b):
    print('I was called with {0} and {1}'.format(a, b))
    return a + b

Finally, let's see how we can pass myfunc as the callback caller will invoke. This is fairly straightforward, thanks to the support for callback functions in ctypes. In fact, it's exactly similar to the way you'd pass Python callbacks to C code via ctypes without any JITing involved:

def main():
    module = ir.Module()
    create_caller(module)

    llvm.initialize()
    llvm.initialize_native_target()
    llvm.initialize_native_asmprinter()

    llvm_module = llvm.parse_assembly(str(module))
    tm = llvm.Target.from_default_triple().create_target_machine()

    # Compile the module to machine code using MCJIT.
    with llvm.create_mcjit_compiler(llvm_module, tm) as ee:
        ee.finalize_object()

        # Obtain a pointer to the compiled 'caller' - it's the address of its
        # JITed code in memory.
        CBFUNCTY = CFUNCTYPE(c_int64, c_int64, c_int64)
        cfptr = ee.get_pointer_to_function(llvm_module.get_function('caller'))
        callerfunc = CFUNCTYPE(c_int64, CBFUNCTY, c_int64)(cfptr)

        # Wrap myfunc in CBFUNCTY and pass it as a callback to caller.
        cb_myfunc = CBFUNCTY(myfunc)
        print('Calling "caller"')
        res = callerfunc(cb_myfunc, 42)
        print('  The result is', res)

If we run this code, we get the expected result:

Calling "caller"
I was called with 42 and 84
  The result is 126

Registering host functions in JITed code

When developing a JIT, one need that comes up very often is to delegate some of the functionality in the JITed code to the host language. For example, if you're developing a JIT to implement a fast DSL, you may not want to reimplement a whole I/O stack in your language. So you'd prefer to delegate all I/O to the host language. Taking C as a sample host language, you just want to call printf from your DSL and somehow have it routed to the host call.

How do we accomplish this feat? The solution here, naturally, depends on both the host language and the DSL you're JITing. Let's take the LLVM tutorial as an example. The Kaleidoscope language does computations on floating point numbers, but it has no I/O facilities of its own. Therefore, the Kaleidoscope compiler exposes a putchard function from the host (C++) to be callable in Kaleidoscope. For Kaleidoscope this is fairly simple, because the host is C++ and is compiled into machine code in the same process with the JITed code. All the JITed code needs to know is the symbol name of the host function to call and the call can happen (as long as the calling conventions match, of course).

Alas, for Python as a host language, things are not so straightforward. This is why, in my reimplementation of Kaleidoscope with llvmlite, I resorted to implementing the builtins in LLVM IR, emitting them into the module along with compiled Kaleidoscope code. These builtins just call the underlying C functions (which still reside in the same process, since Python itself is written in C) and don't call into Python.

But say we wanted to actually call Python. How would we go about that?

Well, we've seen a way to call Python from JITed code in this post. Can this approach be used? Yes, though it's quite cumbersome. The problem is that the only place where we have an actual interface between Python and the JITed code is when we invoke a JITed function. Somehow we should use this interface to convey to the JIT side what Python functions are available to it and how to call them. Essentially, we'll have to imlement something akin to the following schematic symbol table interface in the JITed code:

typedef int64_t (*CallbackType)(int64_t, int64_t);
std::unordered_map<std::string, CallbackType> symtab;

void register_callback(std::string name, CallbackType callback) {
  symtab[name] = callback;
}

CallbackType get_callback(std::string name) {
  auto iter = symtab.find(name);
  if (iter != symtab.end()) {
    return iter->second;
  } else {
    return nullptr;
  }
}

To register Python callbacks with the JIT, we'll call register_callback from Python, passing it a name and the callback (CFUNCTYPE as shown in the code sample at the top). The JIT side will remember this mapping in a symbol table. When it needs to invoke a Python function it will use get_callback to get the pointer by name.

In addition to being cumbersome to implement [1], this is also inefficient. It seems wasteful to go through a symbol table lookup for every call to a Python builtin. It's not like these mappings ever change in a typical use case! We are emitting code at runtime here and have so much flexibility at our command - so this lookup feels like a crutch.

Moreover, this is a simplified example - every callback takes two integer arguments. In real scenarios, the signatures of callback functions can be arbitrary, so we'd have to implement a full blown FFI-dispatching on the calls.

Breaching the compile/run-time barrier

We can do better. For every Python function we intend to call from the JITed code, we can emit a JITed wrapper. This wrapper will hard-code a call to the Python function, thus removing this dispatching (the symbol table shown above) from run-time; this totally makes sense because we know at compile time which Python functions are needed and where to find them.

Let's write the code to do this with llvmlite:

import ctypes
from ctypes import c_int64, c_void_p
import sys

import llvmlite.ir as ir
import llvmlite.binding as llvm

cb_func_ty = ir.FunctionType(ir.IntType(64),
                             [ir.IntType(64), ir.IntType(64)])
cb_func_ptr_ty = cb_func_ty.as_pointer()
i64_ty = ir.IntType(64)

def create_addrcaller(m, addr):
    # define i64 @addrcaller(i64 %a, i64 %b) #0 {
    # entry:
    #   %f = inttoptr i64% ADDR to i64 (i64, i64)*
    #   %call = tail call i64 %f(i64 %a, i64 %b)
    #   ret i64 %call
    # }
    addrcaller_func_ty = ir.FunctionType(i64_ty, [i64_ty, i64_ty])
    addrcaller_func = ir.Function(m, addrcaller_func_ty, name='addrcaller')
    a = addrcaller_func.args[0]; a.name = 'a'
    b = addrcaller_func.args[1]; b.name = 'b'
    irbuilder = ir.IRBuilder(addrcaller_func.append_basic_block('entry'))
    f = irbuilder.inttoptr(irbuilder.constant(i64_ty, addr),
                           cb_func_ptr_ty, name='f')
    call = irbuilder.call(f, [a, b])
    irbuilder.ret(call)

The IR function created by create_addrcaller is somewhat similar to the one we've seen above with create_caller, but there's a subtle difference. addcaller does not take a function pointer at runtime. It has knowledge of this function pointer encoded into it when it's generated. The addr argument passed into create_addrcaller is the runtime address of the function to call. addrcaller converts it to a function pointer (using the inttoptr instruction, which is somewhat similar to a reinterpret_cast in C++) and calls it [2].

Here's how to use it:

def main():
    CBFUNCTY = ctypes.CFUNCTYPE(c_int64, c_int64, c_int64)
    def myfunc(a, b):
        print('I was called with {0} and {1}'.format(a, b))
        return a + b
    cb_myfunc = CBFUNCTY(myfunc)
    cb_addr = ctypes.cast(cb_myfunc, c_void_p).value
    print('Callback address is 0x{0:x}'.format(cb_addr))

    module = ir.Module()
    create_addrcaller(module, cb_addr)
    print(module)

    llvm.initialize()
    llvm.initialize_native_target()
    llvm.initialize_native_asmprinter()

    llvm_module = llvm.parse_assembly(str(module))

    tm = llvm.Target.from_default_triple().create_target_machine()

    # Compile the module to machine code using MCJIT
    with llvm.create_mcjit_compiler(llvm_module, tm) as ee:
        ee.finalize_object()
        # Now call addrcaller
        print('Calling "addrcaller"')
        addrcaller = ctypes.CFUNCTYPE(c_int64, c_int64, c_int64)(
            ee.get_pointer_to_function(llvm_module.get_function('addrcaller')))
        res = addrcaller(105, 23)
        print('  The result is', res)

The key trick here is the call to ctypes.cast. It takes a Python function wrapped in a ctypes.CFUNCTYPE and casts it to a void*; in other words, it obtains its address [3]. This is the address we pass into create_addrcaller. The code ends up having exactly the same effect as the previous sample, but with an important difference: whereas previously the dispatch to myfunc happened at run-time, here it happens at compile-time.

This is a synthetic example, but it should be clear how to extend it to the full thing mentioned earlier: for each built-in needed by the JITed code from the host code, we emit a JITed wrapper to call it. No symbol table dispatching at runtime. Even better, since these builtins can have arbitrary signatures, the JITed wrapper can handle all of that efficiently. PyPy uses this technique to make calls into C (via the cffi library) much more efficien than they are with ctypes. ctypes uses libffi, which has to pack all the arguments to a function at runtime, according to a type signature it was provided. However, since this type signature almost never changes during the runtime of one program, this packing can be done much more efficiently with JITing.

Conclusion

Hopefully it's clear that while this article focuses on a very specific technology (using llvmlite to JIT native code from Python), its principles are universal. The overarching idea here is that the difference between what happens when the program is compiled and what happens when it runs is artificial. We can breach and overlay it in many ways, and use it to built increasingly complex abstractions. Some languages, like the Lisp family, list this mixture of compile-time and run-time as one of their unique strengths, and have been preaching it for decades. I fondly recall my own first real-world use of this technique many years ago - reading a configuration file and generating code at runtime that unpacks data based on that configuration. That task, emitting Perl code from Perl according to a XML config may appear worlds away from the topic of this post - emitting LLVM IR from Python according to a function signature, but if you really think about it, it's exactly the same thing.

I suspect this is one of the most obtuse articles I've written lately; if you read this far, I sure hope you found it interesting and helpful. Let me know in the comments if anything isn't clear or if you have relevant ideas - I love discussing this topic!


[1]We'll have to compile the equivalent of a hash table implementation into our JITed code. While not impossible, this may be an overkill if you really just want a quick-and-simple DSL.
[2]This delightful mixture of compile-time and run-time is by far the most important part of this article; if you remember just one thing from here, this should be it. Let me know in the comments if it's not clear.
[3]The concept of "address" for a Python function may raise an eyebrow. Keep in mind that this isn't a pure Python function we're talking about here. It's wrapped in a ctypes.CFUNCTYPE, which is a dispatcher created by ctypes ("thunk" in the nomenclature of libffi, the underlying mechanism behind ctypes) to perform argument conversion and make the actual call.