# Jim Crist-Harif

Posted on April 06, 2020

Numba is an amazing project. It's able to just-in-time (JIT) compile a fairly large subset of numerical Python code to fast machine code, with the user only having to apply a decorator to their function. It looks like this:

import random
import numba as nb

@nb.njit   # <- magic Numba decorator
def monte_carlo_pi(nsamples):
"""Estimate the value of Pi using the monte-carlo method"""
acc = 0
for i in range(nsamples):
x = random.random()
y = random.random()
if (x ** 2 + y ** 2) < 1.0:
acc += 1
return 4.0 * acc / nsamples


The above is an example routine for estimating Pi using the monte-carlo method. Without the nb.njit decorator monte_carlo_pi would successfully run (everything is valid Python), but with the decorator itruns significantly faster (roughly 30x in my quick benchmark).

Numba does this by analyzing the decorated function's bytecode to build an intermediate representation (IR) of the its structure. Type inference is then applied, followed by a series of IR transformations. Finally the IR is used to generate LLVM IR, which is then compiled to machine code. Part of this transformation process involves swapping out calls to Python functions (like random.random above) with faster compiled versions that Numba knows about.

Numba natively supports a decent subset of Python and NumPy (see here and here for a full reference). But sometimes you'll run into situations where Numba doesn't know about the function you're referencing. In this case you'll get an error:

>>> import numpy as np
>>> import numba as nb

>>> @nb.njit
... def clipped_sum(x):
...     return np.clip(x, 0, 1).sum()

>>> clipped_sum(np.arange(10))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/miniconda3/envs/numba/lib/python3.8/site-packages/numba/dispatcher.py", line 401, in _compile_for_args
error_rewrite(e, 'typing')
File "/opt/miniconda3/envs/numba/lib/python3.8/site-packages/numba/dispatcher.py", line 344, in error_rewrite
reraise(type(e), e, None)
File "/opt/miniconda3/envs/numba/lib/python3.8/site-packages/numba/six.py", line 668, in reraise
raise value.with_traceback(tb)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Use of unsupported NumPy function 'numpy.clip' or unsupported use of the function.

File "<ipython-input-8-3524b5de079c>", line 3:
def clipped_sum(x):
return np.clip(x, 0, 1).sum()
^

[1] During: typing of get attribute at <ipython-input-8-3524b5de079c> (3)

File "<ipython-input-8-3524b5de079c>", line 3:
def clipped_sum(x):
return np.clip(x, 0, 1).sum()
^


Numba (at least version 0.48.0) doesn't know how to handle the np.clip call. We'll need to use Numba's extension API to add support for np.clip.

## Implementing new functions with overload

To add support for a new function to Numba, we can make use of the numba.extending.overload decorator. The decorated function is called at compile time with the types of the arguments, and should return an implementation for those given types. This implementation will then be jit compiled and used in place of the overloaded function.

Lets work a simplified example first, before handling np.clip. Say we wanted to add Numba support for the following function:

def myfunc(x):
if isinstance(x, int):
return x + 1
elif isinstance(x, float):
return x * 2
else:
raise TypeError("x must be an int or a float")


The above checks if x is an int, and if so increments it. If it's a float it doubles it. Otherwise it raises a TypeError.

These type checks happen at runtime. When writing our Numba implementation, the types are known at compile time. This means that we can elide the type checks at runtime by handling the type-based dispatching (and erroring) at compile time.

The corresponding Numba implementation for this function is:

from numba import types
from numba.errors import TypingError

def implement_myfunc(x):
# This is a code generator for myfunc.
# Here x is the compile-time *type*

if isinstance(x, types.Integer):
def impl(x):
# This is an *implementation* of myfunc (in this case the
# implementation for integer values of x).
# Here x is the runtime *value*
return x + 1

elif isinstance(x, types.Float):
def impl(x):
return x * 2

else:
# If an invalid type is passed to implement_myfunc, a
# numba.types.TypingError should be raised. This helps inform
# the user what went wrong.
raise TypingError("x must be an int or a float")

return impl


At a high level implement_myfunc and myfunc look quite similar. Both branch on the type of x, with branches for integers, floats, and errors. But while myfunc returns values, implement_myfunc returns a callable that will be JIT compiled by numba and used to implement myfunc for the provided type.

Note that Numba types are distinct from their Python counterparts (but there's usually a one-to-one mapping between them). These types can be found in numba.types (documentation). If you don't know what the corresponding Numba type is for something, you can use numba.typeof.

>>> from numba import typeof
>>> typeof(1)
int64
>>> typeof((1, 5.0))
Tuple(int64, float64)


Also, instead of a TypeError, when an invalid type (or combination of types) is provided, the decorated function (e.g. implement_myfunc) a should raise a numba.errors.TypingError. This will be reported back to the user.

Now lets apply the same process to np.clip, reimplementing the function in a way that Numba's JIT can reason about. To make sure the NumPy and Numba versions are compatible, we first check the docstring:

"""
Clip (limit) the values in an array.

Given an interval, values outside the interval are clipped to
the interval edges.  For example, if an interval of [0, 1]
is specified, values smaller than 0 become 0, and values larger
than 1 become 1.

Equivalent to but faster than np.maximum(a_min, np.minimum(a, a_max)).
No check is performed to ensure a_min < a_max.

Parameters
----------
a : array_like
Array containing elements to clip.
a_min : scalar or array_like or None
Minimum value. If None, clipping is not performed on lower
interval edge. Not more than one of a_min and a_max may be
None.
a_max : scalar or array_like or None
Maximum value. If None, clipping is not performed on upper
interval edge. Not more than one of a_min and a_max may be
None. If a_min or a_max are array_like, then the three
arrays will be broadcasted to match their shapes.
...
"""


We don't need to support all of np.clip's possible arguments yet, just the ones we need. To simplify things, we'll support:

• Scalar values for a_min/a_max (int, float, or None).
• Either scalar or 1D-array values for a.

After a bit of work, I ended up with the following implementation:

import numpy as np
from numba import types
from numba.errors import TypingError

def impl_clip(a, a_min, a_max):
# Check that a_min and a_max are scalars, and at most one of them is None.
if not isinstance(a_min, (types.Integer, types.Float, types.NoneType)):
raise TypingError("a_min must be a_min scalar int/float")
if not isinstance(a_max, (types.Integer, types.Float, types.NoneType)):
raise TypingError("a_max must be a_min scalar int/float")
if isinstance(a_min, types.NoneType) and isinstance(a_max, types.NoneType):
raise TypingError("a_min and a_max can't both be None")

if isinstance(a, (types.Integer, types.Float)):
# a is a scalar with a valid type
if isinstance(a_min, types.NoneType):
# a_min is None
def impl(a, a_min, a_max):
return min(a, a_max)
elif isinstance(a_max, types.NoneType):
# a_max is None
def impl(a, a_min, a_max):
return max(a, a_min)
else:
# neither a_min or a_max are None
def impl(a, a_min, a_max):
return min(max(a, a_min), a_max)
elif (
isinstance(a, types.Array) and
a.ndim == 1 and
isinstance(a.dtype, (types.Integer, types.Float))
):
# a is a 1D array of the proper type
def impl(a, a_min, a_max):
# Allocate an output array using standard numpy functions
out = np.empty_like(a)
# Iterate over a, calling np.clip on every element
for i in range(a.size):
# This will dispatch to the proper scalar implementation (as
# defined above) at *compile time*. There should have no
out[i] = np.clip(a[i], a_min, a_max)
return out
else:
raise TypingError("a must be an int/float or a 1D array of ints/floats")

# The call to np.clip has arguments with valid types, return our
# numba-compatible implementation
return impl


With our implementation registered, we should now be able to use np.clip with Numba. Verifying:

>>> import numpy as np
>>> import numba as nb

>>> @nb.njit
... def test_clip(x, a_min, a_max):
...     return np.clip(x, a_min, a_max)

>>> x = np.arange(10)

>>> test_clip(x, 2, 5)
array([2, 2, 2, 3, 4, 5, 5, 5, 5, 5])

>>> test_clip(x, None, 5)
array([0, 1, 2, 3, 4, 5, 5, 5, 5, 5])

>>> test_clip(5.0, 0, 3)
3.0


## Registering Numba Extensions with Entry Points

Our above example using np.clip worked because our overloaded definition was already registered with Numba. As long as our overload decorated functions have been loadAs long as our overload decorated functions have been loaded before Numba tries to compile something that relies on them, everything should just work. However, sometimes you may need (or want) to store the overloaded definitions in a package that would not normally be imported by users. For example, the numba-scipy package adds Numba support for the SciPy library, but is a separate package from scipy.

To avoid forcing users to import numba_scipy to enable the extension, Numba relies on entry points to automatically discover any installed extensions.

To register a module as a Numba extension, you need to:

• Define an init function to setup your extension (in our case this is just importing any modules with overload definitions):

# numba_overload_example/__init__.py
def init():
# Import the overloads module, registering any functions or types

• Register this init function as an entry point under the numba_extensions group.

# setup.py
setup(
...,
entry_points={
"numba_extensions": [