import abc
from typing import Dict, Set, Type, Union
from chemper.graphs.environment import ChemicalEnvironment
from openff.toolkit.typing.engines.smirnoff import (
AngleHandler,
BondHandler,
ImproperTorsionHandler,
ParameterType,
ProperTorsionHandler,
vdWHandler,
)
from typing_extensions import Literal
from openff.bespokefit._pydantic import Field, PositiveFloat, SchemaBase, validator
[docs]def validate_smirks(smirks: str, expected_tags: int) -> str:
"""
Make sure the supplied smirks has the correct number of tagged atoms.
"""
smirk = ChemicalEnvironment(smirks=smirks)
tagged_atoms = len(smirk.get_indexed_atoms())
assert tagged_atoms == expected_tags, (
f"The smirks pattern ({smirks}) has {tagged_atoms} tagged atoms, but should "
f"have {expected_tags}."
)
return smirks
[docs]class BaseSMIRKSParameter(SchemaBase, abc.ABC):
"""
This schema identifies new smirks patterns and the corresponding atoms they should
be applied to.
"""
type: Literal["base"] = "base"
smirks: str = Field(
...,
description="The SMIRKS pattern that defines which chemical environment the "
"parameter should be applied to.",
)
attributes: Set[str] = Field(
..., description="The attributes of the parameter which should be optimized."
)
cached: bool = Field(
False,
description="If the parameter was reused from a local cache rather than fit.",
)
@classmethod
@abc.abstractmethod
def _expected_n_tags(cls) -> int:
raise NotImplementedError()
@validator("smirks")
def _check_smirks(cls, value: str) -> str:
return validate_smirks(value, cls._expected_n_tags())
[docs] @classmethod
@abc.abstractmethod
def from_smirnoff(cls, parameter: ParameterType):
"""Creates a version of this class from a SMIRNOFF parameter"""
def __eq__(self, other):
return type(self) is type(other) and self.__hash__() == other.__hash__()
def __ne__(self, other):
assert not self.__eq__(other)
def __hash__(self):
return hash((self.type, self.smirks, self.cached, tuple(self.attributes)))
[docs]class BaseSMIRKSHyperparameters(SchemaBase, abc.ABC):
"""A data class to track how the target will effect the target parameters and the
prior values/ starting values.
"""
type: Literal["base"] = "base"
priors: Dict[str, PositiveFloat] = Field(..., description="")
[docs] @classmethod
@abc.abstractmethod
def offxml_tag(cls) -> str:
"""The OFFXML tag that wraps this parameter type."""
raise NotImplementedError()
[docs]class VdWSMIRKS(BaseSMIRKSParameter):
type: Literal["vdW"] = "vdW"
attributes: Set[Literal["epsilon", "sigma"]] = Field(
..., description="The attributes of the parameter which should be optimized."
)
@classmethod
def _expected_n_tags(cls) -> int:
return 1
[docs] @classmethod
def from_smirnoff(cls, parameter: vdWHandler.vdWType) -> "VdWSMIRKS":
return cls(
smirks=parameter.smirks,
attributes={"epsilon", "sigma"},
cached=getattr(parameter, "_cached", False),
)
[docs]class VdWHyperparameters(BaseSMIRKSHyperparameters):
type: Literal["vdW"] = "vdW"
priors: Dict[Literal["epsilon", "sigma"], PositiveFloat] = Field(
{"epsilon": 0.1, "sigma": 0.1}, description=""
)
[docs] @classmethod
def offxml_tag(cls) -> str:
return "Atom"
[docs]class BondSMIRKS(BaseSMIRKSParameter):
type: Literal["Bonds"] = "Bonds"
attributes: Set[Literal["k", "length"]] = Field(
..., description="The attributes of the parameter which should be optimized."
)
@classmethod
def _expected_n_tags(cls) -> int:
return 2
[docs] @classmethod
def from_smirnoff(cls, parameter: BondHandler.BondType) -> "BondSMIRKS":
return cls(
smirks=parameter.smirks,
attributes={"k", "length"},
cached=getattr(parameter, "_cached", False),
)
[docs]class BondHyperparameters(BaseSMIRKSHyperparameters):
type: Literal["Bonds"] = "Bonds"
priors: Dict[Literal["k", "length"], PositiveFloat] = Field(
{"k": 100.0, "length": 0.1}, description=""
)
[docs] @classmethod
def offxml_tag(cls) -> str:
return "Bond"
[docs]class AngleSMIRKS(BaseSMIRKSParameter):
type: Literal["Angles"] = "Angles"
attributes: Set[Literal["k", "angle"]] = Field(
..., description="The attributes of the parameter which should be optimized."
)
@classmethod
def _expected_n_tags(cls) -> int:
return 3
[docs] @classmethod
def from_smirnoff(cls, parameter: AngleHandler.AngleType) -> "AngleSMIRKS":
return cls(
smirks=parameter.smirks,
attributes={"k", "angle"},
cached=getattr(parameter, "_cached", False),
)
[docs]class AngleHyperparameters(BaseSMIRKSHyperparameters):
type: Literal["Angles"] = "Angles"
priors: Dict[Literal["k", "angle"], PositiveFloat] = Field(
{"k": 10.0, "angle": 10.0}, description=""
)
[docs] @classmethod
def offxml_tag(cls) -> str:
return "Angle"
# TODO: This can likely be more cleanly handled by a pydantic regex type.
# fmt: off
ProperTorsionAttribute = Literal[
"k", "k1_bondorder", "k1_bondorder", "periodicity", "phase", "idivf",
"k1", "k1_bondorder1", "k1_bondorder2", "periodicity1", "phase1", "idivf1",
"k2", "k2_bondorder1", "k2_bondorder2", "periodicity2", "phase2", "idivf2",
"k3", "k3_bondorder1", "k3_bondorder2", "periodicity3", "phase3", "idivf3",
"k4", "k4_bondorder1", "k4_bondorder2", "periodicity4", "phase4", "idivf4",
"k5", "k5_bondorder1", "k5_bondorder2", "periodicity5", "phase5", "idivf5",
"k6", "k6_bondorder1", "k6_bondorder2", "periodicity6", "phase6", "idivf6",
]
[docs]class ProperTorsionSMIRKS(BaseSMIRKSParameter):
type: Literal["ProperTorsions"] = "ProperTorsions"
attributes: Set[Literal[ProperTorsionAttribute]] = Field(
..., description="The attributes of the parameter which should be optimized."
)
@classmethod
def _expected_n_tags(cls) -> int:
return 4
[docs] @classmethod
def from_smirnoff(
cls, parameter: ProperTorsionHandler.ProperTorsionType
) -> "ProperTorsionSMIRKS":
return cls(
smirks=parameter.smirks,
attributes={f"k{i + 1}" for i in range(len(parameter.k))},
# cosmetic attrs are hidden
cached=getattr(parameter, "_cached", False)
)
[docs]class ProperTorsionHyperparameters(BaseSMIRKSHyperparameters):
type: Literal["ProperTorsions"] = "ProperTorsions"
priors: Dict[ProperTorsionAttribute, PositiveFloat] = Field(
{"k": 6.0}, description=""
)
[docs] @classmethod
def offxml_tag(cls) -> str:
return "Proper"
# fmt: off
ImproperTorsionAttribute = Literal[
"k*", "periodicity*", "phase*", "idivf*",
"k1", "periodicity1", "phase1", "idivf1",
"k2", "periodicity2", "phase2", "idivf2",
"k3", "periodicity3", "phase3", "idivf3",
"k4", "periodicity4", "phase4", "idivf4",
]
[docs]class ImproperTorsionSMIRKS(BaseSMIRKSParameter):
type: Literal["ImproperTorsions"] = "ImproperTorsions"
attributes: Set[Literal[ImproperTorsionAttribute]] = Field(
..., description="The attributes of the parameter which should be optimized."
)
@classmethod
def _expected_n_tags(cls) -> int:
return 4
[docs] @classmethod
def from_smirnoff(
cls, parameter: ImproperTorsionHandler.ImproperTorsionType
) -> "ImproperTorsionSMIRKS":
raise NotImplementedError()
[docs]class ImproperTorsionHyperparameters(BaseSMIRKSHyperparameters):
type: Literal["ImproperTorsions"] = "ImproperTorsions"
priors: Dict[ProperTorsionAttribute, PositiveFloat] = Field(
{"k": 6.0}, description=""
)
[docs] @classmethod
def offxml_tag(cls) -> str:
return "Improper"
SMIRNOFFParameter = Union[
VdWSMIRKS, BondSMIRKS, AngleSMIRKS, ProperTorsionSMIRKS, ImproperTorsionSMIRKS
]
SMIRNOFFHyperparameters = Union[
ProperTorsionHyperparameters,
BondHyperparameters,
AngleHyperparameters,
VdWHyperparameters,
ImproperTorsionHyperparameters,
]
[docs]def get_smirnoff_parameter(parameter_type: str) -> Type[SMIRNOFFParameter]:
"""
A helper function to get the SMIRNOFFParameter class from the parameter type.
"""
_parameters_by_type = {
VdWSMIRKS.__fields__["type"].default: VdWSMIRKS,
BondSMIRKS.__fields__["type"].default: BondSMIRKS,
AngleSMIRKS.__fields__["type"].default: AngleSMIRKS,
ProperTorsionSMIRKS.__fields__["type"].default: ProperTorsionSMIRKS,
ImproperTorsionSMIRKS.__fields__["type"].default: ImproperTorsionSMIRKS
}
return _parameters_by_type[parameter_type]