Jim Crist-Harif

GSoC Week 7: Expression Trees and Substitution

Posted on July 06, 2014

Late post this week due to the celebration of everything that is American. This week I finally got my first PR merged. The general linearization code is now part of the Sympy codebase. I currently have a PR for Lagrange support waiting to be reviewed. After that I just need to write the ReST docs and the first part of the project is "complete". The rest of the week was spent on more optimization work. I'm getting closer to being able to solve the bicycle example in a reasonable amount of time!

Last week's post showed the issues with the large expression size and subs (it takes forever to run). I took some time this week to look into how expressions work in sympy, and wrote a specialized subs function for use in sympy.physics.mechanics. The rest of this post will show give an overview and some benchmarks of this code.

Expression Trees

In sympy, expressions are stored as trees. Each node is an object that has an attribute args that contains a list of it's child nodes. This is best shown by an example:

In [1]:
from sympy import *
from sympy.printing.dot import dotprint
In [2]:
a, b, c = symbols('a, b, c')
In [3]:
test = a*cos(a + b) + a**2/b
test
Out[3]:
a**2/b + a*cos(a + b)
In [4]:
with open('test.dot', 'w') as fil:
    fil.write(dotprint(test))
In [5]:
%%bash
dot -Tpng test.dot -o test.png
In [6]:
from IPython.core.display import Image 
Image(filename='test.png')
Out[6]:

The root node of this tree is an Add object, as the outermost operation is adding two smaller expressions together. Going down the left side first, we see that $a^2/b$ is stored as multiplying $a^2$ times $b^{-1}$. Sympy doesn't have Div objects, fractions are all expressed using Mul, with the denominator wrapped in a Pow(den, -1). Traversing the right side $a \cos(a + b)$ is stored as multiplying a and a cos object together. The cos object itself contains one child node - an Add - which holds a and b.

Crawling the Tree

The design of sympy expression objects uses two key features: args, and func. As mentioned above, object.args holds a tuple of all the child nodes for that object. Likewise, object.func is a class method that takes in arguments, and returns an instance of that class. For most objects in sympy, running

object.func(*object.args) == object

will be a true expression. I say most because not all objects adhere to this (leaf nodes). There has been a lot of discussion about this here, if you're interested.

One of the great things about this design is that it makes it very easy to write operations that modify these trees using recursion. Normally, recursion in python is frowned upon because the function calls add overhead that could be removed if rewritten as a loop. There is also a maximum recursion depth (default of 1000) to prevent stackoverflow conditions. However, a sympy expression that has 1000 nested nodes is highly unlikely (even in mechanics), and the recursion makes the code much more readable. As the zen of python says, "Readability counts".

A simple crawler to print every node is written below. It consists of two functions.

The first one is a generic function crawl, that crawls the expression tree, calls func on each node, and returns the result if there is one. Otherwise it recurses down a level into the child nodes, forming a list of new_args, and then calls expr.func to rebuild the expression from those args.

The second one is a printer function. As we don't want to modify the expression at all, we'll just print the node, and then return the expression if it doesn't have args (it's a leaf node). Note that there are more efficient ways to traverse the tree and print all the nodes - this is mostly to demonstrate crawl, as it will be used later.

Using these two functions a function that crawls the tree, prints every node, and returns the original expression can be composed using a simple lambda statement.

In [7]:
def crawl(expr, func, *args, **kwargs):
    """Crawl the expression tree, and apply func to every node."""
    val = func(expr, *args, **kwargs)
    if val is not None:
        return val
    new_args = (crawl(arg, func, *args, **kwargs) for arg in expr.args)
    return expr.func(*new_args)

def printer(expr):
    """Print out every node"""
    print(expr)
    if not expr.args:
        return expr

print_expr = lambda expr: crawl(expr, printer)

Let's test this function on our expression from above:

In [8]:
temp = print_expr(test)
# Checking that the expression was unchanged (temp == test)
assert temp == test
a**2/b + a*cos(a + b)
a**2/b
a**2
a
2
1/b
b
-1
a*cos(a + b)
a
cos(a + b)
a + b
b
a

Comparing the printed results with the tree diagram from before, one can see how each big expression can be decomposed into smaller expressions. Further, the rebuilt expression after traversing was identical to the input expression. In the next section we'll write another function that changes the expression tree using crawl.

A Custom subs Function

In sympy.physics.mechanics, we deal with symbols, dynamic symbols (which are of type AppliedUndef), and derivatives of these dynamicsymbols. Unfortunately, the provided subs function traverses inside the Derivative terms, giving underdesired results:

In [9]:
from sympy.physics.mechanics import dynamicsymbols
x, y = dynamicsymbols('x, y')
test = a*(x + x.diff())
test
Out[9]:
a*(x(t) + Derivative(x(t), t))
In [10]:
# Subbing in b for x. Desired result is a*(b + x.diff())
test.subs({x: b})
Out[10]:
a*(b + Derivative(b, t))

To get around this problem, we've been using a custom function _subs_keep_derivs. This function creates two substitution dictionaries - one with Derivative, and one without. Four substitutions then take place:

  1. Perform subs with the terms in the derivative dictionary
  2. Substitute in Dummy symbols for all Derivative terms in the resulting expression
  3. Perform subs with the terms in the non-derivative dictionary
  4. Substitute back the original Derivative terms from the Dummy symbols

This is slow due to the need for four calls to expr.subs. Also, subs applies substitutions sequentially (i.e. each term in the substitution dict requires its own tree traversal). For our purposes in mechanics, this is unecessary. Using the already written crawl function, we can compose our own subs that ignores terms inside derivative objects fairly easily:

In [11]:
def sub_func(expr, sub_dict):
    """Perform expression subsitution, ignoring derivatives."""
    if expr in sub_dict:
        return sub_dict[expr]
    elif not expr.args or expr.is_Derivative:
        return expr
    
new_subs = lambda expr, sub_dict: crawl(expr, sub_func, sub_dict)

That's it. Due to the composable nature of crawl, the code needed to perform this operation is incredibly simple. Let's test it to make sure it on our previous expression:

In [12]:
# Simple example
new_subs(test, {x: b})
Out[12]:
a*(b + Derivative(x(t), t))

So it leaves terms inside derivaties alone, exactly as desired. We can see how this compares to the previous implementation:

In [13]:
# Old way of doing things, taken from sympy.physics.mechanics.functions
def _subs_keep_derivs(expr, sub_dict):
    """Performs subs exactly as subs normally would be,
    but doesn't sub in expressions inside Derivatives."""

    ds = expr.atoms(Derivative)
    gs = [Dummy() for d in ds]
    items = sub_dict.items()
    deriv_dict = dict((i, j) for (i, j) in items if i.is_Derivative)
    sub_dict = dict((i, j) for (i, j) in items if not i.is_Derivative)
    dict_to = dict(zip(ds, gs))
    dict_from = dict(zip(gs, ds))
    return expr.subs(deriv_dict).subs(dict_to).subs(sub_dict).subs(dict_from)
In [14]:
# Benchmark substitution
print("Using _subs_keep_derivs")
%timeit _subs_keep_derivs(test, {x: b})
print("Using new_subs")
%timeit new_subs(test, {x: b})
Using _subs_keep_derivs
100 loops, best of 3: 3 ms per loop
Using new_subs
10000 loops, best of 3: 39.9 µs per loop

So it's significantly faster. For this small benchmark, approximately 75x faster. Also, as it works in only one traversal of the tree it has a smaller computational complexity - meaning that for larger expressions this speed increase will be even higher. For kicks, let's see how it compares to normal subs for an expression without derivatives:

In [15]:
# For kicks, see how it compares to subs for expr without derivative:
test = a*cos(x + b) + (a - x)**2
print("Using subs")
%timeit test.subs(test, {a: 1, b: 2})
print("Using new subs")
%timeit new_subs(test, {a: 1, b: 2})
Using subs
1000 loops, best of 3: 676 µs per loop
Using new subs
1000 loops, best of 3: 247 µs per loop

So our implementation of subs is faster than sympy's. However, this micro benchmark isn't all that meaningful. Normal subs works in multiple passes so that expressions inside the sub_dict are also affected. subs also incorporates math knowledge about the expressions, while ours just does a naive direct match and replace. For our purposes though, this is sufficient. Also, as with all micro-benchmarks, this should be taken with a grain of salt. The main point is that our custom subs function is significantly faster than the old method of _subs_keep_derivs.

Custom Simplification

The other major issue I discussed last week was that some expressions result in nan or oo (infinity) when not simplified, but after simplification result in a realizable expression. For example:

In [16]:
test = sin(a)/tan(a)
print("Before simplification:", test.subs(a, 0))
print("After simplification:", simplify(test).subs(a, 0))
Before simplification: nan
After simplification: 1

For small expressions, doing simplify before evaluation is acceptable, but for larger ones simplification is way too slow. However, these divide by zero errors are caused by subexpressions, not the whole expression. Using some knowledge of the types of expressions present in the mechanics module, we can come up with some simple heuristics for what can result in nan, oo and zoo:

  1. tan(pi/2)
  2. Fractions with 0 in the denominator

In reality, these are the same thing, because

$$ \tan(\pi/2) = \frac{\sin(\pi/2)}{\cos(\pi/2)} = \frac{1}{0} $$

Using this knowledge, we can come up with a simple algorithm for performing subs and catching these issues at the same time:

  1. Replace all tan(*) with sin(*)/cos(*). This will allow us to only have to check for denominator = 0 conditions.
  2. For nodes that are fractions, check if the denominator evaluates to 0. If so, apply simplify to the fraction, and then carry on as normal.

This may not catch every instance that results in would result in a nan without simplification, but it should catch most of them. Also, the algorithm is so simple, that it can be implemented in only a few lines. First, the tan replacement. This requires almost no new code, as it can be composed using the already written crawl function:

In [17]:
def tan_repl_func(expr):
    """Replace tan with sin/cos"""
    if isinstance(expr, tan):
        return sin(*expr.args)/cos(*expr.args)
    elif not expr.args or expr.is_Derivative:
        return expr

tan_repl = lambda expr: crawl(expr, tan_repl_func)

Testing it:

In [18]:
tan_repl(tan(a) + tan(a)/tan(b))
Out[18]:
sin(a)/cos(a) + sin(a)*cos(b)/(sin(b)*cos(a))

So that works as expected. Now for the second pass; the subs with denominator checking and selective simplification. This takes a little bit more code than before, but I've heavily commented it so you should be able to see what's going on:

In [19]:
def smart_subs_func(expr, sub_dict):
        # Decompose the expression into num, den
        num, den = fraction_decomp(expr)
        if den != 1:
            # If there is a non trivial denominator, we need to handle it
            denom_subbed = smart_subs_func(den, sub_dict)
            if denom_subbed.evalf() == 0:
                # If denom is 0 after this, attempt to simplify the bad expr
                expr = simplify(expr)
            else:
                # Expression won't result in nan, find numerator
                num_subbed = smart_subs_func(num, sub_dict)
                return num_subbed/denom_subbed
        # We have to crawl the tree manually, because `expr` may have been
        # modified in the simplify step. First, perform subs as normal:
        val = sub_func(expr, sub_dict)
        if val is not None:
            return val
        new_args = (smart_subs_func(arg, sub_dict) for arg in expr.args)
        return expr.func(*new_args)

def fraction_decomp(expr):
    """Return num, den such that expr = num/den"""
    if not isinstance(expr, Mul):
        return expr, 1
    num = []
    den = []
    for a in expr.args:
        if a.is_Pow and a.args[1] == -1:
            den.append(1/a)
    else:
        num.append(a)
    if not den:
        return expr, 1
    num = Mul(*num)
    den = Mul(*den)
    return num, den

Finally, we can put everything from above inside a nice wrapper function:

In [20]:
def smart_subs(expr, sub_dict):
    """Performs subs, checking for conditions that may result in `nan` or 
    `oo`, and attempts to simplify them out.

    The expression tree is traversed twice, and the following steps are
    performed on each expression node:
    - First traverse: 
        Replace all `tan` with `sin/cos`.
    - Second traverse:
        If node is a fraction, check if the denominator evaluates to 0.
        If so, attempt to simplify it out. Then if node is in sub_dict,
        sub in the corresponding value."""
    expr = crawl(expr, tan_repl_func)
    return smart_subs_func(expr, sub_dict)

Let's see if it works as expected. Using a simple test case:

In [21]:
test = tan(a)*cos(a)
sub_dict = {a: pi/2}
print("Without `smart_subs`:")
print(new_subs(test, sub_dict))
print("With `smart_subs`:")
print(smart_subs(test, sub_dict))
Without `smart_subs`:
nan
With `smart_subs`:
1

And some timings:

In [22]:
print("Using `smart_subs`:")
%timeit smart_subs(test, sub_dict)
print("Using simplification, then normal subs")
%timeit simplify(test).subs(sub_dict)
print("Using trigsimp, then normal subs")
%timeit trigsimp(test).subs(sub_dict)
Using `smart_subs`:
10000 loops, best of 3: 92 µs per loop
Using simplification, then normal subs
10 loops, best of 3: 42.9 ms per loop
Using trigsimp, then normal subs
10 loops, best of 3: 30.8 ms per loop

Using selective simplification, the same results can be obtained for a fraction of the cost. 360 times faster for this small test.

Let's see what the overhead of smart_subs is for an expression that doesn't need it:

In [23]:
test = sin(x + y)/cos(x + y)* (a + b) + a/(a + x)
sub_dict = {x: a, y: 1}
print("new_subs")
%timeit new_subs(test, sub_dict)
print("smart_subs")
%timeit smart_subs(test, sub_dict)
print("Using simplification, then normal subs")
%timeit simplify(test).subs(sub_dict)
print("Using trigsimp, then normal subs")
%timeit trigsimp(test).subs(sub_dict)
new_subs
1000 loops, best of 3: 222 µs per loop
smart_subs
1000 loops, best of 3: 2.34 ms per loop
Using simplification, then normal subs
1 loops, best of 3: 1.11 s per loop
Using trigsimp, then normal subs
1 loops, best of 3: 815 ms per loop

So there's considerable overhead, which was expected. Still, it's much faster than using simplify first, and then running subs.

The best use method would be to first try it with new_subs, and if you get a nan or oo, then try using smart_subs. To aid in this, we can write a nice wrapper function msubs that contains both methods:

In [24]:
def msubs(expr, sub_dict, smart=False):
    """A custom subs for use on expressions derived in physics.mechanics.

    Traverses the expression tree once, performing the subs found in sub_dict.
    Terms inside `Derivative` expressions are ignored:

    >>> x = dynamicsymbols('x')
    >>> msubs(x.diff() + x, {x: 1})
    Derivative(x, t) + 1

    If smart=True, also checks for conditions that may result in `nan`, but
    if simplified would yield a valid expression. For example:

    >>> (sin(a)/tan(a)).subs(a, 0)
    nan
    >>> msubs(sin(a)/tan(a), {a: 0}, smart=True)
    1

    It does this by first replacing all `tan` with `sin/cos`. Then each node
    is traversed. If the node is a fraction, subs is first evaluated on the
    denominator. If this results in 0, simplification of the entire fraction
    is attempted. Using this selective simplification, only subexpressions
    that result in 1/0 are targeted, resulting in faster performance."""

    if smart:
        func = smart_subs
    else:
        func = new_subs
    if isinstance(expr, Matrix):
        # For matrices, func is applied to each element using `applyfunc`
        return expr.applyfunc(lambda x: func(x, sub_dict))
    else:
        return func(expr, sub_dict)
In [25]:
#Check that the results are the same:
msubs(test, sub_dict) == new_subs(test, sub_dict)
msubs(test, sub_dict, smart=True) == smart_subs(test, sub_dict)
Out[25]:
True

This code has been included in my optimization branch, and is performing admirably against all the included tests (so much speed up!!!). The big test though is the bicycle example.

The Big Test

To really put this new code to the test I applied it to the bicycle example, which has operation counts ranging from 400,000 to several million depending on which expression you're working with. How did it fair? Mixed results...

Using msubs with smart=False in the formation of the linearizer object resulted in a huge speed increase. Previously using _subs_keep_derivs there resulted in a run time of several hours. Now it runs in 14 seconds!

The second test is in the application of the operating point to the M, A, and B matrices. Previously evaluating these had resulted in nan and oo, and with operation counts exceeding 400,000 for M pre-simplification is out of the question. I tried using msubs with smart=True on this M, and let it run for 6 hours. It ended up consuming 98% of my RAM (4 GB worth), and it still wasn't done :( So for reasonable sized expressions the smart_subs we've implemented is acceptable, but it still doesn't work for the huge expressions. I'll have to keep working on optimizations to make this faster/reduce the initial expression size.

Still, all is not lost. Going from formulation to M, A, and B without the operating point substitution now only takes 20 minutes - down from the several hours before. This is actually faster than the previous linearize method, but unforunately results in expressions that evaluate to nan.

Future Work

Per my original timeline I was supposed to be working on "pre-linearization" routines for generating the equations of motion directly in linearized form. I think this will be dropped in favor of cleaning up and speeding up the existing codebase. There's some dead code that needs removing, and I have some other ideas for speed-ups that I'd like to try.

After that, I hope to get to the original part 3 of the project, which is adding support for matrices to the code-generation code. Not sure if this would live in sympy or pydy, but it would be extremely useful to myself and others. For now I plan on taking it week-by-week.


If you have thoughts on how to deal with the nan and oo issue, please let me know in the comments. Thanks!