Source code for Poisson.mpi.communicators

"""Halo exchange communicators for MPI domain decomposition."""

import numpy as np
from mpi4py import MPI

# Axis config: (index, name) - used to build neighbor keys like 'z_lower'
_AXES = [(0, "z"), (1, "y"), (2, "x")]


def _face_slice(axis, idx, has_halo):
    """Slice for face at idx, excluding halos of other decomposed axes."""
    return tuple(
        idx if ax == axis else slice(1, -1) if has_halo[ax] else slice(None)
        for ax in range(3)
    )


class _BaseHaloExchange:
    """Base class with shared exchange loop logic."""

    def exchange_halos(self, u, decomposition, rank, comm):
        """Exchange halo zones with neighbors."""
        neighbors = decomposition.get_neighbors(rank)
        has_halo = {ax: f"{name}_lower" in neighbors for ax, name in _AXES}

        for axis, name in _AXES:
            if not has_halo[axis]:
                continue
            lo = neighbors.get(f"{name}_lower")
            hi = neighbors.get(f"{name}_upper")
            self._exchange_axis(u, axis, lo, hi, has_halo, comm, tag=axis * 100)


[docs] class NumpyHaloExchange(_BaseHaloExchange): """Halo exchange using NumPy array copies.""" name = "NumPy" def _exchange_axis(self, u, axis, lo_rank, hi_rank, has_halo, comm, tag): lo = lo_rank if lo_rank is not None else MPI.PROC_NULL hi = hi_rank if hi_rank is not None else MPI.PROC_NULL for send_i, recv_i, dest, src in [(-2, 0, hi, lo), (1, -1, lo, hi)]: send = np.ascontiguousarray(u[_face_slice(axis, send_i, has_halo)]) recv = np.empty_like(send) comm.Sendrecv(send, dest, tag, recv, src, tag) u[_face_slice(axis, recv_i, has_halo)] = ( recv if src != MPI.PROC_NULL else 0.0 ) tag += 1
[docs] class CustomHaloExchange(_BaseHaloExchange): """Halo exchange using MPI derived datatypes (zero-copy).""" name = "MPI Datatype" def __init__(self): self._cache = {} # (shape, has_halo_tuple) -> {axis: datatype} def _exchange_axis(self, u, axis, lo_rank, hi_rank, has_halo, comm, tag): lo = lo_rank if lo_rank is not None else MPI.PROC_NULL hi = hi_rank if hi_rank is not None else MPI.PROC_NULL # Get or create datatype for this configuration key = (u.shape, tuple(has_halo.items())) if key not in self._cache: self._cache[key] = self._create_datatypes(u.shape, has_halo) dtype = self._cache[key][axis] nz, ny, nx = u.shape u_flat = u.ravel() # Compute offsets into flattened array start = [1 if has_halo[ax] else 0 for ax in range(3)] def flat(z, y, x): return z * ny * nx + y * nx + x if axis == 0: send_hi, recv_lo = ( flat(nz - 2, start[1], start[2]), flat(0, start[1], start[2]), ) send_lo, recv_hi = ( flat(1, start[1], start[2]), flat(nz - 1, start[1], start[2]), ) elif axis == 1: send_hi, recv_lo = ( flat(start[0], ny - 2, start[2]), flat(start[0], 0, start[2]), ) send_lo, recv_hi = ( flat(start[0], 1, start[2]), flat(start[0], ny - 1, start[2]), ) else: send_hi, recv_lo = ( flat(start[0], start[1], nx - 2), flat(start[0], start[1], 0), ) send_lo, recv_hi = ( flat(start[0], start[1], 1), flat(start[0], start[1], nx - 1), ) # Exchange comm.Sendrecv( [u_flat[send_hi:], 1, dtype], hi, tag, [u_flat[recv_lo:], 1, dtype], lo, tag ) if lo_rank is None: u[_face_slice(axis, 0, has_halo)] = 0.0 comm.Sendrecv( [u_flat[send_lo:], 1, dtype], lo, tag + 1, [u_flat[recv_hi:], 1, dtype], hi, tag + 1, ) if hi_rank is None: u[_face_slice(axis, -1, has_halo)] = 0.0 def _create_datatypes(self, shape, has_halo): """Create MPI datatypes for each axis.""" nz, ny, nx = shape nz_int = nz - 2 if has_halo[0] else nz ny_int = ny - 2 if has_halo[1] else ny nx_int = nx - 2 if has_halo[2] else nx datatypes = {} # Z-face: contiguous if no y/x halos, else strided if not has_halo[1] and not has_halo[2]: dt = MPI.DOUBLE.Create_contiguous(ny * nx) else: dt = MPI.DOUBLE.Create_vector(ny_int, nx_int, nx) dt.Commit() datatypes[0] = dt # Y-face: strided across z-planes dt = MPI.DOUBLE.Create_vector(nz_int, nx_int, ny * nx) dt.Commit() datatypes[1] = dt # X-face: 2D strided (most complex) row = MPI.DOUBLE.Create_vector(ny_int, 1, nx) row.Commit() dt = row.Create_hvector(nz_int, 1, ny * nx * MPI.DOUBLE.Get_size()) dt.Commit() row.Free() datatypes[2] = dt return datatypes def __del__(self): for datatypes in self._cache.values(): for dt in datatypes.values(): dt.Free()