Source code for openff.nagl.features._base
import abc
import typing
from openff.nagl._base.metaregistry import create_registry_metaclass
from .._base.base import ImmutableModel
try:
from pydantic.v1 import validator
from pydantic.v1.main import ModelMetaclass
except ImportError:
from pydantic import validator
from pydantic.main import ModelMetaclass
if typing.TYPE_CHECKING:
import torch
from openff.toolkit.topology import Molecule
# class FeatureMeta(ModelMetaclass, create_registry_metaclass("feature_name")):
# registry: ClassVar[Dict[str, Type]] = {}
# def __init__(cls, name, bases, namespace, **kwargs):
# super().__init__(name, bases, namespace, **kwargs)
# try:
# key = namespace.get(cls._key_attribute)
# except AttributeError:
# key = None
# else:
# if not key:
# key = name
# if key is not None:
# key = cls._key_transform(key)
# cls.registry[key] = cls
# setattr(cls, cls._key_attribute, key)
[docs]class Feature(ImmutableModel, abc.ABC):
"""
Abstract base class for atom and bond features.
Features with length one can simply inherit :py:class:`AtomFeature
<openff.nagl.features.atoms.AtomFeature>` or
:py:class:`BondFeature <openff.nagl.features.bonds.BondFeature>`,
implement :py:class:`_encode <encode>`, and define
:py:attr:`name`. Complex features should additionally define the
:py:attr:`_feature_length` class attribute and set it to the length of the
feature.
See Also
========
openff.nagl.features.atoms.AtomFeature, openff.nagl.features.bonds.BondFeature
"""
name: typing.Literal[""]
"""Define a name for the feature"""
_feature_length: typing.ClassVar[int] = 1
def __init__(self, *args, **kwargs):
if not kwargs and args:
if len(self.__fields__) == len(args):
kwargs = dict(zip(self.__fields__, args))
args = tuple()
super().__init__(*args, **kwargs)
@classmethod
def _with_args(cls, *args):
if len(cls.__fields__) != len(args):
raise ValueError("Wrong number of arguments")
kwargs = dict(zip(cls.__fields__, args))
return cls(**kwargs)
[docs] def encode(self, molecule: "Molecule") -> "torch.Tensor":
"""
Encode the molecule feature into a tensor.
The output of this method must have shape :py:attr:`tensor_shape`.
Subclasses may instead implement a ``_encode`` method with the same
signature as this one. The default implementation of this method
will call that one and guarantee an appropriate shape.
"""
return self._encode(molecule).reshape(self.tensor_shape)
@abc.abstractmethod
def _encode(self, molecule: "Molecule") -> "torch.Tensor":
"""
Encode the molecule feature into a tensor.
"""
@property
def tensor_shape(self):
"""
Return the shape of the feature tensor.
"""
return (-1, len(self))
def __call__(self, molecule) -> "torch.Tensor":
return self.encode(molecule)
def __len__(self):
"""
Return the length of the feature.
"""
return self._feature_length
class CategoricalMixin:
"""
Mixin class for categorical features.
"""
categories: typing.List[typing.Any]
@property
def _default_categories(self):
return self.__fields__["categories"].default
def __len__(self):
return len(self.categories)