Source code for MDAnalysis.analysis.results
"""Analysis results and their aggregation --- :mod:`MDAnalysis.analysis.results`
================================================================================
Module introduces two classes, :class:`Results` and :class:`ResultsGroup`,
used for storing and aggregating data in
:meth:`MDAnalysis.analysis.base.AnalysisBase.run()`, respectively.
Classes
-------
The :class:`Results` class is an extension of a built-in dictionary
type, that holds all assigned attributes in :attr:`self.data` and
allows for access either via dict-like syntax, or via class-like syntax:
.. code-block:: python
from MDAnalysis.analysis.results import Results
r = Results()
r.array = [1, 2, 3, 4]
assert r['array'] == r.array == [1, 2, 3, 4]
The :class:`ResultsGroup` can merge multiple :class:`Results` objects.
It is mainly used by :class:`MDAnalysis.analysis.base.AnalysisBase` class,
that uses :meth:`ResultsGroup.merge()` method to aggregate results from
multiple workers, initialized during a parallel run:
.. code-block:: python
from MDAnalysis.analysis.results import Results, ResultsGroup
import numpy as np
r1, r2 = Results(), Results()
r1.masses = [1, 2, 3, 4, 5]
r2.masses = [0, 0, 0, 0]
r1.vectors = np.arange(10).reshape(5, 2)
r2.vectors = np.arange(8).reshape(4, 2)
group = ResultsGroup(
lookup = {
'masses': ResultsGroup.flatten_sequence,
'vectors': ResultsGroup.ndarray_vstack
}
)
r = group.merge([r1, r2])
assert r.masses == list((*r1.masses, *r2.masses))
assert (r.vectors == np.vstack([r1.vectors, r2.vectors])).all()
"""
from collections import UserDict
import numpy as np
from typing import Callable, Sequence
[docs]
class Results(UserDict):
r"""Container object for storing results.
:class:`Results` are dictionaries that provide two ways by which values
can be accessed: by dictionary key ``results["value_key"]`` or by object
attribute, ``results.value_key``. :class:`Results` stores all results
obtained from an analysis after calling :meth:`~AnalysisBase.run()`.
The implementation is similar to the :class:`sklearn.utils.Bunch`
class in `scikit-learn`_.
.. _`scikit-learn`: https://scikit-learn.org/
.. _`sklearn.utils.Bunch`: https://scikit-learn.org/stable/modules/generated/sklearn.utils.Bunch.html
Raises
------
AttributeError
If an assigned attribute has the same name as a default attribute.
ValueError
If a key is not of type ``str`` and therefore is not able to be
accessed by attribute.
Examples
--------
>>> from MDAnalysis.analysis.base import Results
>>> results = Results(a=1, b=2)
>>> results['b']
2
>>> results.b
2
>>> results.a = 3
>>> results['a']
3
>>> results.c = [1, 2, 3, 4]
>>> results['c']
[1, 2, 3, 4]
.. versionadded:: 2.0.0
.. versionchanged:: 2.8.0
Moved :class:`Results` to :mod:`MDAnalysis.analysis.results`
"""
def _validate_key(self, key):
if key in dir(self):
raise AttributeError(
f"'{key}' is a protected dictionary attribute"
)
elif isinstance(key, str) and not key.isidentifier():
raise ValueError(f"'{key}' is not a valid attribute")
def __init__(self, *args, **kwargs):
kwargs = dict(*args, **kwargs)
if "data" in kwargs.keys():
raise AttributeError(f"'data' is a protected dictionary attribute")
self.__dict__["data"] = {}
self.update(kwargs)
def __setitem__(self, key, item):
self._validate_key(key)
super().__setitem__(key, item)
def __setattr__(self, attr, val):
if attr == "data":
super().__setattr__(attr, val)
else:
self.__setitem__(attr, val)
def __getattr__(self, attr):
try:
return self[attr]
except KeyError as err:
raise AttributeError(
f"'Results' object has no attribute '{attr}'"
) from err
def __delattr__(self, attr):
try:
del self[attr]
except KeyError as err:
raise AttributeError(
f"'Results' object has no attribute '{attr}'"
) from err
def __getstate__(self):
return self.data
def __setstate__(self, state):
self.data = state
[docs]
class ResultsGroup:
"""
Management and aggregation of results stored in :class:`Results` instances.
A :class:`ResultsGroup` is an optional description for :class:`Result` "dictionaries"
that are used in analysis classes based on :class:`AnalysisBase`. For each *key* in a
:class:`Result` it describes how multiple pieces of the data held under the key are
to be aggregated. This approach is necessary when parts of a trajectory are analyzed
independently (e.g., in parallel) and then need to me merged (with :meth:`merge`) to
obtain a complete data set.
Parameters
----------
lookup : dict[str, Callable], optional
aggregation functions lookup dict, by default None
Examples
--------
.. code-block:: python
from MDAnalysis.analysis.results import ResultsGroup, Results
group = ResultsGroup(lookup={'mass': ResultsGroup.float_mean})
obj1 = Results(mass=1)
obj2 = Results(mass=3)
assert {'mass': 2.0} == group.merge([obj1, obj2])
.. code-block:: python
# you can also set `None` for those attributes that you want to skip
lookup = {'mass': ResultsGroup.float_mean, 'trajectory': None}
group = ResultsGroup(lookup)
objects = [Results(mass=1, skip=None), Results(mass=3, skip=object)]
assert group.merge(objects, require_all_aggregators=False) == {'mass': 2.0}
.. versionadded:: 2.8.0
"""
def __init__(self, lookup: dict[str, Callable] = None):
self._lookup = lookup
[docs]
def merge(
self, objects: Sequence[Results], require_all_aggregators: bool = True
) -> Results:
"""Merge multiple Results into a single Results instance.
Merge multiple :class:`Results` instances into a single one, using the
`lookup` dictionary to determine the appropriate aggregator functions for
each named results attribute. If the resulting object only contains a single
element, it just returns it without using any aggregators.
Parameters
----------
objects : Sequence[Results]
Multiple :class:`Results` instances with the same data attributes.
require_all_aggregators : bool, optional
if True, raise an exception when no aggregation function for a
particular argument is found. Allows to skip aggregation for the
parameters that aren't needed in the final object --
see :class:`ResultsGroup`.
Returns
-------
Results
merged :class:`Results`
Raises
------
ValueError
if no aggregation function for a key is found and ``require_all_aggregators=True``
"""
if len(objects) == 1:
merged_results = objects[0]
return merged_results
merged_results = Results()
for key in objects[0].keys():
agg_function = self._lookup.get(key, None)
if agg_function is not None:
results_of_t = [obj[key] for obj in objects]
merged_results[key] = agg_function(results_of_t)
elif require_all_aggregators:
raise ValueError(f"No aggregation function for {key=}")
return merged_results
[docs]
@staticmethod
def flatten_sequence(arrs: list[list]):
"""Flatten a list of lists into a list
Parameters
----------
arrs : list[list]
list of lists
Returns
-------
list
flattened list
"""
return [item for sublist in arrs for item in sublist]
[docs]
@staticmethod
def ndarray_sum(arrs: list[np.ndarray]):
"""sums an ndarray along ``axis=0``
Parameters
----------
arrs : list[np.ndarray]
list of input arrays. Must have the same shape.
Returns
-------
np.ndarray
sum of input arrays
"""
return np.array(arrs).sum(axis=0)
[docs]
@staticmethod
def ndarray_mean(arrs: list[np.ndarray]):
"""calculates mean of input ndarrays along ``axis=0``
Parameters
----------
arrs : list[np.ndarray]
list of input arrays. Must have the same shape.
Returns
-------
np.ndarray
mean of input arrays
"""
return np.array(arrs).mean(axis=0)
[docs]
@staticmethod
def float_mean(floats: list[float]):
"""calculates mean of input float values
Parameters
----------
floats : list[float]
list of float values
Returns
-------
float
mean value
"""
return np.array(floats).mean()
[docs]
@staticmethod
def ndarray_hstack(arrs: list[np.ndarray]):
"""Performs horizontal stack of input arrays
Parameters
----------
arrs : list[np.ndarray]
input numpy arrays
Returns
-------
np.ndarray
result of stacking
"""
return np.hstack(arrs)
[docs]
@staticmethod
def ndarray_vstack(arrs: list[np.ndarray]):
"""Performs vertical stack of input arrays
Parameters
----------
arrs : list[np.ndarray]
input numpy arrays
Returns
-------
np.ndarray
result of stacking
"""
return np.vstack(arrs)