Source code for openff.interchange.drivers.openmm

"""Functions for running energy evluations with OpenMM."""

import warnings
from typing import TYPE_CHECKING, Optional

import numpy
from openff.toolkit import unit
from openff.units.openmm import ensure_quantity
from openff.utilities.utilities import has_package, requires_package

from openff.interchange import Interchange
from openff.interchange.drivers.report import EnergyReport
from openff.interchange.exceptions import CannotInferNonbondedEnergyError
from openff.interchange.interop.openmm._positions import to_openmm_positions

if has_package("openmm") or TYPE_CHECKING:
    import openmm
    import openmm.unit


[docs]@requires_package("openmm") def get_openmm_energies( interchange: Interchange, round_positions: int | None = None, combine_nonbonded_forces: bool = True, detailed: bool = False, platform: str = "Reference", ) -> EnergyReport: """ Given an OpenFF Interchange object, return single-point energies as computed by OpenMM. .. warning :: This API is experimental and subject to change. Parameters ---------- interchange : openff.interchange.Interchange An OpenFF Interchange object to compute the single-point energy of round_positions : int, optional The number of decimal places, in nanometers, to round positions. This can be useful when comparing to i.e. GROMACS energies, in which positions may be rounded. combine_nonbonded_forces : bool, default=False Whether or not to combine all non-bonded interactions (vdW, short- and long-range ectrostaelectrostatics, and 1-4 interactions) into a single openmm.NonbondedForce. platform : str, default="Reference" The name of the platform (`openmm.Platform`) used by OpenMM in this calculation. detailed : bool, default=False Attempt to report energies with more granularity. Not guaranteed to be compatible with all values of other arguments. Useful for debugging. Returns ------- report : EnergyReport An `EnergyReport` object containing the single-point energies. """ if "VirtualSites" in interchange.collections: if len(interchange["VirtualSites"].key_map) > 0: if not combine_nonbonded_forces: warnings.warn( "Collecting energies from split forces with virtual sites is experimental", UserWarning, stacklevel=2, ) has_virtual_sites = True else: has_virtual_sites = False else: has_virtual_sites = False system: openmm.System = interchange.to_openmm( combine_nonbonded_forces=combine_nonbonded_forces, ) box_vectors: openmm.unit.Quantity = ( None if interchange.box is None else interchange.box.to_openmm() ) positions: openmm.unit.Quantity = to_openmm_positions( interchange, include_virtual_sites=has_virtual_sites, ) return _process( _get_openmm_energies( system=system, box_vectors=box_vectors, positions=positions, round_positions=round_positions, platform=platform, ), combine_nonbonded_forces=combine_nonbonded_forces, detailed=detailed, system=system, )
def _get_openmm_energies( system: "openmm.System", box_vectors: Optional["openmm.unit.Quantity"], positions: "openmm.unit.Quantity", round_positions: int | None, platform: str, ) -> dict[int, "openmm.unit.Quantity"]: """Given prepared `openmm` objects, run a single-point energy calculation.""" for index, force in enumerate(system.getForces()): force.setForceGroup(index) integrator = openmm.VerletIntegrator(1.0 * openmm.unit.femtoseconds) context = openmm.Context( system, integrator, openmm.Platform.getPlatformByName(platform), ) if box_vectors is not None: context.setPeriodicBoxVectors(*box_vectors) context.setPositions( ( numpy.round(positions, round_positions) if round_positions is not None else positions ), ) raw_energies: dict[int, openmm.unit.Quantity] = dict() for index in range(system.getNumForces()): state = context.getState(getEnergy=True, groups={index}) raw_energies[index] = state.getPotentialEnergy() del state del context del integrator return raw_energies def _process( raw_energies: dict[int, "openmm.unit.Quantity"], system: "openmm.System", combine_nonbonded_forces: bool, detailed: bool, ) -> EnergyReport: staged: dict[str, unit.Quantity] = dict() valence_map = { openmm.HarmonicBondForce: "Bond", openmm.HarmonicAngleForce: "Angle", openmm.PeriodicTorsionForce: "Torsion", openmm.RBTorsionForce: "RBTorsion", } # This assumes that only custom forces will have duplicate instances for index, raw_energy in raw_energies.items(): force = system.getForce(index) if type(force) in valence_map: staged[valence_map[type(force)]] = raw_energy elif type(force) in [ openmm.NonbondedForce, openmm.CustomNonbondedForce, openmm.CustomBondForce, ]: if combine_nonbonded_forces: assert isinstance(force, openmm.NonbondedForce) staged["Nonbonded"] = raw_energy continue else: if isinstance(force, openmm.NonbondedForce): staged["Electrostatics"] = raw_energy elif isinstance(force, openmm.CustomNonbondedForce): staged["vdW"] = raw_energy elif isinstance(force, openmm.CustomBondForce): if "qq" in force.getEnergyFunction(): staged["Electrostatics 1-4"] = raw_energy else: staged["vdW 1-4"] = raw_energy else: raise CannotInferNonbondedEnergyError() if detailed: processed = staged else: processed = { key: staged[key] for key in ["Bond", "Angle", "Torsion", "RBTorsion"] if key in staged } nonbonded_energies = [ staged[key] for key in [ "Nonbonded", "Electrostatics", "vdW", "Electrostatics 1-4", "vdW 1-4", ] if key in staged ] # Array inference acts up if given a 1-list of Quantity if combine_nonbonded_forces: assert len(nonbonded_energies) == 1 processed["Nonbonded"] = nonbonded_energies[0] else: zero = 0.0 * openmm.unit.kilojoule_per_mole processed["Electrostatics"] = ensure_quantity( numpy.sum( [ staged.get(key, zero) for key in ["Electrostatics", "Electrostatics 1-4"] ], ), "openff", ) processed["vdW"] = ensure_quantity( numpy.sum([staged.get(key, zero) for key in ["vdW", "vdW 1-4"]]), "openff", ) return EnergyReport(energies=processed)