Source code for MDAnalysis.analysis.backends
"""Analysis backends --- :mod:`MDAnalysis.analysis.backends`
============================================================
.. versionadded:: 2.8.0
The :mod:`backends` module provides :class:`BackendBase` base class to
implement custom execution backends for
:meth:`MDAnalysis.analysis.base.AnalysisBase.run` and its
subclasses.
.. SeeAlso:: :ref:`parallel-analysis`
.. _backends:
Backends
--------
Three built-in backend classes are provided:
* *serial*: :class:`BackendSerial`, that is equivalent to using no
  parallelization and is the default
* *multiprocessing*: :class:`BackendMultiprocessing` that supports
  parallelization via standard Python :mod:`multiprocessing` module
  and uses default :mod:`pickle` serialization
* *dask*: :class:`BackendDask`, that uses the same process-based
  parallelization as :class:`BackendMultiprocessing`, but different
  serialization algorithm via `dask <https://dask.org/>`_ (see `dask
  serialization algorithms
  <https://distributed.dask.org/en/latest/serialization.html>`_ for details)
Classes
-------
"""
import warnings
from typing import Callable
from MDAnalysis.lib.util import is_installed
[docs]
class BackendBase:
    """Base class for backend implementation.
    Initializes an instance and performs checks for its validity, such as
    ``n_workers`` and possibly other ones.
    Parameters
    ----------
    n_workers : int
        number of workers (usually, processes) over which the work is split
    Examples
    --------
    .. code-block:: python
        from MDAnalysis.analysis.backends import BackendBase
        class ThreadsBackend(BackendBase):
            def apply(self, func, computations):
                from multiprocessing.dummy import Pool
                with Pool(processes=self.n_workers) as pool:
                    results = pool.map(func, computations)
                return results
        import MDAnalysis as mda
        from MDAnalysis.tests.datafiles import PSF, DCD
        from MDAnalysis.analysis.rms import RMSD
        u = mda.Universe(PSF, DCD)
        ref = mda.Universe(PSF, DCD)
        R = RMSD(u, ref)
        n_workers = 2
        backend = ThreadsBackend(n_workers=n_workers)
        R.run(backend=backend, unsupported_backend=True)
    .. warning::
        Using `ThreadsBackend` above will lead to erroneous results, since it
        is an educational example. Do not use it for real analysis.
    .. versionadded:: 2.8.0
    """
    def __init__(self, n_workers: int):
        self.n_workers = n_workers
        self._validate()
[docs]
    def _get_checks(self):
        """Get dictionary with ``condition: error_message`` pairs that ensure the
        validity of the backend instance
        Returns
        -------
        dict
            dictionary with ``condition: error_message`` pairs that will get
            checked during ``_validate()`` run
        """
        return {
            isinstance(self.n_workers, int)
            and self.n_workers
            > 0: f"n_workers should be positive integer, got {self.n_workers=}",
        }
[docs]
    def _get_warnings(self):
        """Get dictionary with ``condition: warning_message`` pairs that ensure
        the good usage of the backend instance
        Returns
        -------
        dict
            dictionary with ``condition: warning_message`` pairs that will get
            checked during ``_validate()`` run
        """
        return dict()
[docs]
    def _validate(self):
        """Check correctness (e.g. ``dask`` is installed if using ``backend='dask'``)
        and good usage (e.g. ``n_workers=1`` if backend is serial) of the backend
        Raises
        ------
        ValueError
            if one of the conditions in :meth:`_get_checks` is ``True``
        """
        for check, msg in self._get_checks().items():
            if not check:
                raise ValueError(msg)
        for check, msg in self._get_warnings().items():
            if not check:
                warnings.warn(msg)
[docs]
    def apply(self, func: Callable, computations: list) -> list:
        """map function `func` to all tasks in the `computations` list
        Main method that will get called when using an instance of
        ``BackendBase``.  It is equivalent to running ``[func(item) for item in
        computations]`` while using the parallel backend capabilities.
        Parameters
        ----------
        func : Callable
            function to be called on each of the tasks in computations list
        computations : list
            computation tasks to apply function to
        Returns
        -------
        list
            list of results of the function
        """
        raise NotImplementedError
[docs]
class BackendSerial(BackendBase):
    """A built-in backend that does serial execution of the function, without any
    parallelization.
    Parameters
    ----------
    n_workers : int
        Is ignored in this class, and if ``n_workers`` > 1, a warning will be
        given.
    .. versionadded:: 2.8.0
    """
[docs]
    def _get_warnings(self):
        """Get dictionary with ``condition: warning_message`` pairs that ensure
        the good usage of the backend instance. Here, it checks if the number
        of workers is not 1, otherwise gives warning.
        Returns
        -------
        dict
            dictionary with ``condition: warning_message`` pairs that will get
            checked during ``_validate()`` run
        """
        return {
            self.n_workers
            == 1: "n_workers is ignored when executing with backend='serial'"
        }
[docs]
    def apply(self, func: Callable, computations: list) -> list:
        """
        Serially applies `func` to each task object in ``computations``.
        Parameters
        ----------
        func : Callable
            function to be called on each of the tasks in computations list
        computations : list
            computation tasks to apply function to
        Returns
        -------
        list
            list of results of the function
        """
        return [func(task) for task in computations]
[docs]
class BackendMultiprocessing(BackendBase):
    """A built-in backend that executes a given function using the
    :meth:`multiprocessing.Pool.map <multiprocessing.pool.Pool.map>` method.
    Parameters
    ----------
    n_workers : int
        number of processes in :class:`multiprocessing.Pool
        <multiprocessing.pool.Pool>` to distribute the workload
        between. Must be a positive integer.
    Examples
    --------
    .. code-block:: python
        from MDAnalysis.analysis.backends import BackendMultiprocessing
        import multiprocessing as mp
        backend_obj = BackendMultiprocessing(n_workers=mp.cpu_count())
    .. versionadded:: 2.8.0
    """
[docs]
    def apply(self, func: Callable, computations: list) -> list:
        """Applies `func` to each object in ``computations`` using `multiprocessing`'s `Pool.map`.
        Parameters
        ----------
        func : Callable
            function to be called on each of the tasks in computations list
        computations : list
            computation tasks to apply function to
        Returns
        -------
        list
            list of results of the function
        """
        from multiprocessing import Pool
        with Pool(processes=self.n_workers) as pool:
            results = pool.map(func, computations)
        return results
[docs]
class BackendDask(BackendBase):
    """A built-in backend that executes a given function with *dask*.
    Execution is performed with the :func:`dask.compute` function of
    :class:`dask.delayed.Delayed` object (created with
    :func:`dask.delayed.delayed`) with ``scheduler='processes'`` and
    ``chunksize=1`` (this ensures uniform distribution of tasks among
    processes). Requires the `dask package <https://docs.dask.org/en/stable/>`_
    to be `installed <https://docs.dask.org/en/stable/install.html>`_.
    Parameters
    ----------
    n_workers : int
        number of processes in to distribute the workload
        between. Must be a positive integer. Workers are actually
        :class:`multiprocessing.pool.Pool` processes, but they use a different and
        more flexible `serialization protocol
        <https://docs.dask.org/en/stable/phases-of-computation.html#graph-serialization>`_.
    Examples
    --------
    .. code-block:: python
        from MDAnalysis.analysis.backends import BackendDask
        import multiprocessing as mp
        backend_obj = BackendDask(n_workers=mp.cpu_count())
    .. versionadded:: 2.8.0
    """
[docs]
    def apply(self, func: Callable, computations: list) -> list:
        """Applies `func` to each object in ``computations``.
        Parameters
        ----------
        func : Callable
            function to be called on each of the tasks in computations list
        computations : list
            computation tasks to apply function to
        Returns
        -------
        list
            list of results of the function
        """
        from dask.delayed import delayed
        import dask
        computations = [delayed(func)(task) for task in computations]
        results = dask.compute(
            computations,
            scheduler="processes",
            chunksize=1,
            num_workers=self.n_workers,
        )[0]
        return results
[docs]
    def _get_checks(self):
        """Get dictionary with ``condition: error_message`` pairs that ensure the
        validity of the backend instance. Here checks if ``dask`` module is
        installed in the environment.
        Returns
        -------
        dict
            dictionary with ``condition: error_message`` pairs that will get
            checked during ``_validate()`` run
        """
        base_checks = super()._get_checks()
        checks = {
            is_installed("dask"): (
                "module 'dask' is missing. Please install 'dask': "
                "https://docs.dask.org/en/stable/install.html"
            )
        }
        return base_checks | checks