Source code for openff.interchange.drivers.report
"""Storing and processing results of energy evaluations."""
import warnings
from typing import Dict, Optional
import pandas as pd
from openff.units import unit
from pydantic import validator
from openff.interchange.exceptions import EnergyError, MissingEnergyError
from openff.interchange.models import DefaultModel
from openff.interchange.types import FloatQuantity
kj_mol = unit.kilojoule / unit.mol
[docs]class EnergyReport(DefaultModel):
"""A lightweight class containing single-point energies as computed by energy tests."""
# TODO: Use FloatQuantity, not float
energies: Dict[str, Optional[FloatQuantity]] = {
"Bond": None,
"Angle": None,
"Torsion": None,
"vdW": None,
"Electrostatics": None,
}
[docs] @validator("energies")
def validate_energies(cls, v):
for key, val in v.items():
if not isinstance(val, unit.Quantity):
v[key] = FloatQuantity.validate_type(val)
return v
def __getitem__(self, item: str):
if type(item) != str:
raise LookupError(
"Only str arguments can be currently be used for lookups.\n"
f"Found item {item} of type {type(item)}"
)
if item in self.energies.keys():
return self.energies[item]
if item.lower() == "total":
return sum(self.energies.values())
else:
return None
[docs] def update_energies(self, new_energies):
"""Update the energies in this report with new value(s)."""
self.energies.update(self.validate_energies(new_energies))
# TODO: Better way of exposing tolerances
[docs] def compare(self, other: "EnergyReport", custom_tolerances=None):
"""
Compare this `EnergyReport` to another `EnergyReport`.
Energies are grouped into four categories (bond, angle, torsion, and nonbonded) with
default tolerances for each set to 1e-3 kJ/mol.
.. warning :: This API is experimental and subject to change.
Parameters
----------
other: EnergyReport
The other `EnergyReport` to compare energies against
custom_tolerances: dict of str: `FloatQuantity`, optional
Custom energy tolerances to use to use in comparisons.
"""
tolerances: Dict[str, FloatQuantity] = {
"Bond": 1e-3 * kj_mol,
"Angle": 1e-3 * kj_mol,
"Torsion": 1e-3 * kj_mol,
"vdW": 1e-3 * kj_mol,
"Electrostatics": 1e-3 * kj_mol,
}
if custom_tolerances is not None:
tolerances.update(custom_tolerances)
tolerances = self.validate_energies(tolerances)
errors = pd.DataFrame()
for key in self.energies:
if self.energies[key] is None and other.energies[key] is None:
continue
if self.energies[key] is None and other.energies[key] is None:
raise MissingEnergyError
# TODO: Remove this when OpenMM's NonbondedForce is split out
if key == "Nonbonded":
if "Nonbonded" in other.energies:
this_nonbonded = self.energies["Nonbonded"]
other_nonbonded = other.energies["Nonbonded"]
else:
this_nonbonded = self.energies["Nonbonded"]
other_nonbonded = other.energies["vdW"] + other.energies["Electrostatics"] # type: ignore
elif key in ["vdW", "Electrostatics"] and key not in other.energies:
this_nonbonded = self.energies["vdW"] + self.energies["Electrostatics"] # type: ignore
other_nonbonded = other.energies["Nonbonded"]
else:
diff = self.energies[key] - other.energies[key] # type: ignore[operator]
tolerance = tolerances[key]
if abs(diff) > tolerance:
data: Dict = {
"key": [key],
"diff": [diff],
"tol": [tolerance],
"ener1": [self.energies[key]],
"ener2": [other.energies[key]],
}
error = pd.DataFrame.from_dict(data)
errors = errors.append(error)
continue
diff = this_nonbonded - other_nonbonded # type: ignore
try:
tolerance = tolerances[key]
except KeyError as e:
if "Nonbonded" in str(e):
tolerance = tolerances["vdW"] + tolerances["Electrostatics"] # type: ignore[assignment]
else:
raise e
if abs(diff) > tolerance:
data: Dict = { # type: ignore[no-redef]
"key": ["Nonbonded"],
"diff": [diff],
"tol": [tolerance],
"ener1": [this_nonbonded],
"ener2": [other_nonbonded],
}
error = pd.DataFrame.from_dict(data)
errors = errors.append(error)
if len(errors) > 0:
for col_name in ["diff", "tol", "ener1", "ener2"]:
col_mod = [x.m_as(kj_mol) for x in errors[col_name]]
errors[col_name] = col_mod
raise EnergyError(
"\nSome energy difference(s) exceed tolerances! "
"\nAll values are reported in kJ/mol:"
"\n" + str(errors.to_string(index=False))
)
# TODO: Return energy differences even if none are greater than tolerance
# This might result in mis-matched keys
def __sub__(self, other):
diff = dict()
for key in self.energies:
if key not in other.energies:
warnings.warn(f"Did not find key {key} in second report")
continue
diff[key] = self.energies[key] - other.energies[key]
return diff
def __str__(self):
return (
"Energies:\n\n"
f"Bond: \t\t{self['Bond']}\n"
f"Angle: \t\t{self['Angle']}\n"
f"Torsion: \t\t{self['Torsion']}\n"
f"Nonbonded: \t\t{self['Nonbonded']}\n"
f"vdW: \t\t{self['vdW']}\n"
f"Electrostatics:\t\t{self['Electrostatics']}\n"
)