Source code for openff.nagl.features.bonds

"""Bond features for GNN models.

A bond featurization scheme is a tuple of instances of the classes in this
module:

>>> bond_features = (
...     BondIsAromatic(),
...     BondOrder(),
...     BondInRingOfSize(3),
...     BondInRingOfSize(4),
...     BondInRingOfSize(5),
...     BondInRingOfSize(6),
...     ...
... )

The :py:class:`BondFeature` and :py:class:`BondFeatureMeta` classes may be used
to implement your own features.

"""

# from typing import ClassVar, Dict, Type
import typing

import torch


from ._base import CategoricalMixin, Feature #, FeatureMeta
from ._utils import one_hot_encode

try:
    from pydantic.v1 import Field
except ImportError:
    from pydantic import Field

__all__ = [
    "BondFeature",
    "BondIsAromatic",
    "BondIsInRing",
    "BondInRingOfSize",
    "WibergBondOrder",
    "BondOrder",
]


# class _BondFeatureMeta(FeatureMeta):
#     """Metaclass for registering bond features for string lookup."""

#     registry: ClassVar[Dict[str, Type]] = {}


[docs]class BondFeature(Feature):#, metaclass=_BondFeatureMeta): """Abstract base class for features of bonds. See :py:class:`Feature<openff.nagl.features.Feature>` for details on how to implement your own bond features. """ pass
[docs]class BondIsAromatic(BondFeature): """One-hot encoding for whether the bond is aromatic or not.""" name: typing.Literal["bond_is_aromatic"] = "bond_is_aromatic" def _encode(self, molecule) -> torch.Tensor: return torch.tensor([bool(bond.is_aromatic) for bond in molecule.bonds])
[docs]class BondIsInRing(BondFeature): """ One-hot encoding for whether the bond is in a ring of any size. See Also -------- BondInRingOfSize """ name: typing.Literal["bond_is_in_ring"] = "bond_is_in_ring" def _encode(self, molecule) -> torch.Tensor: from openff.nagl.toolkits.openff import get_openff_molecule_bond_indices ring_bonds = { tuple(sorted(match)) for match in molecule.chemical_environment_matches("[*:1]@[*:2]") } molecule_bonds = get_openff_molecule_bond_indices(molecule) tensor = torch.tensor([bool(bond in ring_bonds) for bond in molecule_bonds]) return tensor
[docs]class BondInRingOfSize(BondFeature): """ One-hot encoding for whether the bond is in a ring of the given size. The size of the ring is specified by the argument. For a ring of any size, see :py:class:`BondIsInRing`. To produce features corresponding to rings of multiple sizes, provide this feature multiple times: >>> bond_features = ( ... BondInRingOfSize(3), ... BondInRingOfSize(4), ... BondInRingOfSize(5), ... BondInRingOfSize(6), ... ... ... ) See Also -------- BondIsInRing, AtomIsInRingOfSize, AtomIsInRing """ name: typing.Literal["bond_in_ring_of_size"] = "bond_in_ring_of_size" ring_size: int def _encode(self, molecule) -> torch.Tensor: from openff.nagl.toolkits.openff import get_bonds_are_in_ring_size is_in_ring = get_bonds_are_in_ring_size(molecule, self.ring_size) return torch.tensor(is_in_ring, dtype=int)
[docs]class WibergBondOrder(BondFeature): """ The Wiberg fractional bond order of the bond. This feature encodes the Wiberg bond order directly, it does not use a one-hot encoding. """ name: typing.Literal["wiberg_bond_order"] = "wiberg_bond_order" def _encode(self, molecule) -> torch.Tensor: return torch.tensor([bond.fractional_bond_order for bond in molecule.bonds])
[docs]class BondOrder(CategoricalMixin, BondFeature): """ One-hot encoding of the bond order. The bond order is also known as the degree of the bond; a single bond has bond order 1, a double bond has bond order 2, etc. By default, one-hot encodings are provided for all of the formal charges in the :py:data:`categories` field. To cover a different list of charges, provide that list as an argument to the feature: >>> bond_features = ( ... BondOrder([1, 2, 3, 4]), ... ... ... ) """ name: typing.Literal["bond_order"] = "bond_order" categories = [1, 2, 3] def _encode(self, molecule) -> torch.Tensor: return torch.vstack( [ one_hot_encode(int(bond.bond_order), self.categories) for bond in molecule.bonds ] )
BondFeatureType = typing.Union[ BondIsAromatic, BondIsInRing, BondInRingOfSize, WibergBondOrder, BondOrder, ] DiscriminatedBondFeatureType = typing.Annotated[ BondFeatureType, Field(..., discriminator="name") ]