numba-mpi / numba-mpi

Numba @njittable wrappers for MPI C API tested on Linux, macOS and Windows

Home Page:

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

numba-mpi logo numba-mpi

Python 3 LLVM Linux OK macOS OK Windows OK Github Actions Status Maintenance License: GPL v3 PyPI version Anaconda-Server Badge AUR package DOI


numba-mpi provides Python wrappers to the C MPI API callable from within Numba JIT-compiled code (@njit mode).

Support is provided for a subset of MPI routines covering: size/rank, send/recv, allreduce, bcast, scatter/gather & allgather, barrier, wtime and basic asynchronous communication with isend/irecv (only for contiguous arrays); for request handling including wait/waitall/waitany and test/testall/testany.

The API uses NumPy and supports both numeric and character datatypes (e.g., broadcast). Auto-generated docstring-based API docs are published on the web:

Packages can be obtained from PyPI, Conda Forge, Arch Linux or by invoking pip install git+

numba-mpi is a pure-Python package. The codebase includes a test suite used through the GitHub Actions workflows (thanks to mpi4py's setup-mpi!) for automated testing on: Linux (MPICH, OpenMPI & Intel MPI), macOS (MPICH & OpenMPI) and Windows (MS MPI).

Features that are not implemented yet include (help welcome!):

  • support for non-default communicators
  • support for MPI_IN_PLACE in [all]gather/scatter and allreduce
  • support for MPI_Type_create_struct (Numpy structured arrays)
  • ...

Hello world send/recv example:

import numba, numba_mpi, numpy

def hello():
    src = numpy.array([1., 2., 3., 4., 5.])
    dst_tst = numpy.empty_like(src)

    if numba_mpi.rank() == 0:
        numba_mpi.send(src, dest=1, tag=11)
    elif numba_mpi.rank() == 1:
        numba_mpi.recv(dst_tst, source=0, tag=11)


Example comparing numba-mpi vs. mpi4py performance:

The example below compares Numba + mpi4py vs. Numba + numba-mpi performance. The sample code estimates $\pi$ by integration of $4/(1+x^2)$ between 0 and 1 dividing the workload into n_intervals handled by separate MPI processes and then obtaining a sum using allreduce. The computation is carried out in a JIT-compiled function and is repeated N_TIMES, the repetitions and the MPI-handled reduction are done outside or inside of the JIT-compiled block for mpi4py and numba-mpi, respectively. Timing is repeated N_REPEAT times and the minimum time is reported. The generated plot shown below depicts the speedup obtained by replacing mpi4py with numba_mpi as a function of n_intervals - the more often communication is needed (smaller n_intervals), the larger the expected speedup.

import timeit, mpi4py, numba, numpy as np, numba_mpi

N_TIMES = 10000
RTOL = 1e-3

def get_pi_part(n_intervals=1000000, rank=0, size=1):
    h = 1 / n_intervals
    partial_sum = 0.0
    for i in range(rank + 1, n_intervals, size):
        x = h * (i - 0.5)
        partial_sum += 4 / (1 + x**2)
    return h * partial_sum

def pi_numba_mpi(n_intervals):
    pi = np.array([0.])
    part = np.empty_like(pi)
    for _ in range(N_TIMES):
        part[0] = get_pi_part(n_intervals, numba_mpi.rank(), numba_mpi.size())
        numba_mpi.allreduce(part, pi, numba_mpi.Operator.SUM)
        assert abs(pi[0] - np.pi) / np.pi < RTOL

def pi_mpi4py(n_intervals):
    pi = np.array([0.])
    part = np.empty_like(pi)
    for _ in range(N_TIMES):
        part[0] = get_pi_part(n_intervals, mpi4py.MPI.COMM_WORLD.rank, mpi4py.MPI.COMM_WORLD.size)
        mpi4py.MPI.COMM_WORLD.Allreduce(part, (pi, mpi4py.MPI.DOUBLE), op=mpi4py.MPI.SUM)
        assert abs(pi[0] - np.pi) / np.pi < RTOL

plot_x = [x for x in range(1, 11)]
plot_y = {'numba_mpi': [], 'mpi4py': []}
for x in plot_x:
    for impl in plot_y:
            f"pi_{impl}(n_intervals={N_TIMES // x})",

if numba_mpi.rank() == 0:
    from matplotlib import pyplot
    pyplot.figure(figsize=(8.3, 3.5), tight_layout=True)
    pyplot.plot(plot_x, np.array(plot_y['mpi4py'])/np.array(plot_y['numba_mpi']), marker='o')
    pyplot.xlabel('number of MPI calls per interval')
    pyplot.ylabel('mpi4py/numba_mpi wall-time ratio')
    pyplot.title(f'mpiexec -np {numba_mpi.size()}')


MPI resources on the web:


Development of numba-mpi has been supported by the Polish National Science Centre (grant no. 2020/39/D/ST10/01220).


Numba @njittable wrappers for MPI C API tested on Linux, macOS and Windows

License:GNU General Public License v3.0


Language:Python 100.0%