Source code for MDAnalysis.analysis.results

"""Analysis results and their aggregation --- :mod:`MDAnalysis.analysis.results`
================================================================================

Module introduces two classes, :class:`Results` and :class:`ResultsGroup`,
used for storing and aggregating data in
:meth:`MDAnalysis.analysis.base.AnalysisBase.run()`, respectively.


Classes
-------

The :class:`Results` class is an extension of a built-in dictionary
type, that holds all assigned attributes in :attr:`self.data` and 
allows for access either via dict-like syntax, or via class-like syntax:

.. code-block:: python

    from MDAnalysis.analysis.results import Results
    r = Results()
    r.array = [1, 2, 3, 4]
    assert r['array'] == r.array == [1, 2, 3, 4]


The :class:`ResultsGroup` can merge multiple :class:`Results` objects.
It is mainly used by :class:`MDAnalysis.analysis.base.AnalysisBase` class, 
that uses :meth:`ResultsGroup.merge()` method to aggregate results from
multiple workers, initialized during a parallel run:

.. code-block:: python

    from MDAnalysis.analysis.results import Results, ResultsGroup
    import numpy as np
    
    r1, r2 = Results(), Results()
    r1.masses = [1, 2, 3, 4, 5]
    r2.masses = [0, 0, 0, 0]
    r1.vectors = np.arange(10).reshape(5, 2)
    r2.vectors = np.arange(8).reshape(4, 2)

    group = ResultsGroup(
        lookup = {
            'masses': ResultsGroup.flatten_sequence,
            'vectors': ResultsGroup.ndarray_vstack
            }
        )

    r = group.merge([r1, r2])
    assert r.masses == list((*r1.masses, *r2.masses))
    assert (r.vectors == np.vstack([r1.vectors, r2.vectors])).all()
"""

from collections import UserDict
import numpy as np
from typing import Callable, Sequence


[docs] class Results(UserDict): r"""Container object for storing results. :class:`Results` are dictionaries that provide two ways by which values can be accessed: by dictionary key ``results["value_key"]`` or by object attribute, ``results.value_key``. :class:`Results` stores all results obtained from an analysis after calling :meth:`~AnalysisBase.run()`. The implementation is similar to the :class:`sklearn.utils.Bunch` class in `scikit-learn`_. .. _`scikit-learn`: https://scikit-learn.org/ .. _`sklearn.utils.Bunch`: https://scikit-learn.org/stable/modules/generated/sklearn.utils.Bunch.html Raises ------ AttributeError If an assigned attribute has the same name as a default attribute. ValueError If a key is not of type ``str`` and therefore is not able to be accessed by attribute. Examples -------- >>> from MDAnalysis.analysis.base import Results >>> results = Results(a=1, b=2) >>> results['b'] 2 >>> results.b 2 >>> results.a = 3 >>> results['a'] 3 >>> results.c = [1, 2, 3, 4] >>> results['c'] [1, 2, 3, 4] .. versionadded:: 2.0.0 .. versionchanged:: 2.8.0 Moved :class:`Results` to :mod:`MDAnalysis.analysis.results` """ def _validate_key(self, key): if key in dir(self): raise AttributeError( f"'{key}' is a protected dictionary attribute" ) elif isinstance(key, str) and not key.isidentifier(): raise ValueError(f"'{key}' is not a valid attribute") def __init__(self, *args, **kwargs): kwargs = dict(*args, **kwargs) if "data" in kwargs.keys(): raise AttributeError(f"'data' is a protected dictionary attribute") self.__dict__["data"] = {} self.update(kwargs) def __setitem__(self, key, item): self._validate_key(key) super().__setitem__(key, item) def __setattr__(self, attr, val): if attr == "data": super().__setattr__(attr, val) else: self.__setitem__(attr, val) def __getattr__(self, attr): try: return self[attr] except KeyError as err: raise AttributeError( f"'Results' object has no attribute '{attr}'" ) from err def __delattr__(self, attr): try: del self[attr] except KeyError as err: raise AttributeError( f"'Results' object has no attribute '{attr}'" ) from err def __getstate__(self): return self.data def __setstate__(self, state): self.data = state
[docs] class ResultsGroup: """ Management and aggregation of results stored in :class:`Results` instances. A :class:`ResultsGroup` is an optional description for :class:`Result` "dictionaries" that are used in analysis classes based on :class:`AnalysisBase`. For each *key* in a :class:`Result` it describes how multiple pieces of the data held under the key are to be aggregated. This approach is necessary when parts of a trajectory are analyzed independently (e.g., in parallel) and then need to me merged (with :meth:`merge`) to obtain a complete data set. Parameters ---------- lookup : dict[str, Callable], optional aggregation functions lookup dict, by default None Examples -------- .. code-block:: python from MDAnalysis.analysis.results import ResultsGroup, Results group = ResultsGroup(lookup={'mass': ResultsGroup.float_mean}) obj1 = Results(mass=1) obj2 = Results(mass=3) assert {'mass': 2.0} == group.merge([obj1, obj2]) .. code-block:: python # you can also set `None` for those attributes that you want to skip lookup = {'mass': ResultsGroup.float_mean, 'trajectory': None} group = ResultsGroup(lookup) objects = [Results(mass=1, skip=None), Results(mass=3, skip=object)] assert group.merge(objects, require_all_aggregators=False) == {'mass': 2.0} .. versionadded:: 2.8.0 """ def __init__(self, lookup: dict[str, Callable] = None): self._lookup = lookup
[docs] def merge( self, objects: Sequence[Results], require_all_aggregators: bool = True ) -> Results: """Merge multiple Results into a single Results instance. Merge multiple :class:`Results` instances into a single one, using the `lookup` dictionary to determine the appropriate aggregator functions for each named results attribute. If the resulting object only contains a single element, it just returns it without using any aggregators. Parameters ---------- objects : Sequence[Results] Multiple :class:`Results` instances with the same data attributes. require_all_aggregators : bool, optional if True, raise an exception when no aggregation function for a particular argument is found. Allows to skip aggregation for the parameters that aren't needed in the final object -- see :class:`ResultsGroup`. Returns ------- Results merged :class:`Results` Raises ------ ValueError if no aggregation function for a key is found and ``require_all_aggregators=True`` """ if len(objects) == 1: merged_results = objects[0] return merged_results merged_results = Results() for key in objects[0].keys(): agg_function = self._lookup.get(key, None) if agg_function is not None: results_of_t = [obj[key] for obj in objects] merged_results[key] = agg_function(results_of_t) elif require_all_aggregators: raise ValueError(f"No aggregation function for {key=}") return merged_results
[docs] @staticmethod def flatten_sequence(arrs: list[list]): """Flatten a list of lists into a list Parameters ---------- arrs : list[list] list of lists Returns ------- list flattened list """ return [item for sublist in arrs for item in sublist]
[docs] @staticmethod def ndarray_sum(arrs: list[np.ndarray]): """sums an ndarray along ``axis=0`` Parameters ---------- arrs : list[np.ndarray] list of input arrays. Must have the same shape. Returns ------- np.ndarray sum of input arrays """ return np.array(arrs).sum(axis=0)
[docs] @staticmethod def ndarray_mean(arrs: list[np.ndarray]): """calculates mean of input ndarrays along ``axis=0`` Parameters ---------- arrs : list[np.ndarray] list of input arrays. Must have the same shape. Returns ------- np.ndarray mean of input arrays """ return np.array(arrs).mean(axis=0)
[docs] @staticmethod def float_mean(floats: list[float]): """calculates mean of input float values Parameters ---------- floats : list[float] list of float values Returns ------- float mean value """ return np.array(floats).mean()
[docs] @staticmethod def ndarray_hstack(arrs: list[np.ndarray]): """Performs horizontal stack of input arrays Parameters ---------- arrs : list[np.ndarray] input numpy arrays Returns ------- np.ndarray result of stacking """ return np.hstack(arrs)
[docs] @staticmethod def ndarray_vstack(arrs: list[np.ndarray]): """Performs vertical stack of input arrays Parameters ---------- arrs : list[np.ndarray] input numpy arrays Returns ------- np.ndarray result of stacking """ return np.vstack(arrs)