Type inference is a major feature of several programming languages, most notably languages from the ML family like Haskell. In this post I want to provide a brief overview of type inference, along with a simple Python implementation for a toy ML-like language.

Uni-directional type inference

While static typing is very useful, one of its potential downsides is verbosity. The programmer has to annotate values with types throughout the code, which results in more effort and clutter. What's really annoying, though, is that in many cases these annotations feel superfluous. Consider this classical C++ example from pre-C++11 times:

std::vector<Blob*> blobs;
std::vector<Blob*>::iterator iter = blobs.begin();

Clearly when the compiler sees blobs.begin(), it knows the type of blobs, so it also knows the type of the begin() method invoked on it because it is familiar with the declaration of begin. Why should the programmer be burdened with spelling out the type of the iterator? Indeed, one of the most welcome changes in C++11 was lifting this burden by repurposing auto for basic type inference:

std::vector<Blob*> blobs;
auto iter = blobs.begin();

Go has a similar capability with the := syntax. Given some function:

func parseThing(...) (Node, error) {
}

We can simply write:

node, err := parseThing(...)

Without having to explicitly declare that node has type Node and err has type error.

These features are certainly useful, and they involve some degree of type inference from the compiler. Some functional programming proponents say this is not real type inference, but I think the difference is just a matter of degree. There's certainly some inference going on here, with the compiler calculating and assigning the right types for expressions without the programmer's help. Since this calculation flows in one direction (from the declaration of the vector::begin method to the auto assignment), I'll call it uni-directional type inference [1].

Bi-directional type inference (Hindley-Milner)

If we define a new map function in Haskell to map a function over a list, we can do it as follows:

mymap f [] = []
mymap f (first:rest) = f first : mymap f rest

Note that we did not specify the types for either the arguments of mymap, or its return value. The Haskell compiler can infer them on its own, using the definition provided:

> :t Main.mymap
Main.mymap :: (t1 -> t) -> [t1] -> [t]

The compiler has determined that the first argument of mymap is a generic function, assigning its argument the type t1 and its return value the type t. The second argument of mymap has the type [t1], which means "list of t1"; then the return value of mymap has the type "list of t". How was this accomplished?

Let's start with the second argument. From the [] = [] variant, and also from the (first:rest) deconstruction, the compiler infers it has a list type. But there's nothing else in the code constraining the element type, so the compiler chooses a generic type specifier - t1. f first applies f to an element of this list, so f has to take t1; nothing constrains its return value type, so it gets the generic t. The result is f has type (t1 -> t), which in Haskell parlance means "a function from t1 to t".

Here is another example, written in a toy language I put together for the sake of this post. The language is called microml, and its implementation is described at the end of the post:

foo f g x = if f(x == 1) then g(x) else 20

Here foo is declared as a function with three arguments. What is its type? Let's try to run type inference manually. First, note that the body of the function consists of an if expresssion. As is common in programming languages, this one has some strict typing rules in microml; namely, the type of the condition is boolean (Bool), and the types of the then and else clauses must match.

So we know that f(x == 1) has to return a Bool. Moreover, since x is compared to an integer, x is an Int. What is the type of g? Well, it has an Int argument, and it return value must match the type of the else clause, which is an Int as well.

To summarize:

  • The type of x is Int
  • The type of f is Bool -> Bool
  • The type of g is Int -> Int

So the overall type of foo is:

((Bool -> Bool), (Int -> Int), Int) -> Int

It takes three arguments, the types of which we have determined, and returns an Int.

Note how this type inference process is not just going in one direction, but seems to be "jumping around" the body of the function figuring out known types due to typing rules. This is why I call it bi-directional type inference, but it's much better known as Hindley-Milner type inference, since it was independently discovered by Roger Hindley in 1969 and Robin Milner in 1978.

How Hindley-Milner type inference works

We've seen a couple of examples of manually running type inference on some code above. Now let's see how to translate it to an implementable algorithm. I'm going to present the process in several separate stages, for simplicity. Some other presentations of the algorithm combine several of these stages, but seeing them separately is more educational, IMHO.

The stages are:

  1. Assign symbolic type names (like t1, t2, ...) to all subexpressions.
  2. Using the language's typing rules, write a list of type equations (or constraints) in terms of these type names.
  3. Solve the list of type equations using unification.

Let's use this example again:

foo f g x = if f(x == 1) then g(x) else 20

Starting with stage 1, we'll list all subexpressions in this declaration (starting with the declaration itself) and assign unique type names to them:

foo                                       t0
f                                         t1
g                                         t2
x                                         t3
if f(x == 1) then g(x) else 20            t4
f(x == 1)                                 t5
x == 1                                    t6
x                                         t3
g(x)                                      t7
20                                        Int

Note that every subexpression gets a type, and we de-duplicate them (e.g. x is encountered twice and gets the same type name assigned). Constant nodes get known types.

In stage 2, we'll use the language's typing rules to write down equations involving these type names. Usually books and papers use slightly scary formal notation for typing rules; for example, for if:

\[\frac{\Gamma \vdash e_0 : Bool, \Gamma \vdash e_1 : T, \Gamma \vdash e_2 : T}{\Gamma \vdash if\: e_0\: then\: e_1\: else\: e_2 : T}\]

All this means is the intuitive typing of if we've described above: the condition is expected to be boolean, and the types of the then and else clauses are expected to match, and their type becomes the type of the whole expression.

To unravel the notation, prepend "given that" to the expression above the line and "we can derive" to the expression below the line; \Gamma \vdash e_0 : Bool means that e_0 is typed to Bool in the set of typing assumptions called \Gamma.

Similarly, a typing rule for single-argument function application would be:

\[\frac{\Gamma \vdash e_0 : T, \Gamma \vdash f : T \rightarrow U}{\Gamma \vdash f(e_0) : U}\]

The real trick of type inference is running these typing rules in reverse. The rule tells us how to assign types to the whole expression given its constituent types, but we can also use it as an equation that works both ways and lets us infer constituent types from the whole expression's type.

Let's see what equations we can come up with, looking at the code:

From f(x == 1) we infer t1 = (t6 -> t5), because t1 is the type of f, t6 is the type of x == 1, and t5 is the type of f(x == 1). Note that we're using the typing rules for function application here. Moreover, we can infer that t3 is Int and t6 is Bool because of the typing rule of the == operator.

Similarly, from g(x) we infer t2 = (t3 -> t7).

From the if expression, we infer that t6 is Bool (since it's the condition of the if) and that t4 = Int, because the then and else clauses must match.

Now we have a list of equations, and our task is to find the most general solution, treating the equations as constraints. This is done by using the unification algorithm which I described in detail in the previous post. The solution we're seeking here is precisely the most general unifier.

For our expression, the algorithm will find the type of foo to be:

((Bool -> Bool), (Int -> Int), Int) -> Int)

As expected.

If we make a slight modification to the expression to remove the comparison of x with 1:

foo f g x = if f(x) then g(x) else 20

Then we can no longer constrain the type of x, since all we know about it is that it's passed into functions f and g, and nothing else constrains the arguments of these functions. The type inference process will thus calculate this type for foo:

((a -> Bool), (a -> Int), a) -> Int

It assigns x the generic type name a, and uses it for the arguments of f and g as well.

The implementation

An implementation of microml is available here, as a self-contained Python program that parses a microml declaration and infers its type. The best starting point is main.py, which spells out the stages of type inference:

code = 'foo f g x = if f(x == 1) then g(x) else 20'
print('Code', '----', code, '', sep='\n')

# Parse the microml code snippet into an AST.
p = parser.Parser()
e = p.parse_decl(code)
print('Parsed AST', '----', e, '', sep='\n')

# Stage 1: Assign symbolic typenames
typing.assign_typenames(e.expr)
print('Typename assignment', '----',
      typing.show_type_assignment(e.expr), '', sep='\n')

# Stage 2: Generate a list of type equations
equations = []
typing.generate_equations(e.expr, equations)
print('Equations', '----', sep='\n')
for eq in equations:
    print('{:15} {:20} | {}'.format(str(eq.left), str(eq.right), eq.orig_node))

# Stage 3: Solve equations using unification
unifier = typing.unify_all_equations(equations)
print('', 'Inferred type', '----',
      typing.get_expression_type(e.expr, unifier, rename_types=True),
      sep='\n')

This will print out:

Code
----
foo f g x = if f(x == 1) then g(x) else 20

Parsed AST
----
Decl(foo, Lambda([f, g, x], If(App(f, [(x == 1)]), App(g, [x]), 20)))

Typename assignment
----
Lambda([f, g, x], If(App(f, [(x == 1)]), App(g, [x]), 20))   t0
If(App(f, [(x == 1)]), App(g, [x]), 20)                      t4
App(f, [(x == 1)])                                           t5
f                                                            t1
(x == 1)                                                     t6
x                                                            t3
1                                                            Int
App(g, [x])                                                  t7
g                                                            t2
x                                                            t3
20                                                           Int

Equations
----
Int             Int                  | 1
t3              Int                  | (x == 1)
Int             Int                  | (x == 1)
t6              Bool                 | (x == 1)
t1              (t6 -> t5)           | App(f, [(x == 1)])
t2              (t3 -> t7)           | App(g, [x])
Int             Int                  | 20
t5              Bool                 | If(App(f, [(x == 1)]), App(g, [x]), 20)
t4              t7                   | If(App(f, [(x == 1)]), App(g, [x]), 20)
t4              Int                  | If(App(f, [(x == 1)]), App(g, [x]), 20)
t0              ((t1, t2, t3) -> t4) | Lambda([f, g, x], If(App(f, [(x == 1)]), App(g, [x]), 20))

Inferred type
----
(((Bool -> Bool), (Int -> Int), Int) -> Int)

There are many more examples of type-inferred microml code snippets in the test file test_typing.py. Here's another example which is interesting:

> foo f x = if x then lambda t -> f(t) else lambda j -> f(x)
((Bool -> a), Bool) -> (Bool -> a)

The actual inference is implemented in typing.py, which is fairly well commented and should be easy to understand after reading this post. The trickiest part is probably the unification algorithm, but that one is just a slight adaptation of the algorithm presented in the previous post.


[1]

After this post was published, it was pointed out that another type checking / inference technique is already called bi-directional (see this paper for example); while it's related to Hindley-Milner (HM), it's a distinct method. Therefore, my terminology here can create a confusion.

I'll emphasize that my only use of the term "bi-directional" is to distinguish what HM does from the simpler "uni-directional" inference described at the beginning.