Implementing Code Transformations in Python

Have you ever wondered how a compiler might actually optimise your code to make it run faster? Or would like to know what an Abstract Syntax Tree (AST) is, and what it might be used for?

In this article I am going to give an overview of how Python code is transformed into tree form (AST). Once you have the AST of your program, you can then look for opportunities to optimise and transform your code. Be aware, though, that optimising Python programs in non-trivial ways is extremely hard, as I will discuss elsewhere.

The Program Code as a Tree

How does the computer make sure that the order in which your expressions are evaluted is correct? It first translates your program code into a tree-structure called AST.

The traditional notion of an interpreted language (like Python) is that the interpreter goes through your program code, and executes whatever it finds there on the spot, without doing any translation of the Python code to machine code. In reality, however, such an execution scheme causes quite a few issues, which make it rather impractical.

Take, for instance, the simple issue of operator precedence. In an expression like 3 + 4 * x, the 4 * x must be computed first, and only then the addition of 3 with the result of the multiplication can be carried out. In math class, we teach operator precedence by drawing trees beneath the expression:

3 + 4 * x
\   \   /
 \   \ /
  \   M
   \ /
    A

Python adheres to the standard rules of mathematical notation (do multiplication first, and addition afterwards). In order to get the precedence right, like most programming languages, Python first constructs the same tree as in the picture above. The overall operation is an addition (at the root of the tree), and while the left hand side of this addition is just a number, the right hand side is a multiplication. The resulting data structure looks like this:

BinOp(
  left  = Num(3),
  op    = Add(),
  right = BinOp(
            left  = Num(4),
            op    = Mult(),
            right = Name('x')
          )
)

BinOp stands for Binary Operation, referring to the fact that operations such as Addition and Multiplication have two operands. Obviously, you cannot do the addition without having a proper value for the right hand side, hence the multiplication must be computed first.

In the field of compilers and programming languages, such a tree is called an Abstract Syntax Tree, or AST for short. The AST in the example above comprises two BinOp-nodes, two Num-nodes, and one Name-node.

A nice feature of Python is the possibility to directly inspect, and print the AST for any given Python program. All you have to do is to import the standard module ast, parse your program, and then dump the result to your screen (parsing, by the way, is the process of translating a source program into an AST).

import ast
my_tree = ast.parse("3 + 4*x")
print(ast.dump(my_tree))

You will notice, however, that the AST generated by Python has some additional nodes and fields, and is printed in a single line, which makes it look more complicated (but it is not):

Module(body=[Expr(value=BinOp(left=Num(n=3), op=Add(), right=BinOp(left=Num(n=4), op=Mult(), right=Name(id='x', ctx=Load()))))])

Let’s break this into individual nodes as before, and you will rediscover the AST from above as part of the entire tree:

Module(body = [
    Expr(
        value = BinOp(
            left  = Num(n=3),
            op    = Add(),
            right = BinOp(
                left  = Num(n=4),
                op    = Mult(),
                right = Name(id='x', ctx=Load())
            )
        )
    )
])

Obviously, Python thinks that the string we gave it to parse constitutes an entire module. The body of the module is a list of all statements inside the module. The only statement in our example is an expression Expr, whose value is exactly what we have discussed above.

Note that the Name node has an additional field ctx (short for context), which has a value of Load(). This is Python’s way to say that we are using the value stored in the variable x, as opposed to (re)defining or deleting the name x. Go ahead, and try and parse something like del x, or x = 123, and you will see how the ctx field in the Name node changes to Del(), or Store(), respectively.

By the way: if you install the astunparse module, you can print your ASTs much nicer, and even transform the AST back to valid Python code.

The Rest of The Compilation Process

Once we have the AST of a program, it is in principle possible to execute the program by going through the AST, and perform all the operations accordingly. There are at least two drawbacks to this, though. First, the AST might take up a relatively large amount of memory space, particularly if it includes redundant information. And second, the traversal of the AST might take longer than necessary. In short: you can do it, but it is inefficient.

Instead of executing the AST directly, the compiler will produce bytecode, which is then executed by Python’s virtual machine. While it is beyond the scope of this article to go into details here, the basic principle is that the compiler translates the AST into Reverse Polish Notation (RPN). Instead of having the operator + between the left and the right operand, it comes after both operands. For the example 3 + 4*x above we get the sequence 3 4 x * + (what makes this notation particularly nice is that it is evident from the sequence that we need to do the multiplication first, and then the addition). Since each of the five elements in this sequence can basically be represented by a single byte, the resulting code is called bytecode. Python then uses a stack-based virtual machine to efficiently execute this code.

In other words, the compilation process of a Python program has two steps. First the input program is parsed, producing an abstract syntax tree (AST). Then the compiler generates bytecode by going through the AST. The Python interpreter will then execute this bytecode. When doing optimisations, we can do that either on the AST, or on the bytecode. Both options have their benefits and disadvantages.

Finally, be aware that while the AST is typically shared by every Python implementation, the translation of the AST to bytecode might vary, and some Python implementation might produce, say, JavaScript instead of bytecode.

Other Programming Language Paradigms

Not every programming language uses an infix notation like Python does. Two notable examples are PostScript, where the program is directly written in Reverse Polish Notation, and Lisp, of course, where programs are basically written in Polish Notation. So, our example expression would in Lisp look like this: (+ 3 (* 4 x)).

Transforming a Node inside the AST

Once you have the AST of your program, how do you transform individual parts of that tree? Use Python’s handy builtin facilities.

When we look at an AST, and discover, say, that the left and right field of a BinOp node both are numbers (Num nodes), we can do the respective computation already ahead of time, and then replace the BinOp by a simple Num node.

Of course, we must be very careful not to change the behaviour of the program when doing such a transformation. As an example, in len([a(), b(), c(), d()]), we can tell that the result will be 4. But we cannot replace the entire expression by 4, because the four functions a, b, c, d still need to get properly called.

Anyway, let us start with an easy optimisation. Whenever we encounter the name pi in the source program, we replace it by the value 3.14159265. Python’s ast module already provides the necessary structures to achieve this: a transformer-class NodeTransformer, that goes through the entire AST and checks for each node, if the respective node can be replaced. Per default, the transformation method for each node just returns the original node, so that we end up with exactly the same AST as we started out with. But we can easily override the method for Name nodes, say, check if we are dealing with pi, and then return a Num node instead of the original name.

import ast

class MyOptimizer(ast.NodeTransformer):

    def visit_Name(self, node: ast.Name):
        if node.id == 'pi':
            return ast.Num(n=3.14159265)
        return node

tree = ast.parse("y = 2 * pi")
optimizer = MyOptimizer()
tree = optimizer.visit(tree)
print(ast.dump(tree))

In order to have the transformer/optimiser go through our tree, we have to call its visit method, which then returns the new, transformed, tree.

Unfortunately, we cannot compile and run the resulting AST due to a technical detail. So far invisible, (almost) all nodes in the AST also carry the fields lineno and col_offset, respectively. They give the exact position of the respective node inside the original source code. If they are not properly set, the compiler will complain, and refuse to do its job.

So, let us copy the respective fields from the original Name node to the new Num node. Then, we can compile and execute the resulting AST:

import ast

class MyOptimizer(ast.NodeTransformer):

    def visit_Name(self, node: ast.Name):
        if node.id == 'pi':
            result = ast.Num(n=3.14159265)
            result.lineno = node.lineno
            result.col_offset = node.col_offset
            return result
        return node

tree = ast.parse("print(2 * pi)")
optimizer = MyOptimizer()
tree = optimizer.visit(tree)
code = compile(tree, "<string>", "exec")
exec(code)

Note that the compile function requires not only the source (which can be the original program as a string, or an AST), but also a filename (which we set to "<string>"), and either "exec", "eval", or "single".

The need to copy the fields for the position of a node within the source file occurs quite often. Therefore, the ast module has a dedicated function copy_location for exactly this purpose, so that we can write:

    def visit_Name(self, node: ast.Name):
        if node.id == 'pi':
            result = ast.Num(n=3.14159265)
            return ast.copy_location(result, node)
        return node

Finally, we can extend the previous example so as to do an actual optimisation, namely on the BinOp node. Our transformation rule says that first we need to transform/optimise the left, and the right hand node of BinOp. When we then see that both the left and the right node are numbers, we might perform the actual computation right here, and replace the original BinOp by just a number holding the result of the operation.

class MyVisitor(ast.NodeTransformer):

    def visit_BinOp(self, node: ast.BinOp):
        node.left = self.visit(node.left)
        node.right = self.visit(node.right)
        if isinstance(node.left, ast.Num) and isinstance(node.right, ast.Num):
            if isinstance(node.op, ast.Add):
                result = ast.Num(n = node.left.n + node.right.n)
                return ast.copy_location(result, node)
            elif isinstance(node.op, ast.Mult):
                result = ast.Num(n = node.left.n * node.right.n)
                return ast.copy_location(result, node)
        return node

    def visit_Name(self, node: ast.Name):
        if node.id == 'pi':
            result = ast.Num(n=3.14159265)
            return ast.copy_location(result, node)
        return node

tree = ast.parse("y = 2 * pi + 1")
optimizer = MyOptimizer()
tree = optimizer.visit(tree)
print(ast.dump(tree))

Actually, CPython’s compiler already optimises BinOp nodes as presented here. The respective code is written in C and can be found in Python/ast_opt.c. Note that CPython’s optimiser is more general, as it is not limited to numbers, as in our example here, but can work with various types of constant values.

Checking the Nodes in the AST

How do we make sure that our transformations are correct? For a start, go through the entire AST, and look at the whole program first.

The optimiser presented above has a serious glitch remaining. What happens if you were to redefine pi somewhere in your program? Just imagine something as simple and meaningful as pi = 4. Our optimiser will just replace the pi on the left hand side by the numeric value 3.14159265, and Python will then refuse to compile it because you cannot assignt something to a literal value.

This behaviour might be exactly what you want in the first place, making pi a true constant, which gets replaced during compilation, and can never be reassigned to stand for a different value. But it certainly violates the semantics of Python.

So, what if want to stick to the semantics of Python, and still replace pi whenever we can? Then we first have to go through the entire program and check if pi gets assigned a value anywhere in the program. For the time being, we keep it simple: we bail out of substituting pi if there is any assignment to pi anywhere in the program.

We now use a node visitor, which is similar to the node transformer above. In contrast to the transformer, the visitor is not supposed to modify any of the nodes, but just go through the AST and look at the nodes (visit them). Accordingly, the visit methods do not return anything.

In our case, we check if a Name node refers to pi and does something other than load the value of pi (remember the context field ctx).

import ast

class MyVisitor(ast.NodeVisitor):

    def __init__(self):
        self.modify_pi = False

    def visit_FunctionDef(self, node: ast.FunctionDef):
        if node.name == 'pi':
            self.modify_pi = True
        self.generic_visit(node)

    def visit_Name(self, node: ast.Name):
        if node.id == 'pi' and not isinstance(node.ctx, ast.Load):
            self.modify_pi = True

program = """
def pi():
    return 3.1415
print(2 * pi())
"""
tree = ast.parse(program)
my_visitor = MyVisitor()
my_visitor.visit(tree)
print("Pi modified:", my_visitor.modify_pi)

The method generic_visit(node) is called by the visitor for each node for which we do not provide a specialised visit method. In other words: there is no visit_FunctionDef method in the class NodeVisitor, which we could call using super(). In the case of function definitions, we need to call the generic visitor to make sure that the entire body of the function is also processed correctly. Otherwise we could hide a global pi statement inside a function, and change the value of pi globally without our optimiser noticing.

Locals in Python

Our method for determining if the name pi is modified by the programmer is rather crude. Nevertheless, Python’s compiler works in a very similar way when it determines which names in a function scope correspond to local variables. If a variable is modified anywhere in a function scope (and has not been made explicitly global, e. g. through a global statement), then it is treated as a local variable in the entire function scope.

The following example would run just fine without the fourth line. However, even though the x = 0 on the fourth line is never executed, it still counts as an assignment to x, and therefore makes x a local variable for the entire function, that is even on line three. That is why Python will complain that the variable x on line three is unbound (has no value yet).

x = 1
def print_x():
    print(x)
    if False: x = 0
print_x()

If you are interested in what Python does in detail, have a look at Python/symtable.c.

Conclusion

Like most programming languages, Python does not execute a given program directly from the source code. Rather, the source code is translated in a two-step process to an abstract syntax tree (AST), and then to bytecode for a stack-based virtual machine. Python also provides some really nice facilities to analyse, and even transform the AST of any given Python program, and allows the modified AST to be compiled and executed afterwards. This means that we can easily implement our own optimisations.

There are, of course, quite a few details I have glossed over. Making sure that your optimisation is correct in all possible circumstances and cases is a rather involved process. But the intent of this article is not to write a production ready optimisation, but to give you a basic understanding of how Python analysis your program code, and how you can start to tinker around with code transformations, and eventually write your own code optimisations.