Source code for openff.interchange.drivers.report

"""Storing and processing results of energy evaluations."""

import warnings

from openff.models.models import DefaultModel
from openff.models.types import FloatQuantity
from openff.toolkit import unit

from openff.interchange._pydantic import validator
from openff.interchange.constants import kj_mol
from openff.interchange.exceptions import (
    EnergyError,
    IncompatibleTolerancesError,
    InvalidEnergyError,
)

_KNOWN_ENERGY_TERMS: set[str] = {
    "Bond",
    "Angle",
    "Torsion",
    "RBTorsion",
    "Nonbonded",
    "vdW",
    "Electrostatics",
    "vdW 1-4",
    "Electrostatics 1-4",
}


[docs]class EnergyReport(DefaultModel): """A lightweight class containing single-point energies as computed by energy tests.""" # TODO: Should the default be None or 0.0 kj_mol? energies: dict[str, FloatQuantity | None] = { "Bond": None, "Angle": None, "Torsion": None, "vdW": None, "Electrostatics": None, }
[docs] @validator("energies") def validate_energies(cls, v: dict) -> dict: """Validate the structure of a dict mapping keys to energies.""" for key, val in v.items(): if key not in _KNOWN_ENERGY_TERMS: raise InvalidEnergyError(f"Energy type {key} not understood.") if not isinstance(val, unit.Quantity): v[key] = FloatQuantity.validate_type(val) return v
@property def total_energy(self): """Return the total energy.""" return self["total"] def __getitem__(self, item: str) -> FloatQuantity | None: if type(item) is not str: raise LookupError( "Only str arguments can be currently be used for lookups.\n" f"Found item {item} of type {type(item)}", ) if item in self.energies.keys(): return self.energies[item] if item.lower() == "total": return sum(self.energies.values()) # type: ignore else: return None
[docs] def update(self, new_energies: dict) -> None: """Update the energies in this report with new value(s).""" self.energies.update(self.validate_energies(new_energies))
[docs] def compare( self, other: "EnergyReport", tolerances: dict[str, FloatQuantity] | None = None, ): """ Compare two energy reports. Parameters ---------- other: EnergyReport The other `EnergyReport` to compare energies against tolerances: dict of str: `FloatQuantity` Per-key allowed differences in energies """ default_tolerances = { "Bond": 1e-3 * kj_mol, "Angle": 1e-3 * kj_mol, "Torsion": 1e-3 * kj_mol, "vdW": 1e-3 * kj_mol, "Electrostatics": 1e-3 * kj_mol, } if tolerances: default_tolerances.update(tolerances) tolerances = default_tolerances # Ensure everything is in kJ/mol for safety of later comparison energy_differences = { key: diff.to(kj_mol) for key, diff in self.diff(other).items() } if ("Nonbonded" in tolerances) != ("Nonbonded" in energy_differences): raise IncompatibleTolerancesError( "Mismatch between energy reports and tolerances with respect to whether nonbonded " "interactions are collapsed into a single value.", ) errors = dict() for key, diff in energy_differences.items(): if abs(energy_differences[key]) > tolerances[key]: errors[key] = diff if errors: raise EnergyError(errors)
[docs] def diff( self, other: "EnergyReport", ) -> dict[str, FloatQuantity]: """ Return the per-key energy differences between these reports. Parameters ---------- other: EnergyReport The other `EnergyReport` to compare energies against Returns ------- energy_differences : dict of str: `FloatQuantity` Per-key energy differences """ energy_differences: dict[str, FloatQuantity] = dict() nonbondeds_processed = False for key in self.energies: if key in ("Bond", "Angle", "Torsion"): energy_differences[key] = self[key] - other[key] # type: ignore[operator] continue if key in ("Nonbonded", "vdW", "Electrostatics"): if nonbondeds_processed: continue if (self["vdW"] and other["vdW"]) is not None and ( self["Electrostatics"] and other["Electrostatics"] ) is not None: for key in ("vdW", "Electrostatics"): energy_differences[key] = self[key] - other[key] energy_differences[key] = self[key] - other[key] nonbondeds_processed = True continue else: energy_differences["Nonbonded"] = ( self._get_nonbonded_energy() - other._get_nonbonded_energy() ) nonbondeds_processed = True continue return energy_differences
def __sub__(self, other: "EnergyReport") -> dict[str, FloatQuantity]: diff = dict() for key in self.energies: if key not in other.energies: warnings.warn(f"Did not find key {key} in second report", stacklevel=2) continue diff[key]: FloatQuantity = self.energies[key] - other.energies[key] # type: ignore return diff def __str__(self) -> str: return ( "Energies:\n\n" f"Bond: \t\t{self['Bond']}\n" f"Angle: \t\t{self['Angle']}\n" f"Torsion: \t\t{self['Torsion']}\n" f"RBTorsion: \t\t{self['RBTorsion']}\n" f"Nonbonded: \t\t{self['Nonbonded']}\n" f"vdW: \t\t{self['vdW']}\n" f"Electrostatics:\t\t{self['Electrostatics']}\n" ) def _get_nonbonded_energy(self) -> FloatQuantity: nonbonded_energy = 0.0 * kj_mol for key in ("Nonbonded", "vdW", "Electrostatics"): if key in self.energies is not None: nonbonded_energy += self.energies[key] return nonbonded_energy