Jim Crist-Harif

Extending Numba with overload

Posted on April 06, 2020

Numba is an amazing project. It's able to just-in-time (JIT) compile a fairly large subset of numerical Python code to fast machine code, with the user only having to apply a decorator to their function. It looks like this:

import random
import numba as nb

@nb.njit   # <- magic Numba decorator
def monte_carlo_pi(nsamples):
    """Estimate the value of Pi using the monte-carlo method"""
    acc = 0
    for i in range(nsamples):
        x = random.random()
        y = random.random()
        if (x ** 2 + y ** 2) < 1.0:
            acc += 1
    return 4.0 * acc / nsamples

The above is an example routine for estimating Pi using the monte-carlo method. Without the nb.njit decorator monte_carlo_pi would successfully run (everything is valid Python), but with the decorator itruns significantly faster (roughly 30x in my quick benchmark).

Numba does this by analyzing the decorated function's bytecode to build an intermediate representation (IR) of the its structure. Type inference is then applied, followed by a series of IR transformations. Finally the IR is used to generate LLVM IR, which is then compiled to machine code. Part of this transformation process involves swapping out calls to Python functions (like random.random above) with faster compiled versions that Numba knows about.

Numba natively supports a decent subset of Python and NumPy (see here and here for a full reference). But sometimes you'll run into situations where Numba doesn't know about the function you're referencing. In this case you'll get an error:

>>> import numpy as np
>>> import numba as nb

>>> @nb.njit
... def clipped_sum(x):
...     return np.clip(x, 0, 1).sum()

>>> clipped_sum(np.arange(10))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/miniconda3/envs/numba/lib/python3.8/site-packages/numba/dispatcher.py", line 401, in _compile_for_args
    error_rewrite(e, 'typing')
File "/opt/miniconda3/envs/numba/lib/python3.8/site-packages/numba/dispatcher.py", line 344, in error_rewrite
    reraise(type(e), e, None)
File "/opt/miniconda3/envs/numba/lib/python3.8/site-packages/numba/six.py", line 668, in reraise
    raise value.with_traceback(tb)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Use of unsupported NumPy function 'numpy.clip' or unsupported use of the function.

File "<ipython-input-8-3524b5de079c>", line 3:
def clipped_sum(x):
    return np.clip(x, 0, 1).sum()
    ^

[1] During: typing of get attribute at <ipython-input-8-3524b5de079c> (3)

File "<ipython-input-8-3524b5de079c>", line 3:
def clipped_sum(x):
    return np.clip(x, 0, 1).sum()
    ^

Numba (at least version 0.48.0) doesn't know how to handle the np.clip call. We'll need to use Numba's extension API to add support for np.clip.

Implementing new functions with overload

To add support for a new function to Numba, we can make use of the numba.extending.overload decorator. The decorated function is called at compile time with the types of the arguments, and should return an implementation for those given types. This implementation will then be jit compiled and used in place of the overloaded function.

Lets work a simplified example first, before handling np.clip. Say we wanted to add Numba support for the following function:

def myfunc(x):
    if isinstance(x, int):
        return x + 1
    elif isinstance(x, float):
        return x * 2
    else:
        raise TypeError("x must be an int or a float")

The above checks if x is an int, and if so increments it. If it's a float it doubles it. Otherwise it raises a TypeError.

These type checks happen at runtime. When writing our Numba implementation, the types are known at compile time. This means that we can elide the type checks at runtime by handling the type-based dispatching (and erroring) at compile time.

The corresponding Numba implementation for this function is:

from numba import types
from numba.errors import TypingError
from numba.extending import overload

@overload(myfunc)
def implement_myfunc(x):
    # This is a code generator for ``myfunc``.
    # Here x is the compile-time *type*

    if isinstance(x, types.Integer):
        def impl(x):
            # This is an *implementation* of ``myfunc`` (in this case the
            # implementation for integer values of x).
            # Here x is the runtime *value*
            return x + 1

    elif isinstance(x, types.Float):
        def impl(x):
            return x * 2

    else:
        # If an invalid type is passed to ``implement_myfunc``, a
        # ``numba.types.TypingError`` should be raised. This helps inform
        # the user what went wrong.
        raise TypingError("x must be an int or a float")

    return impl

At a high level implement_myfunc and myfunc look quite similar. Both branch on the type of x, with branches for integers, floats, and errors. But while myfunc returns values, implement_myfunc returns a callable that will be JIT compiled by numba and used to implement myfunc for the provided type.

Note that Numba types are distinct from their Python counterparts (but there's usually a one-to-one mapping between them). These types can be found in numba.types (documentation). If you don't know what the corresponding Numba type is for something, you can use numba.typeof.

>>> from numba import typeof
>>> typeof(1)
int64
>>> typeof((1, 5.0))
Tuple(int64, float64)

Also, instead of a TypeError, when an invalid type (or combination of types) is provided, the decorated function (e.g. implement_myfunc) a should raise a numba.errors.TypingError. This will be reported back to the user.

Adding an overload for np.clip

Now lets apply the same process to np.clip, reimplementing the function in a way that Numba's JIT can reason about. To make sure the NumPy and Numba versions are compatible, we first check the docstring:

"""
Clip (limit) the values in an array.

Given an interval, values outside the interval are clipped to
the interval edges.  For example, if an interval of ``[0, 1]``
is specified, values smaller than 0 become 0, and values larger
than 1 become 1.

Equivalent to but faster than ``np.maximum(a_min, np.minimum(a, a_max))``.
No check is performed to ensure ``a_min < a_max``.

Parameters
----------
a : array_like
    Array containing elements to clip.
a_min : scalar or array_like or None
    Minimum value. If None, clipping is not performed on lower
    interval edge. Not more than one of `a_min` and `a_max` may be
    None.
a_max : scalar or array_like or None
    Maximum value. If None, clipping is not performed on upper
    interval edge. Not more than one of `a_min` and `a_max` may be
    None. If `a_min` or `a_max` are array_like, then the three
    arrays will be broadcasted to match their shapes.
...
"""

We don't need to support all of np.clip's possible arguments yet, just the ones we need. To simplify things, we'll support:

After a bit of work, I ended up with the following implementation:

import numpy as np
from numba import types
from numba.errors import TypingError
from numba.extending import overload

@overload(np.clip)
def impl_clip(a, a_min, a_max):
    # Check that `a_min` and `a_max` are scalars, and at most one of them is None.
    if not isinstance(a_min, (types.Integer, types.Float, types.NoneType)):
        raise TypingError("a_min must be a_min scalar int/float")
    if not isinstance(a_max, (types.Integer, types.Float, types.NoneType)):
        raise TypingError("a_max must be a_min scalar int/float")
    if isinstance(a_min, types.NoneType) and isinstance(a_max, types.NoneType):
        raise TypingError("a_min and a_max can't both be None")

    if isinstance(a, (types.Integer, types.Float)):
        # `a` is a scalar with a valid type
        if isinstance(a_min, types.NoneType):
            # `a_min` is None
            def impl(a, a_min, a_max):
                return min(a, a_max)
        elif isinstance(a_max, types.NoneType):
            # `a_max` is None
            def impl(a, a_min, a_max):
                return max(a, a_min)
        else:
            # neither `a_min` or `a_max` are None
            def impl(a, a_min, a_max):
                return min(max(a, a_min), a_max)
    elif (
        isinstance(a, types.Array) and
        a.ndim == 1 and
        isinstance(a.dtype, (types.Integer, types.Float))
    ):
        # `a` is a 1D array of the proper type
        def impl(a, a_min, a_max):
            # Allocate an output array using standard numpy functions
            out = np.empty_like(a)
            # Iterate over `a`, calling `np.clip` on every element
            for i in range(a.size):
                # This will dispatch to the proper scalar implementation (as
                # defined above) at *compile time*. There should have no
                # overhead at runtime.
                out[i] = np.clip(a[i], a_min, a_max)
            return out
    else:
        raise TypingError("`a` must be an int/float or a 1D array of ints/floats")

    # The call to `np.clip` has arguments with valid types, return our
    # numba-compatible implementation
    return impl

With our implementation registered, we should now be able to use np.clip with Numba. Verifying:

>>> import numpy as np
>>> import numba as nb

>>> @nb.njit
... def test_clip(x, a_min, a_max):
...     return np.clip(x, a_min, a_max)

>>> x = np.arange(10)

>>> test_clip(x, 2, 5)
array([2, 2, 2, 3, 4, 5, 5, 5, 5, 5])

>>> test_clip(x, None, 5)
array([0, 1, 2, 3, 4, 5, 5, 5, 5, 5])

>>> test_clip(5.0, 0, 3)
3.0

Registering Numba Extensions with Entry Points

Our above example using np.clip worked because our overloaded definition was already registered with Numba. As long as our overload decorated functions have been loadAs long as our overload decorated functions have been loaded before Numba tries to compile something that relies on them, everything should just work. However, sometimes you may need (or want) to store the overloaded definitions in a package that would not normally be imported by users. For example, the numba-scipy package adds Numba support for the SciPy library, but is a separate package from scipy.

To avoid forcing users to import numba_scipy to enable the extension, Numba relies on entry points to automatically discover any installed extensions.

To register a module as a Numba extension, you need to:

For more information on registering Numba extensions using entry points, see the documentation.

Wrapping Up

Unmodified, Numba is able to compile a decent subset of Python and NumPy. If you're writing code that looks similar to how it'd be done in a "low-level" language like C (e.g. loops, arithmetic, arrays of scalars, ...) you may never need to use the extension API. But when needed, using overload to add support for new functions can be quite useful.

The full code for our example extension module can be found here: numba-overload-example.