Source code for openff.nagl.nn._models

import copy
from typing import TYPE_CHECKING, Tuple, Dict, Union, Callable, Literal, Optional
import warnings

import torch
import pytorch_lightning as pl

from openff.utilities.exceptions import MissingOptionalDependencyError
from openff.nagl.nn._containers import ConvolutionModule, ReadoutModule
from openff.nagl.config.model import ModelConfig
from import ChemicalDomain

    from openff.toolkit.topology import Molecule
    from openff.nagl.molecule._dgl import DGLMoleculeOrBatch

class BaseGNNModel(pl.LightningModule):
    def __init__(
        convolution_module: ConvolutionModule,
        readout_modules: ReadoutModule,
        self.convolution_module = convolution_module
        self.readout_modules = torch.nn.ModuleDict(readout_modules)

    def forward(
        self, molecule: "DGLMoleculeOrBatch"
    ) -> Dict[str, torch.Tensor]:

        readouts: Dict[str, torch.Tensor] = {
            readout_type: readout_module(molecule)
            for readout_type, readout_module in self.readout_modules.items()
        return readouts

    def _forward_unpostprocessed(self, molecule: "DGLMoleculeOrBatch"):
        Forward pass without postprocessing the readout modules.
        This is quality-of-life method for debugging and testing.
        It is *not* intended for public use.
        readouts: Dict[str, torch.Tensor] = {
            readout_type: readout_module._forward_unpostprocessed(molecule)
            for readout_type, readout_module in self.readout_modules.items()
        return readouts

[docs]class GNNModel(BaseGNNModel): def __init__( self, config: ModelConfig, chemical_domain: Optional[ChemicalDomain] = None, ): if not isinstance(config, ModelConfig): config = ModelConfig(**config) if chemical_domain is None: chemical_domain = ChemicalDomain( allowed_elements=tuple(), forbidden_patterns=tuple(), ) elif not isinstance(chemical_domain, ChemicalDomain): chemical_domain = ChemicalDomain(**chemical_domain) convolution_module = ConvolutionModule.from_config( config.convolution, n_input_features=config.n_atom_features, ) readout_modules = {} for readout_name, readout_config in config.readouts.items(): readout_modules[readout_name] = ReadoutModule.from_config( readout_config, n_input_features=config.convolution.layers[-1].hidden_feature_size, ) super().__init__( convolution_module=convolution_module, readout_modules=readout_modules, ) self.save_hyperparameters({ "config": config.dict(), "chemical_domain": chemical_domain.dict(), }) self.config = config self.chemical_domain = chemical_domain
[docs] @classmethod def from_yaml(cls, filename): config = ModelConfig.from_yaml(filename) return cls(config)
@property def _is_dgl(self): return self.convolution_module._is_dgl def _as_nagl(self): copied = type(self)(self.config) copied.convolution_module = self.convolution_module._as_nagl(copy_weights=True) copied.load_state_dict(self.state_dict()) return copied
[docs] def compute_properties( self, molecule: "Molecule", as_numpy: bool = True, check_domains: bool = False, error_if_unsupported: bool = True, ) -> Dict[str, torch.Tensor]: """ Compute the trained property for a molecule. Parameters ---------- molecule: :class:`~openff.toolkit.topology.Molecule` The molecule to compute the property for. as_numpy: bool Whether to return the result as a numpy array. If ``False``, the result will be a ``torch.Tensor``. check_domains: bool Whether to check if the molecule is similar to the training data. error_if_unsupported: bool Whether to raise an error if the molecule is not represented in the training data. This is only used if ``check_domains`` is ``True``. If ``False``, a warning will be raised instead. Returns ------- result: Dict[str, torch.Tensor] or Dict[str, numpy.ndarray] """ if check_domains: is_supported, error = self.chemical_domain.check_molecule( molecule, return_error_message=True ) if not is_supported: if error_if_unsupported: raise ValueError(error) else: warnings.warn(error) try: values = self._compute_properties_dgl(molecule) except (MissingOptionalDependencyError, TypeError): values = self._compute_properties_nagl(molecule) if as_numpy: values = {k: v.detach().numpy().flatten() for k, v in values.items()} return values
[docs] def compute_property( self, molecule: "Molecule", readout_name: Optional[str] = None, as_numpy: bool = True, check_domains: bool = False, error_if_unsupported: bool = True, ): """ Compute the trained property for a molecule. Parameters ---------- molecule: :class:`~openff.toolkit.topology.Molecule` The molecule to compute the property for. readout_name: str The name of the readout property to return. If this is not given and there is only one readout, the result of that readout will be returned. as_numpy: bool Whether to return the result as a numpy array. If ``False``, the result will be a ``torch.Tensor``. check_domains: bool Whether to check if the molecule is similar to the training data. error_if_unsupported: bool Whether to raise an error if the molecule is not represented in the training data. This is only used if ``check_domains`` is ``True``. If ``False``, a warning will be raised instead. Returns ------- result: torch.Tensor or numpy.ndarray """ properties = self.compute_properties( molecule=molecule, as_numpy=as_numpy, check_domains=check_domains, error_if_unsupported=error_if_unsupported, ) if readout_name is None: if len(properties) == 1: return next(iter(properties.values())) raise ValueError( "The readout name must be specified if the model has multiple readouts" ) return properties[readout_name]
def _compute_properties_nagl(self, molecule: "Molecule") -> "torch.Tensor": from openff.nagl.molecule._graph.molecule import GraphMolecule nxmol = GraphMolecule.from_openff( molecule, atom_features=self.config.atom_features, bond_features=self.config.bond_features, ) model = self if self._is_dgl: model = self._as_nagl() return model.forward(nxmol) def _compute_properties_dgl(self, molecule: "Molecule") -> "torch.Tensor": from openff.nagl.molecule._dgl.molecule import DGLMolecule if not self._is_dgl: raise TypeError( "This model is not a DGL-based model " "and cannot be used to compute properties with the DGL backend" ) dglmol = DGLMolecule.from_openff( molecule, atom_features=self.config.atom_features, bond_features=self.config.bond_features, ) return self.forward(dglmol) def _convert_to_nagl_molecule(self, molecule: "Molecule"): from openff.nagl.molecule._graph.molecule import GraphMolecule if self._is_dgl: from openff.nagl.molecule._dgl.molecule import DGLMolecule return DGLMolecule.from_openff( molecule, atom_features=self.config.atom_features, bond_features=self.config.bond_features, ) return GraphMolecule.from_openff( molecule, atom_features=self.config.atom_features, bond_features=self.config.bond_features, )
[docs] @classmethod def load(cls, model: str, eval_mode: bool = True, **kwargs): """ Load a model from a file. Parameters ---------- model: str The path to the model to load. This should be a file containing a dictionary of hyperparameters and a state dictionary, with the keys "hyperparameters" and "state_dict". This can be created using the `save` method. eval_mode: bool Whether to set the model to evaluation mode. **kwargs Additional keyword arguments to pass to `torch.load`. Returns ------- model: GNNModel Examples -------- >>>"") >>> new_model = GNNModel.load("") Notes ----- This method is not compatible with normal Pytorch models saved with ````, as it expects a dictionary of hyperparameters and a state dictionary. """ model_kwargs = torch.load(str(model), **kwargs) if isinstance(model_kwargs, dict): model = cls(**model_kwargs["hyperparameters"]) model.load_state_dict(model_kwargs["state_dict"]) elif isinstance(model_kwargs, cls): model = model_kwargs else: raise ValueError(f"Unknown model type {type(model_kwargs)}") if eval_mode: model.eval() return model
[docs] def save(self, path: str): """ Save this model to a file. Parameters ---------- path: str The path to save this file to. Examples -------- >>>"") >>> new_model = GNNModel.load("") Notes ----- This method writes a dictionary of the hyperparameters and the state dictionary, with the keys "hyperparameters" and "state_dict". """ { "hyperparameters": self.hparams, "state_dict": self.state_dict(), }, str(path), )