Source code for MDAnalysis.analysis.backends

"""Analysis backends --- :mod:`MDAnalysis.analysis.backends`
============================================================

.. versionadded:: 2.8.0


The :mod:`backends` module provides :class:`BackendBase` base class to
implement custom execution backends for
:meth:`MDAnalysis.analysis.base.AnalysisBase.run` and its
subclasses.

.. SeeAlso:: :ref:`parallel-analysis`

.. _backends:

Backends
--------

Three built-in backend classes are provided:

* *serial*: :class:`BackendSerial`, that is equivalent to using no
  parallelization and is the default

* *multiprocessing*: :class:`BackendMultiprocessing` that supports
  parallelization via standard Python :mod:`multiprocessing` module
  and uses default :mod:`pickle` serialization

* *dask*: :class:`BackendDask`, that uses the same process-based
  parallelization as :class:`BackendMultiprocessing`, but different
  serialization algorithm via `dask <https://dask.org/>`_ (see `dask
  serialization algorithms
  <https://distributed.dask.org/en/latest/serialization.html>`_ for details)

Classes
-------

"""
import warnings
from typing import Callable
from MDAnalysis.lib.util import is_installed


[docs] class BackendBase: """Base class for backend implementation. Initializes an instance and performs checks for its validity, such as ``n_workers`` and possibly other ones. Parameters ---------- n_workers : int number of workers (usually, processes) over which the work is split Examples -------- .. code-block:: python from MDAnalysis.analysis.backends import BackendBase class ThreadsBackend(BackendBase): def apply(self, func, computations): from multiprocessing.dummy import Pool with Pool(processes=self.n_workers) as pool: results = pool.map(func, computations) return results import MDAnalysis as mda from MDAnalysis.tests.datafiles import PSF, DCD from MDAnalysis.analysis.rms import RMSD u = mda.Universe(PSF, DCD) ref = mda.Universe(PSF, DCD) R = RMSD(u, ref) n_workers = 2 backend = ThreadsBackend(n_workers=n_workers) R.run(backend=backend, unsupported_backend=True) .. warning:: Using `ThreadsBackend` above will lead to erroneous results, since it is an educational example. Do not use it for real analysis. .. versionadded:: 2.8.0 """ def __init__(self, n_workers: int): self.n_workers = n_workers self._validate()
[docs] def _get_checks(self): """Get dictionary with ``condition: error_message`` pairs that ensure the validity of the backend instance Returns ------- dict dictionary with ``condition: error_message`` pairs that will get checked during ``_validate()`` run """ return { isinstance(self.n_workers, int) and self.n_workers > 0: f"n_workers should be positive integer, got {self.n_workers=}", }
[docs] def _get_warnings(self): """Get dictionary with ``condition: warning_message`` pairs that ensure the good usage of the backend instance Returns ------- dict dictionary with ``condition: warning_message`` pairs that will get checked during ``_validate()`` run """ return dict()
[docs] def _validate(self): """Check correctness (e.g. ``dask`` is installed if using ``backend='dask'``) and good usage (e.g. ``n_workers=1`` if backend is serial) of the backend Raises ------ ValueError if one of the conditions in :meth:`_get_checks` is ``True`` """ for check, msg in self._get_checks().items(): if not check: raise ValueError(msg) for check, msg in self._get_warnings().items(): if not check: warnings.warn(msg)
[docs] def apply(self, func: Callable, computations: list) -> list: """map function `func` to all tasks in the `computations` list Main method that will get called when using an instance of ``BackendBase``. It is equivalent to running ``[func(item) for item in computations]`` while using the parallel backend capabilities. Parameters ---------- func : Callable function to be called on each of the tasks in computations list computations : list computation tasks to apply function to Returns ------- list list of results of the function """ raise NotImplementedError
[docs] class BackendSerial(BackendBase): """A built-in backend that does serial execution of the function, without any parallelization. Parameters ---------- n_workers : int Is ignored in this class, and if ``n_workers`` > 1, a warning will be given. .. versionadded:: 2.8.0 """
[docs] def _get_warnings(self): """Get dictionary with ``condition: warning_message`` pairs that ensure the good usage of the backend instance. Here, it checks if the number of workers is not 1, otherwise gives warning. Returns ------- dict dictionary with ``condition: warning_message`` pairs that will get checked during ``_validate()`` run """ return { self.n_workers == 1: "n_workers is ignored when executing with backend='serial'" }
[docs] def apply(self, func: Callable, computations: list) -> list: """ Serially applies `func` to each task object in ``computations``. Parameters ---------- func : Callable function to be called on each of the tasks in computations list computations : list computation tasks to apply function to Returns ------- list list of results of the function """ return [func(task) for task in computations]
[docs] class BackendMultiprocessing(BackendBase): """A built-in backend that executes a given function using the :meth:`multiprocessing.Pool.map <multiprocessing.pool.Pool.map>` method. Parameters ---------- n_workers : int number of processes in :class:`multiprocessing.Pool <multiprocessing.pool.Pool>` to distribute the workload between. Must be a positive integer. Examples -------- .. code-block:: python from MDAnalysis.analysis.backends import BackendMultiprocessing import multiprocessing as mp backend_obj = BackendMultiprocessing(n_workers=mp.cpu_count()) .. versionadded:: 2.8.0 """
[docs] def apply(self, func: Callable, computations: list) -> list: """Applies `func` to each object in ``computations`` using `multiprocessing`'s `Pool.map`. Parameters ---------- func : Callable function to be called on each of the tasks in computations list computations : list computation tasks to apply function to Returns ------- list list of results of the function """ from multiprocessing import Pool with Pool(processes=self.n_workers) as pool: results = pool.map(func, computations) return results
[docs] class BackendDask(BackendBase): """A built-in backend that executes a given function with *dask*. Execution is performed with the :func:`dask.compute` function of :class:`dask.delayed.Delayed` object (created with :func:`dask.delayed.delayed`) with ``scheduler='processes'`` and ``chunksize=1`` (this ensures uniform distribution of tasks among processes). Requires the `dask package <https://docs.dask.org/en/stable/>`_ to be `installed <https://docs.dask.org/en/stable/install.html>`_. Parameters ---------- n_workers : int number of processes in to distribute the workload between. Must be a positive integer. Workers are actually :class:`multiprocessing.pool.Pool` processes, but they use a different and more flexible `serialization protocol <https://docs.dask.org/en/stable/phases-of-computation.html#graph-serialization>`_. Examples -------- .. code-block:: python from MDAnalysis.analysis.backends import BackendDask import multiprocessing as mp backend_obj = BackendDask(n_workers=mp.cpu_count()) .. versionadded:: 2.8.0 """
[docs] def apply(self, func: Callable, computations: list) -> list: """Applies `func` to each object in ``computations``. Parameters ---------- func : Callable function to be called on each of the tasks in computations list computations : list computation tasks to apply function to Returns ------- list list of results of the function """ from dask.delayed import delayed import dask computations = [delayed(func)(task) for task in computations] results = dask.compute(computations, scheduler="processes", chunksize=1, num_workers=self.n_workers)[0] return results
[docs] def _get_checks(self): """Get dictionary with ``condition: error_message`` pairs that ensure the validity of the backend instance. Here checks if ``dask`` module is installed in the environment. Returns ------- dict dictionary with ``condition: error_message`` pairs that will get checked during ``_validate()`` run """ base_checks = super()._get_checks() checks = { is_installed("dask"): ("module 'dask' is missing. Please install 'dask': " "https://docs.dask.org/en/stable/install.html") } return base_checks | checks