Source code for openff.nagl.nn.postprocess

"""
Postprocessing functions to convert a graph representation to a predicted property
"""

import abc
from typing import ClassVar, Dict, Type, Union

import torch

from openff.nagl._base.metaregistry import create_registry_metaclass
from openff.nagl.molecule._dgl import DGLMolecule, DGLMoleculeBatch


class _PostprocessLayerMeta(abc.ABCMeta, create_registry_metaclass()):
    """Metaclass for registering post-processing layers for string lookup."""
    registry: ClassVar[Dict[str, Type]] = {}

    def __init__(cls, name, bases, namespace, **kwargs):
        super().__init__(name, bases, namespace, **kwargs)
        if hasattr(cls, "name") and cls.name:
            cls.registry[cls.name.lower()] = cls

    # @classmethod
    # def get_layer_class(cls, class_name: str):
    #     if isinstance(class_name, cls):
    #         return class_name
    #     if isinstance(type(class_name), cls):
    #         return type(class_name)
    #     try:
    #         return cls.registry[class_name.lower()]
    #     except KeyError:
    #         raise ValueError(
    #             f"Unknown PostprocessLayer type: {class_name}. "
    #             f"Supported types: {list(cls.registry.keys())}"
    #         )


[docs]class PostprocessLayer(torch.nn.Module, abc.ABC, metaclass=_PostprocessLayerMeta): """A layer to apply to the final readout of a neural network.""" name: ClassVar[str] = "" n_features: ClassVar[int] = 0
[docs] @abc.abstractmethod def forward( self, molecule: Union[DGLMolecule, DGLMoleculeBatch], inputs: torch.Tensor ) -> torch.Tensor: """Returns the post-processed input vector."""
[docs]class ComputePartialCharges(PostprocessLayer): """ Maps a set of atomic electronegativity and hardness parameters to partial charges. References ---------- 1. Gilson, Michael K.; Gilson, Hillary S.R.; Potter, Michael J. "Fast assignment of accurate partial atomic charges: an electronegativity equalization method that accounts for alternate resonance forms." `Journal of Chemical Information and Computer Sciences <https://doi.org/10.1021/ci034148o>`_ 43.6 (2003): 1982-1997. 2. Wang, Yuanqing; Fass, Josh; Stern, Chaya D.; Luo, Kun; Chodera, John D. "Graph Nets for Partial Charge Prediction." `arXiv:1909.07903 [physics.comp-ph] <https://doi.org/10.48550/arXiv.1909.07903>`_ """ name: ClassVar[str] = "compute_partial_charges" n_features: ClassVar[int] = 2 @staticmethod def _calculate_partial_charges( electronegativity: torch.Tensor, hardness: torch.Tensor, total_charge: float, ) -> torch.Tensor: """ Equation borrowed from Wang et al's preprint on Espaloma (Eq 15) """ inverse_hardness = 1.0 / hardness e_over_s = electronegativity * inverse_hardness numerator = e_over_s.sum() + total_charge denominator = inverse_hardness.sum() fraction = inverse_hardness * (numerator / denominator) charges = (-e_over_s + fraction).reshape(-1, 1) return charges
[docs] def forward( self, molecule: Union[DGLMolecule, DGLMoleculeBatch], inputs: torch.Tensor, ) -> torch.Tensor: electronegativity = inputs[:, 0] hardness = inputs[:, 1] formal_charges = molecule.graph.ndata["formal_charge"] all_charges = [] counter = 0 for n_atoms, n_representations in zip( molecule.n_atoms_per_molecule, molecule.n_representations_per_molecule, ): n_atoms = int(n_atoms) representation_charges = [] for i in range(n_representations): atom_slice = slice(counter, counter + n_atoms) counter += n_atoms charges = self._calculate_partial_charges( electronegativity[atom_slice], hardness[atom_slice], formal_charges[atom_slice].sum(), ) representation_charges.append(charges) mean_charges = torch.stack(representation_charges).mean(dim=0) all_charges.append(mean_charges) return torch.vstack(all_charges)
[docs]class RegularizedComputePartialCharges(PostprocessLayer): """ Maps a set of initial charges, atomic electronegativity, and hardness parameters to partial charges. This is a modification of the :class:`ComputePartialCharges`. References ---------- 1. Gilson, Michael K.; Gilson, Hillary S.R.; Potter, Michael J. "Fast assignment of accurate partial atomic charges: an electronegativity equalization method that accounts for alternate resonance forms." `Journal of Chemical Information and Computer Sciences <https://doi.org/10.1021/ci034148o>`_ 43.6 (2003): 1982-1997. 2. Wang, Yuanqing; Fass, Josh; Stern, Chaya D.; Luo, Kun; Chodera, John D. "Graph Nets for Partial Charge Prediction." `arXiv:1909.07903 [physics.comp-ph] <https://doi.org/10.48550/arXiv.1909.07903>`_ """ name: ClassVar[str] = "regularized_compute_partial_charges" n_features: ClassVar[int] = 3 @staticmethod def _calculate_partial_charges( charge_priors: torch.Tensor, electronegativity: torch.Tensor, hardness: torch.Tensor, total_charge: float, ) -> torch.Tensor: """ Equation borrowed from Wang et al's preprint on Espaloma (Eq 15) """ total_prior_charge = charge_priors.sum() inverse_hardness = 1.0 / hardness e_over_s = electronegativity * inverse_hardness numerator = total_prior_charge - total_charge - e_over_s.sum() denominator = inverse_hardness.sum() fraction = inverse_hardness * (numerator / denominator) charges = (charge_priors - e_over_s - fraction).reshape(-1, 1) return charges
[docs] def forward( self, molecule: Union[DGLMolecule, DGLMoleculeBatch], inputs: torch.Tensor, ) -> torch.Tensor: charge_priors = inputs[:, 0] electronegativity = inputs[:, 1] hardness = inputs[:, 2] formal_charges = molecule.graph.ndata["formal_charge"] all_charges = [] counter = 0 for n_atoms, n_representations in zip( molecule.n_atoms_per_molecule, molecule.n_representations_per_molecule, ): n_atoms = int(n_atoms) representation_charges = [] for i in range(n_representations): atom_slice = slice(counter, counter + n_atoms) counter += n_atoms charges = self._calculate_partial_charges( charge_priors[atom_slice], electronegativity[atom_slice], hardness[atom_slice], formal_charges[atom_slice].sum(), ) representation_charges.append(charges) mean_charges = torch.stack(representation_charges) mean_charges = mean_charges.mean(dim=0) all_charges.append(mean_charges) return torch.vstack(all_charges)