Source code for openff.interchange.smirnoff._base

import abc
import json
from typing import TypeVar

from openff.models.models import DefaultModel
from openff.models.types import custom_quantity_encoder
from openff.toolkit import Quantity, Topology, unit
from openff.toolkit.typing.engines.smirnoff.parameters import (
    AngleHandler,
    BondHandler,
    ImproperTorsionHandler,
    ParameterHandler,
    ProperTorsionHandler,
)

from openff.interchange.components.potentials import Collection, Potential
from openff.interchange.exceptions import (
    InvalidParameterHandlerError,
    SMIRNOFFParameterAttributeNotImplementedError,
    UnassignedAngleError,
    UnassignedBondError,
    UnassignedTorsionError,
)
from openff.interchange.models import (
    LibraryChargeTopologyKey,
    PotentialKey,
    TopologyKey,
)

T = TypeVar("T", bound="SMIRNOFFCollection")
TP = TypeVar("TP", bound="ParameterHandler")


def _sanitize(o) -> str | dict:
    # `BaseModel.json()` assumes that all keys and values in dicts are JSON-serializable, which is a problem
    # for the mapping dicts `key_map` and `potentials`.
    if isinstance(o, dict):
        return {_sanitize(k): _sanitize(v) for k, v in o.items()}
    elif isinstance(o, DefaultModel):
        return o.json()
    elif isinstance(o, unit.Quantity):
        return custom_quantity_encoder(o)
    return o


def dump_collection(v, *, default):
    """Dump a SMIRNOFFCollection to JSON after converting to compatible types."""
    return json.dumps(_sanitize(v), default=default)


def collection_loader(data: str) -> dict:
    """Load a JSON blob dumped from a `Collection`."""
    tmp: dict[str, int | float | bool | str | dict | None] = {}

    for key, val in json.loads(data).items():
        if val is None:
            tmp[key] = val
        elif isinstance(val, (int, float, bool)):
            tmp[key] = val
        elif isinstance(val, (str)):
            # These are stored as string but must be parsed into `Quantity`
            if key in ("cutoff", "switch_width"):
                tmp[key] = Quantity(*json.loads(val).values())  # type: ignore[arg-type]
            else:
                tmp[key] = val
        elif isinstance(val, dict):
            if key == "key_map":
                key_map = {}

                for key_, val_ in val.items():
                    if "atom_indices" in key_:
                        topology_key: TopologyKey | LibraryChargeTopologyKey = (
                            TopologyKey.parse_raw(key_)
                        )

                    else:
                        topology_key = LibraryChargeTopologyKey.parse_raw(key_)

                    # TODO: Not obvious if cosmetic attributes survive here
                    potential_key = PotentialKey(**val_)

                    key_map[topology_key] = potential_key

                tmp[key] = key_map  # type: ignore[assignment]

            elif key == "potentials":
                potentials = {}

                for key_, val_ in val.items():
                    potential_key = PotentialKey.parse_raw(key_)
                    potential = Potential.parse_raw(json.dumps(val_))

                    potentials[potential_key] = potential

                tmp[key] = potentials  # type: ignore[assignment]

            else:
                raise NotImplementedError(f"Cannot parse {key} in this JSON.")

    return tmp


# Coped from the toolkit, see
# https://github.com/openforcefield/openff-toolkit/blob/0133414d3ab51e1af0996bcebe0cc1bdddc6431b/
# openff/toolkit/typing/engines/smirnoff/parameters.py#L2318
def _check_all_valence_terms_assigned(
    handler,
    assigned_terms,
    topology,
    valence_terms,
):
    """Check that all valence terms have been assigned."""
    if len(assigned_terms) == len(valence_terms):
        return

    # Convert the valence term to a valence dictionary to make sure
    # the order of atom indices doesn't matter for comparison.
    valence_terms_dict = assigned_terms.__class__()
    for atoms in valence_terms:
        atom_indices = (topology.atom_index(a) for a in atoms)
        valence_terms_dict[atom_indices] = atoms

    # Check that both valence dictionaries have the same keys (i.e. terms).
    assigned_terms_set = set(assigned_terms.keys())
    valence_terms_set = set(valence_terms_dict.keys())
    unassigned_terms = valence_terms_set.difference(assigned_terms_set)
    not_found_terms = assigned_terms_set.difference(valence_terms_set)

    # Raise an error if there are unassigned terms.
    err_msg = ""

    if len(unassigned_terms) > 0:
        unassigned_atom_tuples = []

        unassigned_str = ""
        for unassigned_tuple in unassigned_terms:
            unassigned_str += "\n- Topology indices " + str(unassigned_tuple)
            unassigned_str += ": names and elements "

            unassigned_atoms = []

            # Pull and add additional helpful info on missing terms
            for atom_idx in unassigned_tuple:
                atom = topology.atom(atom_idx)
                unassigned_atoms.append(atom)
                unassigned_str += f"({atom.name} {atom.symbol}), "
            unassigned_atom_tuples.append(tuple(unassigned_atoms))
        err_msg += (
            "{parameter_handler} was not able to find parameters for the following valence terms:\n"
            "{unassigned_str}"
        ).format(
            parameter_handler=handler.__class__.__name__,
            unassigned_str=unassigned_str,
        )
    if len(not_found_terms) > 0:
        if err_msg != "":
            err_msg += "\n"
        not_found_str = "\n- ".join([str(x) for x in not_found_terms])
        err_msg += (
            "{parameter_handler} assigned terms that were not found in the topology:\n"
            "- {not_found_str}"
        ).format(
            parameter_handler=handler.__class__.__name__,
            not_found_str=not_found_str,
        )
    if err_msg:
        err_msg += "\n"

        if isinstance(handler, BondHandler):
            exception_class = UnassignedBondError
        elif isinstance(handler, AngleHandler):
            exception_class = UnassignedAngleError
        elif isinstance(handler, (ProperTorsionHandler, ImproperTorsionHandler)):
            exception_class = UnassignedTorsionError
        else:
            raise RuntimeError(
                f"Could not find an exception class for handler {handler}",
            )

        exception = exception_class(err_msg)
        exception.unassigned_topology_atom_tuples = unassigned_atom_tuples
        exception.handler_class = handler.__class__
        raise exception


[docs]class SMIRNOFFCollection(Collection, abc.ABC): """Base class for handlers storing potentials produced by SMIRNOFF force fields.""" is_plugin: bool = False
[docs] def modify_openmm_forces(self, *args, **kwargs): """Optionally modify, create, or delete forces. Currently only available to plugins.""" raise NotImplementedError()
class Config: """Default configuration options for SMIRNOFF potential handlers.""" json_dumps = dump_collection json_loads = collection_loader validate_assignment = True arbitrary_types_allowed = True
[docs] @classmethod @abc.abstractmethod def allowed_parameter_handlers(cls): """Return a list of allowed types of ParameterHandler classes.""" raise NotImplementedError()
[docs] @classmethod @abc.abstractmethod def supported_parameters(cls): """Return a list of parameter attributes supported by this handler.""" raise NotImplementedError()
[docs] @classmethod def potential_parameters(cls): """Return a subset of `supported_parameters` that are meant to be included in potentials.""" raise NotImplementedError()
[docs] @classmethod def check_supported_parameters(cls, parameter_handler: ParameterHandler): """Verify that a parameter handler is in an allowed list of handlers.""" for parameter in parameter_handler.parameters: for parameter_attribute in parameter._get_defined_parameter_attributes(): if parameter_attribute == "parent_id": continue if parameter_attribute not in cls.supported_parameters(): raise SMIRNOFFParameterAttributeNotImplementedError( parameter_attribute, )
[docs] @classmethod def check_openmm_requirements(cls, combine_nonbonded_forces: bool) -> None: """Run through a list of assertions about what is compatible when exporting this to OpenMM."""
[docs] def store_matches( self, parameter_handler: ParameterHandler, topology: "Topology", ) -> None: """Populate self.key_map with key-val pairs of [TopologyKey, PotentialKey].""" if self.key_map: # TODO: Should the key_map always be reset, or should we be able to partially # update it? Also Note the duplicated code in the child classes self.key_map: dict[ TopologyKey | LibraryChargeTopologyKey, PotentialKey, ] = dict() matches = parameter_handler.find_matches(topology) for key, val in matches.items(): parameter: ParameterHandler.ParameterType = val.parameter_type cosmetic_attributes = { cosmetic_attribute: getattr( parameter, f"_{cosmetic_attribute}", ) for cosmetic_attribute in parameter._cosmetic_attribs } topology_key = TopologyKey(atom_indices=key) potential_key = PotentialKey( id=parameter.smirks, associated_handler=parameter_handler.TAGNAME, cosmetic_attributes=cosmetic_attributes, ) self.key_map[topology_key] = potential_key if self.__class__.__name__ in [ "SMIRNOFFBondCollection", "SMIRNOFFAngleCollection", ]: valence_terms = self.valence_terms(topology) _check_all_valence_terms_assigned( handler=parameter_handler, assigned_terms=matches, topology=topology, valence_terms=valence_terms, )
[docs] def store_potentials(self, parameter_handler: TP): """ Populate self.potentials with key-val pairs of [PotentialKey, Potential]. """ raise NotImplementedError()
[docs] @classmethod def create( cls: type[T], parameter_handler: TP, topology: "Topology", ) -> T: """ Create a SMIRNOFFCOllection from toolkit data. """ if type(parameter_handler) not in cls.allowed_parameter_handlers(): raise InvalidParameterHandlerError(type(parameter_handler)) handler = cls() if hasattr(handler, "fractional_bondorder_method"): if getattr(parameter_handler, "fractional_bondorder_method", None): handler.fractional_bond_order_method = ( # type: ignore[attr-defined] parameter_handler.fractional_bondorder_method ) handler.fractional_bond_order_interpolation = ( # type: ignore[attr-defined] parameter_handler.fractional_bondorder_interpolation ) handler.store_matches(parameter_handler=parameter_handler, topology=topology) handler.store_potentials(parameter_handler=parameter_handler) return handler
def __repr__(self) -> str: return ( f"Handler '{self.type}' with expression '{self.expression}', {len(self.key_map)} mapping keys, " f"and {len(self.potentials)} potentials" )