Why Optimising Python is Hard (3): Tail Call Optimisation

The Idea of Tail Call Optimisation

There is a famous family of numerical sequences, which can be built according to a very simple pattern. You start with an arbitrary positive integer as your first number in the sequence. Whenever your number is even, you divide it by two. Otherwise, you multiply it be three, and add one. Here is an example for such a sequence:

15, 46, 23, 70, 35, 106, 53, 160, 80, 40, 20, 10, 5, 16, 8, 4, 2, 1

The remarkable thing about this sequence is that you always seem to end up with 1 pretty quickly. But let’s not dwell on the beauty of mathematics, but rather ponder how you would write a program to spit out this series in Python. Here is an example how you could do it:

def Collatz(x):
    while x != 1:
        print(x, end=', ')
        if x % 2 == 0:
            x = x // 2
        else:
            x = 3 * x + 1

But, perhaps you prefer a more functional approach, and like to write it using recursion:

def collatz(x):
    print(x, end=', ')
    if x == 1:
        return
    if x % 2 == 0:
        collatz(x // 2)
    else:
        collatz(3 * x + 1)

There is a certain elegance to this way of writing code. But, it comes at a cost. Each time we call collatz again (with a new argument), a new frame is created (that is the environment holding local variables, such as x in this case). The calling function, however, is still running, until the next iteration of collatz has run its course, and returns. Hence, these frames with local variables start piling up, and if you are lucky enough to find the starting number for a really long series, you will run out of memory. Python then you that the maximum recursion depth has been reached, and your program breaks.

If you look closer, running into this problem of exhausting the limit of recursions is entirely unnecessary. The first instance of collatz is waiting for the second instance to end, just so it can end, too. But once the second instance has finished, the first instance does not do anything anymore. Now, a clever system (compiler and/or interpreter) recognises that the first instanace, and its local variables, are not used anymore, anyway. Thus, instead of doing a proper call with creating a new frame, it might just replace the current frame by a new one, and then jumps back to the beginning of the function’s code.

In principle, it is not even necessary, to replace the frame with a new one. A good compiler would just change the value of x, and basically rewrite the recursive version into an iterative version as we had before. Sure, this seems like coming to full circle in our toy example here. But imagine you started with the recursive version, and could rely on the compiler to find the best implementation for you – be it in the end truly recursive, or iterative. This is known as Tail Call Optimisation.

In short, if the last action of a function is to call itself with new arguments, the compiler could just reset the variable to appropriate values, and then jump back to the beginning of the function’s code. In principle, this last call could also be to any other function, but things might get slightly more difficult then (because you have a different set of local variables).

Getting the Entire Sequence

Our (recursive) collatz function prints each number, one after the other. But, we rather want to get a list col_numbers with all the numbers of a particular sequence. To make the problem interesting, let us assume that we cannot modify the collatz function itself (just imagine it some very weird, convoluted, and complex code, and you absolutely do not want to tamper with it).

One option certainly is to just replace to print function. Something like the following should do the trick nicely:

col_numbers = []
print = lambda number, end: col_numbers.append(number)

Another option is to go, and wrap the original collatz function. Instead of the output, we collect the arguments passed to the function (since we are using the recursive version here, that should lead to exactly the same list). What we want is something like this:

old_collatz = collatz
    
def collatz(x):
    col_numbers.append(x)
    old_collatz(x)

Indeed, this works. But why? After all, once collatz is called, it is just recursively calling itself, right? Not quite: Python stipulates that – even in the case of a seemingly recursive function – before calling the function, the interpreter looks up the function currently named collatz, and then calls that. So, if we replace what collatz refers to, the original function is not strictly recursive, anymore. Rather it calls the new collatz function, which then in turn calls the previously recursive version. If we call the original collatz function simply C, and the new wrapper W, then we get the call sequence:

W -> C -> W -> C -> W -> C -> ...

Obviously, this trick with a wrapper function only works because Python looks up every name each time it is used, even if it is a (recursive) function call. On the other hand, this means that the Python compiler cannot know in advance, if the ostensibly tail recursive function collatz really is tail recursive. There is a chance that, somewhere in the code, the functions get renamed, or wrapped, and the entire call structure becomes much more complicated.

Decorators

As an aside: the wrapping is usually done using a decorator. You might end up wrapping more than one function, so you write a wrapper factory, which takes a function f, and returns a new function wrapper, which first adds the argument to the list col_numbers, before calling f itself.

def wrap(f):
        def wrapper(x):
            col_numbers.append(x)
            f(x)
        return wrapper

Once you have this wrapper factory, you could write collatz = wrap(collatz). This, however, is usually done using a decorator, which does the wrapping for you:

@wrap
def collatz(x):
    print(x, end=', ')
    if x == 1:
        return
    ...

And if you are up for it, you go all the way, and write a factory for wrapper factories, so that you can select the list to collect the arguments as well (this might hurt a little, though, given that we now have a function that creates a function, which takes a function, and then returns a function, …):

def wrap_with_list(arg_list):
    def wrap(f):
        def wrapper(x):
            arg_list.append(x)
            f(x)
        return wrapper
    return wrap
    
@wrap_with_list(col_numbers)
def collatz(x):
    print(x, end=', ')
    ...

Isn’t Python just awesome?

Counting Frames

It is a fun game to try, and figure out with which starting number you can produce the longest sequence before it ends up at one. In that case, all we care about is the length of the list. This is easily obtained once we have the list. If all we are interested in is the length, however, we do not even need to memorise the entire list, but just use a counter instead:

old_collatz = collatz
counter = 0
def collatz(x):
    global counter
    counter += 1
    old_collatz(x)

Nothing new here, problem solved, and we are done. But…

Just for the fun of it, let us consider another, much more complicated (=more interesting) way of getting the length of the list. Let us count the frames on the call stack.

Do you remember the frames from above? Each time a function is called, Python creates a new frame just for this specific invocation of the function. The frame then holds the local variables, among other things. In fact, each frame has a field f_locals, containing the values of local variables, a field f_code, referring to the function’s code, and a field f_back, which refers to the calling frame. If a runtime error occurs, Python usually spills the entire list of frames (called a traceback) to the screen. If you want to, you can easily do basically the same with code like this:

def show_traceback():
    import inspect
    frame = inspect.currentframe()
    while frame.f_back is not None:
        frame = frame.f_back
        print(frame.f_code.co_name, 'in', frame.f_code.co_filename)

Today, however, we are not going to print a traceback. Instead, we are measuring the depth of the call stack. More precisely, we go through the entire call stack, just as in the code above, and count how often the original collatz function appears there.

import inspect
    
def count_code(f):
    code = f.__code__
    frame = inspect.currentframe()
    result = 0
    while frame is not None:
        if frame.f_code is code:
            result += 1
        frame = frame.f_back
    return result
    
old_collatz = collatz
sequence_length = 0
    
def collatz(x):
    global sequence_length
    # we need to add one manually, because we are going
    # to call another `collatz` three lines further down
    seq_len = count_code(old_collatz) + 1
    if seq_len > sequence_length:
        sequence_length = seq_len
    old_collatz(x)
    return sequence_length

Sure, this is a grossly over-engineered, and not particularly performant way, to get the length of a collatz sequence. But playing around with Python on this level is just too much fun.

Instead of using the inspect module, you will often find the alternative using the sys module instead:

import sys
frame = sys._getframe(0)

This does the same as the currentframe function from inspect. The difference is that sys._getframe is a protected, and not entirely official function. There is no guarantee that it will still be around in the next version of Python (even though, let’s be honest, it is used far too often as that it could just be removed).

Even after having enjoyed ourselves, there is a moral to this story…

The idea of counting frames is perfectly legal in Python, and follows just the standard Python specs. In other words: counting frames is not some implementation detail that exists only in CPython, but something we can rely on in every complete installment of Python.

Tail call optimisation, on the other hand, is based on the idea that we reuse or replace frames, instead of creating a new frame for each call. So, if a brave new compiler optimised our tail call away, and replaced it with a simple jump, the result of our program would always be one – regardless of the actual length of our sequence. Hence, we have just found a completely valid Python program that “breaks” our hypothetical compiler, and proves that it does not adhere to the Python specifications.

Once again, there are so many ways to mess with Python that it is almost impossible for a static compiler to perform a tail call optimisation. You cannot optimise Python without the risk of breaking it, which, in turn, is an absolute no-go for any compiler.

Guards and Dynamic Analysis

Is all lost, then? Not entirely! If we are willing to place a guard around our optimisation, and if we are willing to sacrifice compatibility with counting frames, we might still be able to do our tail call optimisation.

After the compiler detects the tail call, it transforms the function’s code into a loop. At the end, just before going into the next round of iteration, it looks up collatz and checks if it is still this function. If not, it does a proper function call to whatever collatz refers to, now.

def collatz(x):
    while True:
        print(x, end=', ')
        if x == 1:
            return
        if x % 2 == 0:
            y = x // 2
        else:
            y = 3 * x + 1
        if collatz is not this_function:  # <- this is the guard
            collatz(y)
            return
        else:
            x = y
this_function = collatz

This guarded version runs about 10% faster, but it might still not work properly if we are counting frames. If the loop iterates a few rounds before falling back on calling collatz, there are some frames missing, distorting the picture. Of course, in this particular instance of a toy example, the compiler can actually be quite certain that collatz is not suddenly going to change, as there is no opportunity for it (i. e. the code does not call any other functions which could mess with the namespace) - at least as long as there is no concurrency involved.

Such optimisations with guards are particularly usefull when combined with Just In Time compilers (or JIT, for short). A JIT compiler does not only look at the static program code, but also takes into account the current (and previous) values of variables. You have seen the idea of recording all the arguments of a function above. Now, a JIT compiler will look at this list, and figure out that x always refers to a positive integer. it then aggressively optimises the entire collatz function, based on the assumption that x is always a positive integer. At the very beginning, however, there is a guard, which checks if the assumption about x really holds. Should it ever fail, it falls back to the original Python code, just as we fall back to properly calling the collatz function in our code above.

Projects like PyPy and Numba do exactly that: they track the types of variables, and then use a JIT compiler, which really compiles the Python code down to optimised, native machine code. At the first sign of trouble, however, they fall back to the (slow) Python code, making sure that the program still behaves exactly as it is supposed to. PEP 509 also added a helper field to dictionaries, which helps exactly these guards in detecting any changes.

So, why is there no Ahead Of Time compiler (AOT) for Python doing the same thing? Because the tricky part is getting a reliable estimate of what type a variables has. Doing this during runtime is easy (you can simply do it yourself now, purely in Python). Doing it just by looking at the program code is very hard, though, particularly if you load data from any external source, such as loading it from a file, say.