This is part 2 in the series of articles on multiple dispatch. Part 1 introduced the problem and discussed the issues surrounding it, along with a couple of possible solutions in C++. In this part, I'm going to talk about implementing multiple dispatch in a completely different language - Python.

A brief re-statement of the problem

It's very important to read part 1 in the series before reading this one, but for the sake of completeness here's a (very) brief re-statement of the programming problem we're trying to solve:

We have different "shape" types - rectangles, ellipses, triangles and so on. We'd like to compute intersections between these shapes, but the algorithm may be completely different based on the two shapes being intersected. The problem is how to structure the code for maximal flexibility, correctness and maintainability.

Brute force solution in Python

Part 1 showed a "brute force" approach in C++ that uses a nested chain of if-else conditions. It's instructional to start by presenting the same with Python [1] so we have a useful baseline to compare subsequent solutions to. Here is a basic Shape hierarchy:

class Shape:
    @property
    def name(self):
        return self.__class__

class Rectangle(Shape): pass

class Ellipse(Shape): pass

class Triangle(Shape): pass

And this is the intersect function:

def intersect(s1, s2):
    if isinstance(s1, Rectangle) and isinstance(s2, Ellipse):
        print('Rectangle x Ellipse [names s1=%s, s2=%s]' % (s1.name, s2.name))
    elif isinstance(s1, Rectangle) and isinstance(s2, Rectangle):
        print('Rectangle x Rectangle [names s1=%s, s2=%s]' % (s1.name, s2.name))
    else:
        # Generic shape intersection.
        print('Shape x Shape [names s1=%s, s2=%s]' % (s1.name, s2.name))

Though it requires somewhat less code than in C++, this approach is equivalent to the brute-force checking using dynamic_cast we've seen in part 1, and it suffers from the same problems:

  1. Large amount of code that grows square with the number of types.
  2. Brittleness due to subtle ordering of checks for an inheritance hierarchy.

The brute-force solution appears odious to C++ programmers. In Python, it's even more so. Python programmers are not constrained by static typing - they have a fully dynamic duck-typed language at their disposal - isinstance is usually a code smell, and Python's reflection and meta-programming capabilities are far greater than what C++'s templates allow, so there must be a better way.

Improving brute-force by dispatching on types

In Python it's very easy to create maps (dicts, to be precise) with almost anything as a key. Even types are first-class objects. Let's replace the if-else chain by a map lookup:

def intersect_rectangle_ellipse(r, e):
    print('Rectangle x Ellipse [names r=%s, e=%s]' % (r.name, e.name))

def intersect_rectangle_rectangle(r1, r2):
    print('Rectangle x Rectangle [names r1=%s, r2=%s]' % (r1.name, r2.name))

def intersect_generic(s1, s2):
    print('Shape x Shape [names s1=%s, s2=%s]' % (s1.name, s2.name))

_dispatch_map = {
    (Rectangle, Ellipse): intersect_rectangle_ellipse,
    (Rectangle, Rectangle): intersect_rectangle_rectangle,
}

def intersect(s1, s2):
    handler = _dispatch_map.get((type(s1), type(s2)), intersect_generic)
    handler(s1, s2)

This is much better! Now, whenever we have a new combination of shapes with a custom intersection handler, we just add the new function, add a line to _dispatch_map and we're done. Note that we do introduce a subtle issue here: base-class defaults won't work now, because the dispatch has to be by the exact type name. In other words, if we have a Square inheriting from a Rectangle, we may want to reuse intersect_rectangle_rectangle for interesecting two squares, but right now we'll fall into the generic Shape handler. I'll examine this problem and possible solutions in more detail later on.

Decorators and callable wrapper objects FTW

The type mapping code in the last section isn't bad, but we can do better in Python. Indeed, Guido posted a nice article on this topic a long while ago. The main insight is to use decorators in combination with callable objects that wrap the handler functions to make client code look like this:

@multimethod(Rectangle, Ellipse)
def intersect(r, e):
    print('Rectangle x Ellipse [names r=%s, e=%s]' % (r.name, e.name))

@multimethod(Rectangle, Rectangle)
def intersect(r1, r2):
    print('Rectangle x Rectangle [names r1=%s, r2=%s]' % (r1.name, r2.name))

@multimethod(Shape, Shape)
def intersect(s1, s2):
    print('Shape x Shape [names s1=%s, s2=%s]' % (s1.name, s2.name))

The magic that makes this tick is:

class _MultiMethod:
    """Maps a tuple of types to function to call for these types."""
    def __init__(self, name):
        self.name = name
        self.typemap = {}

    def __call__(self, *args):
        types = tuple(arg.__class__ for arg in args)
        try:
            return self.typemap[types](*args)
        except KeyError:
            raise TypeError('no match %s for types %s' % (self.name, types))

    def register_function_for_types(self, types, function):
        if types in self.typemap:
            raise TypeError("duplicate registration")
        self.typemap[types] = function


# Maps function.__name__ -> _MultiMethod object.
_multi_registry = {}

def multimethod(*types):
    def register(function):
        name = function.__name__
        mm = _multi_registry.get(name)
        if mm is None:
            mm = _multi_registry[name] = _MultiMethod(name)
        mm.register_function_for_types(types, function)
        return mm
    return register

There are two levels of map dispatching here. _multi_registry is a map from functions we consider "multi methods" to a wrapper object of type _MultiMethod. The intersection algorithm is one example of a multi method, there may be others (maybe collide or some such). Each instance of _MultiMethod holds its own mapping akin to _dispatch_map from the previous section, mapping types of arguments to the actual function. Finally, the multimethod decorator completes the picture for a very nice solution overall [2]. Except that it has some issues.

Handling symmetry

I presented these issues in part 1, and they bite us again here. First, symmetry: the way @multimethod is currently defined, we specify the exact order of the types dispatched upon. In the sample above, we defined:

@multimethod(Rectangle, Ellipse)
def intersect(r, e):
    print('Rectangle x Ellipse [names r=%s, e=%s]' % (r.name, e.name))

However, if a call to intersect is made where the first argument is an Ellipse and the second argument a Rectangle, we're most likely interested in dispatching to the same function. It won't work the way the code is currently structured, but it's fairly easy to fix.

Most obviously, we can have something like:

def symmetric_intersection_rectangle_ellipse(r, e):
    print('Rectangle x Ellipse [names r=%s, e=%s]' % (r.name, e.name))

@multimethod(Rectangle, Ellipse)
def intersect(r, e):
    symmetric_intersection_rectangle_ellipse(r, e)

@multimethod(Ellipse, Rectangle)
def intersect(e, r):
    symmetric_intersection_rectangle_ellipse(r, e)

This is similar to the C++ approach in part 1. However, let's try something more advanced if we always want symmetry for these dispatches. We can easily sort the types in Python, thus providing symmetry automatically. Here's the new _MultiMethod.register_function_for_types [3]:

def register_function_for_types(self, types, function):
    # Sort the tuple of types before setting it in the dispatch map.
    types = tuple(sorted(types, key=id))
    if types in self.typemap:
        raise TypeError("duplicate registration")
    self.typemap[types] = function

The difference is that now the tuple of types is sorted by type id, which is guaranteed to be unique and orderable.

The lookup in _MultiMethod.__call__ also has to sort the types, but there's an additional subtlety. If we reorder the arguments, we have to find a way to route them to the handler function in the right order. In our intersect example, the handler function takes the rectangle first, then the ellipse. So if we dispatch intersect(e, r) to the same handler, the order of arguments has to be reversed:

def __call__(self, *args):
    # Find the right function to call based on a sorted tuple of types. We
    # have to sort the call arguments themselves together with the types,
    # so that the handler function can get them in the order it expects.
    args_with_types = sorted(
        zip(args, (arg.__class__ for arg in args)),
        key=lambda pair: id(pair[1]))
    types = tuple(ty for _, ty in args_with_types)
    try:
        args = (arg for arg, _ in args_with_types)
        return self.typemap[types](*args)
    except KeyError:
        raise TypeError('no match %s for types %s' % (self.name, types))

Now using the single multi-method definition:

@multimethod(Rectangle, Ellipse)
def intersect(r, e):
    print('Rectangle x Ellipse [names r=%s, e=%s]' % (r.name, e.name))

We can run:

r = Rectangle()
e = Ellipse()

intersect(r, e)
intersect(e, r)

And get:

Rectangle x Ellipse [names r=<class '__main__.Rectangle'>, e=<class '__main__.Ellipse'>]
Rectangle x Ellipse [names r=<class '__main__.Rectangle'>, e=<class '__main__.Ellipse'>]

It works, but there are (at least) two problems:

  1. The runtime cost of the dispatch is now excessive. Every call to intersect has to sort the types of arguments to determine which handler to dispatch to. This can be alleviated by some sort of caching, but it won't be free.
  2. The intersect(r, e) handler always accepts the rectangle first. This implicitly relies on the relative sorting order between Rectangle and Ellipse: we assume Rectangle will be sorted lower. This is a pretty bad assumption to make. To solve this we could use keyword arguments, for example.

Even though these problems have solutions, it should be clear that making symmetry work "magically" is tricky. Is it worth the trouble? YMMV.

Handling base-class defaults

Another interesting issue the Python solution doesn't address yet is base-class defaults; it suffers from the same problem presented above in the manual dispatch dict approach. We want to add a Square shape:

class Square(Rectangle): pass

But unless we define a version of intersect for the type tuple (Square, Ellipse), when we attempt to intersect squares and ellipses we get:

>>> e = Ellipse()
>>> sq = Square()
>>> intersect(sq, e)
Traceback (most recent call last):
  File "multi_with_base_class_defaults.py", line 32, in __call__
    return self.typemap[types](*args)
KeyError: (<class '__main__.Square'>, <class '__main__.Ellipse'>)

Even though a Square is-a Rectangle and we could reasonably expect this call to dispatch to a (Rectangle, Ellipse) handler that we did define.

There are a couple of ways we could go about this. The most obvious is to find a handler at call-time; that is, when intersect is called and we don't find a handler for the exact type tuple passed in, we can keep looking - creating tuples of the types' superclasses, until we find something (eventually we'd hit a handler for (Shape, Shape)). Note that this is lookup over all combinations of superclasses for the input types: for each superclass of the left-hand-side type, we have to check with every superclass of the right-hand-side type.

This is obviously costly and doing it for each dispatch is unthinkable. So how about we shift it to registration-time? When a handler is registered, we can walk over all the combinations of subclasses of the input types, and register the same handler for those too. The actual dispatch in _MultiMethod.__call__ then remains unchanged. Here's the new registration method [4]:

def register_function_for_types(self, types, function):
    types_with_subclasses = []
    for ty in types:
        types_with_subclasses.append([ty] + all_subclasses(ty))
    for type_tuple in itertools.product(*types_with_subclasses):
        # Here we explicitly support overriding the registration, so that
        # more specific dispatches can override earlier-defined generic
        # dispatches.
        self.typemap[type_tuple] = function

It uses this helper function to recursively list all subclasses of a given class:

def all_subclasses(cls):
    """Returns a list of *all* subclasses of cls, recursively."""
    subclasses = cls.__subclasses__()
    for subcls in cls.__subclasses__():
        subclasses.extend(all_subclasses(subcls))
    return subclasses

This works as expected. The registration of intersect for (Rectangle, Ellipse) registered the same handler for (Square, Ellipse), the dispatch succeeds and we get:

>>> e = Ellipse()
>>> sq = Square()
>>> intersect(sq, e)
Rectangle x Ellipse [names r=<class '__main__.Square'>, e=<class '__main__.Ellipse'>]

There are a couple of caveats we have to be aware of with this solution:

  1. When registering handlers, the generic (base) handlers must be registered before the more concrete (subclass) handlers, because later registrations override earlier ones. If we register the (Shape, Shape) handler last, it will override all other handlers because all the shapes are subclasses of Shape.
  2. The handlers must be registered after the whole type hierarchy has been declared. Since register_function_for_types walks the tree of subclasses of any type it sees, it has to know about all subclasses at that point in the execution. As a concrete example, if we register a handler for (Ellipse, Triangle) and only later define a new shape Circle inheriting from Ellipse, a (Circle, Triangle) call will not be correctly routed to the (Ellipse, Triangle) handler, since Circle was unkown when the handler was registered.

Generalized dispatch

Armed with the magic capabilities of Python, everything seems doable and we feel our programming powers grow; along with them, so does our appetite. So far we've looked at dispatching fairly rigidly by the types of the arguments. Can we come up with something more... general? For example, can we place arbitrary predicates in our decorator, and have the dispatcher invoke the handler only if the predicate is true for an argument? We sure can, and there's a Python library that does these things - PEAK-rules.

I won't spend much time on it, but will just say that it lets us specify dispatch conditions like:

>>> @when(pprint, "isinstance(ob,list) and len(ob)>50")
... def pprint_long_list(ob):
...     ...

It then parses the Python code in the condition and uses that to dispatch the calls when the condition is satisfied [5]. This is just a small example - the library is very general.

It shouldn't be hard to implement a basic form of this functionality on our own and you should definitely try as an exercise. I will spend more time on generalized dispatch in parts 3 and 4 of the series.

Summary

This article demonstrates how multiple dispatch can be easily bolted on top of Python, leveraging the language's dynamism and permissive typing. Moreover, with the help of decorators, multiple dispatch can be handsomely integrated into the program's syntax in a very natural way. What's most appealing is that after being marked as @multi, handlers can be used just like any other Python function - calling them looks no different, but the underlying mechanism gets more sophisticated.

However, as we've seen here there are several variations on the multiple dispatch scheme, and each requires slightly different implementation considerations. This is why I actually think that given a real-world need to do multiple dispatch, I'd be more inclined to roll a domain-specific solution rather than use a library. The basic form of multiple dispatch is just 30 lines of Python; moreover, this is fairly run-of-the-mill code once you have some experience with decorators and callable objects. Is using something like PEAK-rules worth the learning curve? I'm not convinced at all.

In the next part of the series we'll see how multiple dispatch works in a language that supports it as a built-in feature.


[1]I'm using Python 3 (more specifically 3.4) for these code samples.
[2]It's even true multiple dispatch, rather than just double dispatch. Nothing in this approach limits the number of arguments dispatched upon to 2.
[3]The full code sample for multi-dispatch with symmetry is available here.
[4]The full code sample for multi-dispatch with base-class defaults is available here.
[5]Why does it need to parse Python where a lambda would do? The documentation of PEAK-rules is somewhat dense, but I believe it tries to optimize dispatches based on possible sub-expressions, and so on. So for example, if one handler is registered when foo > 20 and another when foo > 40, then the first handler should fire too whenever the second one does. The library needs understanding of the code in question to make that possible. Note that the library is fairly old, predating Python's ast module, so it does its own source-to-AST parsing of a subset of Python.

Comments

comments powered by Disqus