# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4
#
# MDAnalysis --- https://www.mdanalysis.org
# Copyright (c) 2006-2017 The MDAnalysis Development Team and contributors
# (see the file AUTHORS for the full list of names)
#
# Released under the GNU Public Licence, v2 or any higher version
#
# Please cite your use of MDAnalysis in published work:
#
# R. J. Gowers, M. Linke, J. Barnoud, T. J. E. Reddy, M. N. Melo, S. L. Seyler,
# D. L. Dotson, J. Domanski, S. Buchoux, I. M. Kenney, and O. Beckstein.
# MDAnalysis: A Python package for the rapid analysis of molecular dynamics
# simulations. In S. Benthall and S. Rostrup editors, Proceedings of the 15th
# Python in Science Conference, pages 102-109, Austin, TX, 2016. SciPy.
# doi: 10.25080/majora-629e541a-00e
#
# N. Michaud-Agrawal, E. J. Denning, T. B. Woolf, and O. Beckstein.
# MDAnalysis: A Toolkit for the Analysis of Molecular Dynamics Simulations.
# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787
#
"""
AMBER PRMTOP topology parser
============================
Reads an AMBER top file to build the system.
Amber keywords are turned into the following attributes:
+----------------------------+----------------------+
| AMBER flag                 | MDAnalysis attribute |
+----------------------------+----------------------+
| ATOM_NAME                  | names                |
+----------------------------+----------------------+
| CHARGE                     | charges              |
+----------------------------+----------------------+
| ATOMIC_NUMBER              | elements             |
+----------------------------+----------------------+
| MASS                       | masses               |
+----------------------------+----------------------+
| BONDS_INC_HYDROGEN         | bonds                |
| BONDS_WITHOUT_HYDROGEN     |                      |
+----------------------------+----------------------+
| ANGLES_INC_HYDROGEN        | angles               |
| ANGLES_WITHOUT_HYDROGEN    |                      |
+----------------------------+----------------------+
| DIHEDRALS_INC_HYDROGEN     | dihedrals / improper |
| DIHEDRALS_WITHOUT_HYDROGEN |                      |
+----------------------------+----------------------+
| ATOM_TYPE_INDEX            | type_indices         |
+----------------------------+----------------------+
| AMBER_ATOM_TYPE            | types                |
+----------------------------+----------------------+
| RESIDUE_LABEL              | resnames             |
+----------------------------+----------------------+
| RESIDUE_POINTER            | residues             |
+----------------------------+----------------------+
TODO:
  Add support for Chamber-style topologies
  More stringent tests
.. Note::
   The Amber charge is converted to electron charges as used in
   MDAnalysis and other packages. To get back Amber charges, multiply
   by 18.2223.
   Chamber-style Amber topologies (i.e. topologies generated via parmed
   conversion of a CHARMM topology to an AMBER one) are not currently
   supported. Support will likely be added in future MDAnalysis releases.
   As of version 2.0.0, elements are no longer guessed if ATOMIC_NUMBER records
   are missing. In those scenarios, if elements are necessary, users will have
   to invoke the element guessers after parsing the topology file. Please see
   :mod:`MDAnalysis.topology.guessers` for more details.
.. _`PARM parameter/topology file specification`:
   http://ambermd.org/formats.html#topology
Classes
-------
.. autoclass:: TOPParser
   :members:
   :inherited-members:
"""
import numpy as np
import itertools
from .tables import Z2SYMB
from ..lib.util import openany, FORTRANReader
from .base import TopologyReaderBase
from ..core.topology import Topology
from ..core.topologyattrs import (
    Atomnames,
    Atomtypes,
    Atomids,
    Charges,
    Elements,
    Masses,
    Resnames,
    Resids,
    Resnums,
    Segids,
    AtomAttr,
    Bonds,
    Angles,
    Dihedrals,
    Impropers
)
import warnings
import logging
logger = logging.getLogger('MDAnalysis.topology.TOPParser')
class TypeIndices(AtomAttr):
    """Numerical type of each Atom"""
    attrname = 'type_indices'
    singular = 'type_index'
    level = 'atom'
[docs]class TOPParser(TopologyReaderBase):
    """Reads topology information from an AMBER top file.
    Reads the following Attributes if in topology:
    - Atomnames
    - Charges
    - Masses
    - Elements
    - Atomtypes
    - Resnames
    - Type_indices
    - Bonds
    - Angles
    - Dihedrals (inc. impropers)
    The format is defined in `PARM parameter/topology file
    specification`_.  The reader tries to detect if it is a newer
    (AMBER 12?) file format by looking for the flag "ATOMIC_NUMBER".
    .. _`PARM parameter/topology file specification`:
       http://ambermd.org/formats.html#topology
    Notes
    -----
    Elements are obtained from the atomic numbers (if present). If a given
    input atomic number does not belong to an element (usually either -1 or 0),
    the element will be assigned an empty record.
    .. versionchanged:: 0.7.6
      parses both amber10 and amber12 formats
    .. versionchanged:: 0.19.0
      parses bonds, angles, dihedrals, and impropers
    .. versionchanged:: 1.0.0
      warns users that chamber-style topologies are not currently supported
    .. versionchanged:: 2.0.0
      no longer guesses elements if missing
    """
    format = ['TOP', 'PRMTOP', 'PARM7']
[docs]    def parse(self, **kwargs):
        """Parse Amber PRMTOP topology file *filename*.
        Returns
        -------
        A MDAnalysis Topology object
        """
        # Sections that we grab as we parse the file
        sections = {
            "ATOM_NAME": (1, 20, self.parse_names, "name", 0),
            "CHARGE": (1, 5, self.parse_charges, "charge", 0),
            "ATOMIC_NUMBER": (1, 10, self.parse_elements, "elements", 0),
            "MASS": (1, 5, self.parse_masses, "mass", 0),
            "ATOM_TYPE_INDEX": (1, 10, self.parse_type_indices, "type_indices",
                                0),
            "AMBER_ATOM_TYPE": (1, 20, self.parse_types, "types", 0),
            "RESIDUE_LABEL": (1, 20, self.parse_resnames, "resname", 11),
            "RESIDUE_POINTER": (1, 10, self.parse_residx, "respoint", 11),
            "BONDS_INC_HYDROGEN": (3, 10, self.parse_bonded, "bondh", 2),
            "BONDS_WITHOUT_HYDROGEN": (3, 10, self.parse_bonded, "bonda", 3),
            "ANGLES_INC_HYDROGEN": (4, 10, self.parse_bonded, "angh", 4),
            "ANGLES_WITHOUT_HYDROGEN": (4, 10, self.parse_bonded, "anga", 5),
            "DIHEDRALS_INC_HYDROGEN": (5, 10, self.parse_bonded, "dihh", 6),
            "DIHEDRALS_WITHOUT_HYDROGEN": (5, 10, self.parse_bonded, "diha", 7)
        }
        attrs = {}  # empty dict for attrs that we'll fill
        # Open and check top validity
        # Reading header info POINTERS
        with openany(self.filename) as self.topfile:
            header = next(self.topfile)
            if not header.startswith("%VE"):
                raise ValueError(
                    "{0} is not a valid TOP file. %VE Missing in header"
                    "".format(self.filename))
            title = next(self.topfile).split()
            if not (title[1] == "TITLE"):
                # Raise a separate warning if Chamber-style TOP is detected
                if title[1] == "CTITLE":
                    emsg = ("{0} is detected as a Chamber-style TOP file. "
                            "At this time MDAnalysis does not support such "
                            "topologies".format(self.filename))
                else:
                    emsg = ("{0} is not a valid TOP file. "
                            "'TITLE' missing in header".format(self.filename))
                raise ValueError(emsg)
            while not header.startswith('%FLAG POINTERS'):
                header = next(self.topfile)
            next(self.topfile)
            topremarks = [next(self.topfile).strip() for i in range(4)]
            sys_info = [int(k) for i in topremarks for k in i.split()]
            header = next(self.topfile)
            # grab the next section title
            next_section = header.split("%FLAG")[1].strip()
            while next_section is not None:
                try:
                    (num_per_record, per_line,
                     func, name, sect_num) = sections[next_section]
                except KeyError:
                    def next_getter():
                        return self.skipper()
                else:
                    num = sys_info[sect_num] * num_per_record
                    numlines = (num // per_line)
                    if num % per_line != 0:
                        numlines += 1
                    attrs[name] = func(num_per_record, numlines)
                    def next_getter():
                        return next(self.topfile)
                try:
                    line = next_getter()
                    # Capture case where section is empty w/ 1 empty line
                    if numlines == 0 and not line.strip():
                        line = next_getter()
                except StopIteration:
                    next_section = None
                else:
                    try:
                        next_section = line.split("%FLAG")[1].strip()
                    except IndexError:
                        errmsg = (f"%FLAG section not found, formatting error "
                                  f"for PARM7 file {self.filename} ")
                        raise IndexError(errmsg) from None
        # strip out a few values to play with them
        n_atoms = len(attrs['name'])
        resptrs = attrs.pop('respoint')
        resptrs.append(n_atoms)
        residx = np.zeros(n_atoms, dtype=np.int32)
        for i, (x, y) in enumerate(zip(resptrs[:-1], resptrs[1:])):
            residx[x:y] = i
        n_res = len(attrs['resname'])
        # Deal with recreating bonds and angle records here
        attrs['bonds'] = Bonds([i for i in itertools.chain(
                                attrs.pop('bonda'), attrs.pop('bondh'))])
        attrs['angles'] = Angles([i for i in itertools.chain(
                                attrs.pop('anga'), attrs.pop('angh'))])
        attrs['dihedrals'], attrs['impropers'] = self.parse_dihedrals(
                              attrs.pop('diha'), attrs.pop('dihh'))
        # Warn user if elements not in topology
        if 'elements' not in attrs:
            msg = ("ATOMIC_NUMBER record not found, elements attribute will "
                   "not be populated. If needed these can be guessed using "
                   "MDAnalysis.topology.guessers.")
            logger.warning(msg)
            warnings.warn(msg)
        elif np.any(attrs['elements'].values == ""):
            # only send out one warning that some elements are unknown
            msg = ("Unknown ATOMIC_NUMBER value found for some atoms, these "
                   "have been given an empty element record. If needed these "
                   "can be guessed using MDAnalysis.topology.guessers.")
            logger.warning(msg)
            warnings.warn(msg)
        # atom ids are mandatory
        attrs['atomids'] = Atomids(np.arange(n_atoms) + 1)
        attrs['resids'] = Resids(np.arange(n_res) + 1)
        attrs['resnums'] = Resnums(np.arange(n_res) + 1)
        attrs['segids'] = Segids(np.array(['SYSTEM'], dtype=object))
        top = Topology(n_atoms, n_res, 1,
                       attrs=list(attrs.values()),
                       atom_resindex=residx,
                       residue_segindex=None)
        return top 
[docs]    def skipper(self):
        """TOPParser :class: helper function, skips lines of input parm7 file
        until we find the next %FLAG entry and return that
        Returns
        -------
        line : string
            String containing the current line of the parm7 file
        """
        line = next(self.topfile)
        while not line.startswith("%FLAG"):
            line = next(self.topfile)
        return line 
[docs]    def parse_names(self, num_per_record, numlines):
        """Extracts atoms names from parm7 file
        Parameters
        ----------
        num_per_record : int
            The number of entries for each record in the section (unused input)
        numlines : int
            The number of lines to be parsed in current section
        Returns
        -------
        attr : :class:`Atomnames`
            A :class:`Atomnames` instance containing the names of each atom as
            defined in the parm7 file
        """
        vals = self.parsesection_mapper(numlines, lambda x: x)
        attr = Atomnames(np.array(vals, dtype=object))
        return attr 
[docs]    def parse_resnames(self, num_per_record, numlines):
        """Extracts the names of each residue
        Parameters
        ----------
        num_per_record : int
            The number of entries for each recrod in section (unused input)
        numlines : int
            The number of lines to be parsed in current section
        Returns
        -------
        attr : :class:`Resnames`
            A :class:`Resnames` instance containing the names of each residue
            as defined in the parm7 file
        """
        vals = self.parsesection_mapper(numlines, lambda x: x)
        attr = Resnames(np.array(vals, dtype=object))
        return attr 
[docs]    def parse_charges(self, num_per_record, numlines):
        """Extracts the partial charges for each atom
        Parameters
        ----------
        num_per_record : int
            The number of entries for each record in section (unused input)
        numlines : int
            The number of lines to be parsed in current section
        Returns
        -------
        attr : :class:`Charges`
            A :class:`Charges` instance containing the partial charges of each
            atom as defined in the parm7 file
        """
        vals = self.parsesection_mapper(numlines, lambda x: float(x))
        charges = np.array(vals, dtype=np.float32)
        charges /= 18.2223  # to electron charge units
        attr = Charges(charges)
        return attr 
[docs]    def parse_masses(self, num_per_record, numlines):
        """Extracts the mass of each atom
        Parameters
        ----------
        num_per_record : int
            The number of entries for each record in section (unused input)
        numlines : int
            The number of lines to be parsed in current section
        Returns
        -------
        attr : :class:`Masses`
            A :class:`Masses` instance containing the mass of each atom as
            defined in the parm7 file
        """
        vals = self.parsesection_mapper(numlines, lambda x: float(x))
        attr = Masses(vals)
        return attr 
[docs]    def parse_elements(self, num_per_record, numlines):
        """Extracts the atomic numbers of each atom and converts to element type
        Parameters
        ----------
        num_per_record : int
            The number of entries for each record in section(unused input)
        numlines : int
            The number of lines to be pasred in current section
        Returns
        -------
        attr : :class:`Elements`
            A :class:`Elements` instance containing the element of each atom
            as defined in the parm7 file
        Note
        ----
        If the record contains unknown atomic numbers (e.g. <= 0), these will
        be treated as unknown elements and assigned an empty string value. See
        issues #2306 and #2651 for more details.
        .. versionchanged:: 2.0.0
           Unrecognised elements will now return a empty string. The parser
           will no longer attempt to guess the element by default.
        """
        vals = self.parsesection_mapper(
                numlines,
                lambda x: Z2SYMB[int(x)] if int(x) > 0 else "")
        attr = Elements(np.array(vals, dtype=object))
        return attr 
[docs]    def parse_types(self, num_per_record, numlines):
        """Extracts the force field atom types of each atom
        Parameters
        ----------
        num_per_record : int
            The number of entries for each record in section (unused input)
        numlines : int
            The number of lines to be parsed in current section
        Returns
        -------
        attr : :class:`Atomtypes`
            A :class:`Atomtypes` instance containing the atom types for each
            atom as defined in the parm7 file
        """
        vals = self.parsesection_mapper(numlines, lambda x: x)
        attr = Atomtypes(np.array(vals, dtype=object))
        return attr 
[docs]    def parse_type_indices(self, num_per_record, numlines):
        """Extracts the index of atom types of the each atom involved in Lennard
        Jones (6-12) interactions.
        Parameters
        ----------
        num_per_record : int
            The number of entries for each record in section (unused input)
        numlines : int
            The number of lines to be parsed in current section
        Returns
        -------
        attr :class:`TypeIndices`
            A :class:`TypeIndices` instance containing the LJ 6-12 atom type
            index for each atom
        """
        vals = self.parsesection_mapper(numlines, lambda x: int(x))
        attr = TypeIndices(np.array(vals, dtype=np.int32))
        return attr 
[docs]    def parse_residx(self, num_per_record, numlines):
        """Extracts the residue pointers for each atom
        Parameters
        ----------
        num_per_record : int
            The number of entries for each record in section (unused input)
        numlines : int
            The number of lines to be parsed in current section
        Returns
        -------
        vals : list of int
            A list of zero-formatted residue pointers for each atom
        """
        vals = self.parsesection_mapper(numlines, lambda x: int(x) - 1)
        return vals 
[docs]    def parse_chunks(self, data, chunksize):
        """Helper function to parse AMBER PRMTOP bonds/angles.
        Parameters
        ----------
        data : list of int
            Input list of the parm7 bond/angle section, zero-indexed
        num_per_record : int
            The number of entries for each record in the input list
        Returns
        -------
        vals : list of int tuples
            A list of tuples containing the atoms involved in a given bonded
            interaction
        Note
        ----
        In the parm7 format this information is structured in the following
        format: [ atoms 1:n, internal index ]
        Where 1:n represent the ids of the n atoms involved in the bond/angle
        and the internal index links to a given set of FF parameters.
        Therefore, to extract the required information, we split out the list
        into chunks of size num_per_record, and only extract the atom ids.
        """
        vals = [tuple(data[x:x+chunksize-1])
                for x in range(0, len(data), chunksize)]
        return vals 
[docs]    def parse_bonded(self, num_per_record, numlines):
        """Extracts bond information from PARM7 format files
        Parameters
        ----------
        num_per_record : int
            The number of entries for each record in section
        numlines : int
            The number of lines to be parsed for this section
        Note
        ----
        For the bond/angle sections of parm7 files, the atom numbers are set to
        coordinate array index values. As detailed in
        http://ambermd.org/formats.html to recover the actual atom number, one
        should divide the values by 3 and add 1. Here, since we want to satisfy
        zero-indexing, we only divide by 3.
        """
        fields = self.parsesection_mapper(numlines, lambda x: int(x) // 3)
        section = self.parse_chunks(fields, num_per_record)
        return section 
[docs]    def parsesection_mapper(self, numlines, mapper):
        """Parses FORTRAN formatted section, and returns a list of all entries
        in each line
        Parameters
        ----------
        numlines : int
            The number of lines to be parsed in this section
        mapper : lambda operator
            Operator to format entries in current section
        Returns
        -------
        section : list
            A list of all entries in a given parm7 section
        """
        section = []
        y = next(self.topfile).strip("%FORMAT(")
        y.strip(")")
        x = FORTRANReader(y)
        for i in range(numlines):
            l = next(self.topfile)
            for j in range(len(x.entries)):
                val = l[x.entries[j].start:x.entries[j].stop].strip()
                if val:
                    section.append(mapper(val))
        return section 
[docs]    def parse_dihedrals(self, diha, dihh):
        """Combines hydrogen and non-hydrogen containing AMBER dihedral lists
        and extracts sublists for conventional dihedrals and improper angles
        Parameters
        ----------
        diha : list of tuples
            The atom ids of dihedrals not involving hydrogens
        dihh : list of tuples
            The atom ids of dihedrals involving hydrogens
        Returns
        -------
        dihedrals : :class:`Dihedrals`
            A :class:`Dihedrals` instance containing a list of all unique
            dihedrals as defined by the parm7 file
        impropers : :class:`Impropers`
            A :class:`Impropers` instance containing a list of all unique
            improper dihedrals as defined by the parm7 file
        Note
        ----
        As detailed in http://ambermd.org/formats.html, the dihedral sections
        of parm7 files contain information about both conventional dihedrals
        and impropers. The following must be accounted for:
        1) If the fourth atom in a dihedral entry is given a negative value,
        this indicates that it is an improper.
        2) If the third atom in a dihedral entry is given a negative value,
        this indicates that it 1-4 NB interactions are ignored for this
        dihedrals. This could be due to the dihedral within a ring, or if it is
        part of a multi-term dihedral definition or if it is an improper.
        """
        improp = []
        dihed = []
        for i in itertools.chain(diha, dihh):
            if i[3] < 0:
                improp.append(i[:2]+(abs(i[2]),)+(abs(i[3]),))
            elif i[2] < 0:
                vals = i[:2] + (abs(i[2]),) + i[3:]
                dihed.append(vals)
            else:
                dihed.append(i)
        dihed = sorted(set(dihed))
        dihedrals = Dihedrals(dihed)
        impropers = Impropers(improp)
        return dihedrals, impropers