"""Models for storing applied force field parameters."""
import ast
import json
import warnings
from collections.abc import Callable
from typing import Union
import numpy
from openff.models.models import DefaultModel
from openff.models.types import ArrayQuantity, FloatQuantity
from openff.toolkit import Quantity
from openff.utilities.utilities import has_package, requires_package
from openff.interchange._pydantic import Field, PrivateAttr, validator
from openff.interchange.exceptions import MissingParametersError
from openff.interchange.models import (
LibraryChargeTopologyKey,
PotentialKey,
TopologyKey,
)
from openff.interchange.warnings import InterchangeDeprecationWarning
if has_package("jax"):
from jax import numpy as jax_numpy
from numpy.typing import ArrayLike
if has_package("jax"):
from jax import Array
def __getattr__(name: str):
if name == "PotentialHandler":
warnings.warn(
"`PotentialHandler` has been renamed to `Collection`. "
"Importing `Collection` instead.",
InterchangeDeprecationWarning,
stacklevel=2,
)
return Collection
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
[docs]def potential_loader(data: str) -> dict:
"""Load a JSON blob dumped from a `Collection`."""
tmp: dict[str, int | bool | str | dict] = {}
for key, val in json.loads(data).items():
if isinstance(val, (str, type(None))):
tmp[key] = val # type: ignore
elif isinstance(val, dict):
if key == "parameters":
tmp["parameters"] = dict()
for key_, val_ in val.items():
loaded = json.loads(val_)
tmp["parameters"][key_] = Quantity( # type: ignore[index]
loaded["val"],
loaded["unit"],
)
return tmp
[docs]class Potential(DefaultModel):
"""Base class for storing applied parameters."""
parameters: dict[str, FloatQuantity] = dict()
map_key: int | None = None
class Config:
"""Pydantic configuration."""
json_encoders: dict[type, Callable] = DefaultModel.Config.json_encoders
json_loads: Callable = potential_loader
validate_assignment: bool = True
arbitrary_types_allowed: bool = True
[docs] @validator("parameters")
def validate_parameters(
cls,
v: dict[str, ArrayQuantity | FloatQuantity],
) -> dict[str, FloatQuantity]:
for key, val in v.items():
if isinstance(val, list):
v[key] = ArrayQuantity.validate_type(val)
else:
v[key] = FloatQuantity.validate_type(val)
return v
def __hash__(self) -> int:
return hash(tuple(self.parameters.values()))
[docs]class WrappedPotential(DefaultModel):
"""Model storing other Potential model(s) inside inner data."""
[docs] class InnerData(DefaultModel):
"""The potentials being wrapped."""
data: dict[Potential, float]
_inner_data: InnerData = PrivateAttr()
def __init__(self, data: Potential | dict) -> None:
if isinstance(data, Potential):
self._inner_data = self.InnerData(data={data: 1.0})
elif isinstance(data, dict):
self._inner_data = self.InnerData(data=data)
@property
def parameters(self) -> dict[str, FloatQuantity]:
"""Get the parameters as represented by the stored potentials and coefficients."""
keys: set[str] = {
param_key
for pot in self._inner_data.data.keys()
for param_key in pot.parameters.keys()
}
params = dict()
for key in keys:
params.update(
{
key: sum(
coeff * pot.parameters[key]
for pot, coeff in self._inner_data.data.items()
),
},
)
return params
def __repr__(self) -> str:
return str(self._inner_data.data)
[docs]class Collection(DefaultModel):
"""Base class for storing parametrized force field data."""
type: str = Field(..., description="The type of potentials this handler stores.")
is_plugin: bool = Field(
False,
description="Whether this collection is defined as a plugin.",
)
expression: str = Field(
...,
description="The analytical expression governing the potentials in this handler.",
)
key_map: dict[TopologyKey | LibraryChargeTopologyKey, PotentialKey] = Field(
dict(),
description="A mapping between TopologyKey objects and PotentialKey objects.",
)
potentials: dict[PotentialKey, Potential | WrappedPotential] = Field(
dict(),
description="A mapping between PotentialKey objects and Potential objects.",
)
@property
def independent_variables(self) -> set[str]:
"""
Return a set of variables found in the expression but not in any potentials.
"""
vars_in_potentials = set([*self.potentials.values()][0].parameters.keys())
vars_in_expression = {
node.id
for node in ast.walk(ast.parse(self.expression))
if isinstance(node, ast.Name)
}
return vars_in_expression - vars_in_potentials
def _get_parameters(self, atom_indices: tuple[int]) -> dict:
for topology_key in self.key_map:
if topology_key.atom_indices == atom_indices:
potential_key = self.key_map[topology_key]
potential = self.potentials[potential_key]
parameters = potential.parameters
return parameters
raise MissingParametersError(
f"Could not find parameter in parameter in handler {self.type} "
f"associated with atoms {atom_indices}",
)
[docs] def get_force_field_parameters(
self,
use_jax: bool = False,
) -> Union["ArrayLike", "Array"]:
"""Return a flattened representation of the force field parameters."""
# TODO: Handle WrappedPotential
if any(
isinstance(potential, WrappedPotential)
for potential in self.potentials.values()
):
raise NotImplementedError
if use_jax:
return jax_numpy.array(
[
[v.m for v in p.parameters.values()]
for p in self.potentials.values()
],
)
else:
return numpy.array(
[
[v.m for v in p.parameters.values()]
for p in self.potentials.values()
],
)
[docs] def set_force_field_parameters(self, new_p: "ArrayLike") -> None:
"""Set the force field parameters from a flattened representation."""
mapping = self.get_mapping()
if new_p.shape[0] != len(mapping): # type: ignore
raise RuntimeError
for potential_key, potential_index in self.get_mapping().items():
potential = self.potentials[potential_key]
if len(new_p[potential_index, :]) != len(potential.parameters): # type: ignore
raise RuntimeError
for parameter_index, parameter_key in enumerate(potential.parameters):
parameter_units = potential.parameters[parameter_key].units
modified_parameter = new_p[potential_index, parameter_index] # type: ignore
self.potentials[potential_key].parameters[parameter_key] = (
modified_parameter * parameter_units
)
[docs] def get_system_parameters(
self,
p=None,
use_jax: bool = False,
) -> Union["ArrayLike", "Array"]:
"""
Return a flattened representation of system parameters.
These values are effectively force field parameters as applied to a chemical topology.
"""
# TODO: Handle WrappedPotential
if any(
isinstance(potential, WrappedPotential)
for potential in self.potentials.values()
):
raise NotImplementedError
if p is None:
p = self.get_force_field_parameters(use_jax=use_jax)
mapping = self.get_mapping()
q: list = list()
for potential_key in self.key_map.values():
index = mapping[potential_key]
q.append(p[index])
if use_jax:
return jax_numpy.array(q)
else:
return numpy.array(q)
[docs] def get_mapping(self) -> dict[PotentialKey, int]:
"""Get a mapping between potentials and array indices."""
mapping: dict = dict()
index = 0
for potential_key in self.key_map.values():
if potential_key not in mapping:
mapping[potential_key] = index
index += 1
return mapping
[docs] def parametrize(
self,
p=None,
use_jax: bool = True,
) -> Union["ArrayLike", "Array"]:
"""Return an array of system parameters, given an array of force field parameters."""
if p is None:
p = self.get_force_field_parameters(use_jax=use_jax)
return self.get_system_parameters(p=p, use_jax=use_jax)
[docs] def parametrize_partial(self):
"""Return a function that will call `self.parametrize()` with arguments specified by `self.mapping`."""
from functools import partial
return partial(
self.parametrize,
mapping=self.get_mapping(),
)
[docs] @requires_package("jax")
def get_param_matrix(self) -> Union["Array", "ArrayLike"]:
"""Get a matrix representing the mapping between force field and system parameters."""
from functools import partial
import jax
p = self.get_force_field_parameters(use_jax=True)
parametrize_partial = partial(
self.parametrize,
)
jac_parametrize = jax.jacfwd(parametrize_partial)
jac_res = jac_parametrize(p)
return jac_res.reshape(-1, p.flatten().shape[0]) # type: ignore[union-attr]
def __getattr__(self, attr: str):
if attr == "slot_map":
warnings.warn(
"The `slot_map` attribute is deprecated. Use `key_map` instead.",
InterchangeDeprecationWarning,
stacklevel=2,
)
return self.key_map
else:
return super().__getattribute__(attr)