import collections
import logging
import types
from typing import TYPE_CHECKING, Tuple, Dict, Union, Callable, Literal, Optional
import warnings
import torch
import pytorch_lightning as pl
from openff.utilities.exceptions import MissingOptionalDependencyError
from openff.nagl.nn._containers import ConvolutionModule, ReadoutModule
from openff.nagl.config.model import ModelConfig
from openff.nagl.domains import ChemicalDomain
from openff.nagl.lookups import LookupTableType, _as_lookup_table
from openff.nagl.utils._utils import potential_dict_to_list
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from openff.toolkit.topology import Molecule
from openff.nagl.molecule._dgl import DGLMoleculeOrBatch
class BaseGNNModel(pl.LightningModule):
def __init__(
self,
convolution_module: ConvolutionModule,
readout_modules: ReadoutModule,
):
super().__init__()
self.convolution_module = convolution_module
self.readout_modules = torch.nn.ModuleDict(readout_modules)
def forward(
self, molecule: "DGLMoleculeOrBatch"
) -> Dict[str, torch.Tensor]:
self.convolution_module(molecule)
readouts: Dict[str, torch.Tensor] = {
readout_type: readout_module(molecule)
for readout_type, readout_module in self.readout_modules.items()
}
return readouts
def _forward_unpostprocessed(self, molecule: "DGLMoleculeOrBatch"):
"""
Forward pass without postprocessing the readout modules.
This is quality-of-life method for debugging and testing.
It is *not* intended for public use.
"""
self.convolution_module(molecule)
readouts: Dict[str, torch.Tensor] = {
readout_type: readout_module._forward_unpostprocessed(molecule)
for readout_type, readout_module in self.readout_modules.items()
}
return readouts
[docs]class GNNModel(BaseGNNModel):
"""
A GNN model for predicting properties of molecules.
Parameters
----------
config: ModelConfig or dict
The configuration for the model.
chemical_domain: ChemicalDomain or dict
The applicable chemical domain for the model.
lookup_tables: dict
A dictionary of lookup tables for properties.
The keys should be the property names, and the values
should be instances of :class:`~openff.nagl.lookups.BaseLookupTable`.
"""
def __init__(
self,
config: ModelConfig,
chemical_domain: Optional[ChemicalDomain] = None,
lookup_tables: dict[str, LookupTableType] = None,
):
if not isinstance(config, ModelConfig):
config = ModelConfig(**config)
if chemical_domain is None:
chemical_domain = ChemicalDomain(
allowed_elements=tuple(),
forbidden_patterns=tuple(),
)
elif not isinstance(chemical_domain, ChemicalDomain):
chemical_domain = ChemicalDomain(**chemical_domain)
convolution_module = ConvolutionModule.from_config(
config.convolution,
n_input_features=config.n_atom_features,
)
readout_modules = {}
for readout_name, readout_config in config.readouts.items():
readout_modules[readout_name] = ReadoutModule.from_config(
readout_config,
n_input_features=config.convolution.layers[-1].hidden_feature_size,
)
valid_lookup_tables = {}
if not lookup_tables:
lookup_tables = {}
# allow an iterable of lookup tables
lookup_tables = potential_dict_to_list(lookup_tables)
for lookup_table in lookup_tables:
lookup_table = _as_lookup_table(lookup_table)
if not lookup_table.property_name in readout_modules:
raise ValueError(
f"The lookup table property name {lookup_table.property_name} "
f"is not in the readout modules."
)
valid_lookup_tables[lookup_table.property_name] = lookup_table
super().__init__(
convolution_module=convolution_module,
readout_modules=readout_modules,
)
lookup_tables_dict = {}
for k, v in valid_lookup_tables.items():
v_ = v.dict()
v_["properties"] = dict(v_["properties"])
lookup_tables_dict[k] = v_
self.save_hyperparameters({
"config": config.dict(),
"chemical_domain": chemical_domain.dict(),
"lookup_tables": lookup_tables_dict,
})
self.config = config
self.chemical_domain = chemical_domain
self.lookup_tables = types.MappingProxyType(valid_lookup_tables)
[docs] @classmethod
def from_yaml(cls, filename):
config = ModelConfig.from_yaml(filename)
return cls(config)
@property
def _is_dgl(self):
return self.convolution_module._is_dgl
def _as_nagl(self):
copied = type(self)(self.config)
copied.convolution_module = self.convolution_module._as_nagl(copy_weights=True)
copied.load_state_dict(self.state_dict())
return copied
[docs] def compute_properties(
self,
molecule: "Molecule",
as_numpy: bool = True,
check_domains: bool = False,
error_if_unsupported: bool = True,
check_lookup_table: bool = True,
) -> Dict[str, torch.Tensor]:
"""
Compute the trained property for a molecule.
Parameters
----------
molecule: :class:`~openff.toolkit.topology.Molecule`
The molecule to compute the property for.
as_numpy: bool
Whether to return the result as a numpy array.
If ``False``, the result will be a ``torch.Tensor``.
check_domains: bool
Whether to check if the molecule is similar
to the training data.
error_if_unsupported: bool
Whether to raise an error if the molecule
is not represented in the training data.
This is only used if ``check_domains`` is ``True``.
If ``False``, a warning will be raised instead.
check_lookup_table: bool
Whether to check a lookup table for the property values.
If ``False`` or if the molecule is not in the lookup
table, the property will be computed using the model.
Returns
-------
result: Dict[str, torch.Tensor] or Dict[str, numpy.ndarray]
"""
import numpy as np
# split up molecule in case it's fragments
from openff.nagl.toolkits.openff import split_up_molecule
fragments, all_indices = split_up_molecule(molecule)
# TODO: this assumes atom-wise properties
# we should add support for bond-wise/more general properties
results = [
self._compute_properties(
fragment,
as_numpy=as_numpy,
check_domains=check_domains,
error_if_unsupported=error_if_unsupported,
check_lookup_table=check_lookup_table,
)
for fragment in fragments
]
# combine the results
combined_results = {}
if as_numpy:
tensor = np.empty
else:
tensor = torch.empty
for property_name, value in results[0].items():
combined_results[property_name] = tensor(
molecule.n_atoms,
dtype=value.dtype
)
seen_indices = collections.defaultdict(set)
for result, indices in zip(results, all_indices):
for property_name, value in result.items():
combined_results[property_name][indices] = value
if seen_indices[property_name] & set(indices):
raise ValueError(
"Overlapping indices in the fragments"
)
seen_indices[property_name].update(indices)
expected_indices = list(range(molecule.n_atoms))
for property_name, seen_indices in seen_indices.items():
assert sorted(seen_indices) == expected_indices, (
f"Missing indices for property {property_name}: "
f"{set(expected_indices) - seen_indices}"
)
return combined_results
def _compute_properties(
self,
molecule: "Molecule",
as_numpy: bool = True,
check_domains: bool = False,
error_if_unsupported: bool = True,
check_lookup_table: bool = True,
) -> Dict[str, torch.Tensor]:
"""
Compute the trained property for a molecule.
Parameters
----------
molecule: :class:`~openff.toolkit.topology.Molecule`
The molecule to compute the property for.
as_numpy: bool
Whether to return the result as a numpy array.
If ``False``, the result will be a ``torch.Tensor``.
check_domains: bool
Whether to check if the molecule is similar
to the training data.
error_if_unsupported: bool
Whether to raise an error if the molecule
is not represented in the training data.
This is only used if ``check_domains`` is ``True``.
If ``False``, a warning will be raised instead.
check_lookup_table: bool
Whether to check a lookup table for the property values.
If ``False`` or if the molecule is not in the lookup
table, the property will be computed using the model.
Returns
-------
result: Dict[str, torch.Tensor] or Dict[str, numpy.ndarray]
"""
values = {}
expected_value_keys = list(self.readout_modules.keys())
if check_lookup_table and self.lookup_tables:
for property_name in expected_value_keys:
try:
value = self._check_property_lookup_table(
molecule=molecule,
readout_name=property_name,
)
except KeyError as e:
logger.info(
f"Could not find property in lookup table: {e}"
)
continue
else:
logger.info(
f"Using lookup table for property {property_name}"
)
values[property_name] = value
computed_value_keys = set(values.keys())
if computed_value_keys == set(expected_value_keys):
if as_numpy:
values = {k: v.detach().numpy().flatten() for k, v in values.items()}
return values
if check_domains:
is_supported, error = self.chemical_domain.check_molecule(
molecule, return_error_message=True
)
if not is_supported:
if error_if_unsupported:
raise ValueError(error)
else:
warnings.warn(error)
try:
values = self._compute_properties_dgl(molecule)
except (MissingOptionalDependencyError, TypeError):
values = self._compute_properties_nagl(molecule)
if as_numpy:
values = {k: v.detach().numpy().flatten() for k, v in values.items()}
return values
def _check_property_lookup_table(
self,
molecule: "Molecule",
readout_name: str,
):
"""
Check if the molecule is in the property lookup table.
Parameters
----------
molecule: :class:`~openff.toolkit.topology.Molecule`
The molecule to check.
readout_name: str
The name of the readout to check.
Returns
-------
torch.Tensor
Raises
------
KeyError
If there is no table for this property, or
if the molecule is not in the property lookup table
"""
table = self.lookup_tables[readout_name]
return table.lookup(molecule)
[docs] def compute_property(
self,
molecule: "Molecule",
readout_name: Optional[str] = None,
as_numpy: bool = True,
check_domains: bool = False,
error_if_unsupported: bool = True,
check_lookup_table: bool = True
):
"""
Compute the trained property for a molecule.
Parameters
----------
molecule: :class:`~openff.toolkit.topology.Molecule`
The molecule to compute the property for.
readout_name: str
The name of the readout property to return.
If this is not given and there is only one readout,
the result of that readout will be returned.
as_numpy: bool
Whether to return the result as a numpy array.
If ``False``, the result will be a ``torch.Tensor``.
check_domains: bool
Whether to check if the molecule is similar
to the training data.
error_if_unsupported: bool
Whether to raise an error if the molecule
is not represented in the training data.
This is only used if ``check_domains`` is ``True``.
If ``False``, a warning will be raised instead.
check_lookup_table: bool
Whether to check a lookup table for the property values.
If ``False`` or if the molecule is not in the lookup
table, the property will be computed using the model.
Returns
-------
result: torch.Tensor or numpy.ndarray
"""
properties = self.compute_properties(
molecule=molecule,
as_numpy=as_numpy,
check_domains=check_domains,
error_if_unsupported=error_if_unsupported,
check_lookup_table=check_lookup_table
)
if readout_name is None:
if len(properties) == 1:
return next(iter(properties.values()))
raise ValueError(
"The readout name must be specified if the model has multiple readouts"
)
return properties[readout_name]
def _compute_properties_nagl(self, molecule: "Molecule") -> "torch.Tensor":
from openff.nagl.molecule._graph.molecule import GraphMolecule
nxmol = GraphMolecule.from_openff(
molecule,
atom_features=self.config.atom_features,
bond_features=self.config.bond_features,
)
model = self
if self._is_dgl:
model = self._as_nagl()
return model.forward(nxmol)
def _compute_properties_dgl(self, molecule: "Molecule") -> "torch.Tensor":
from openff.nagl.molecule._dgl.molecule import DGLMolecule
if not self._is_dgl:
raise TypeError(
"This model is not a DGL-based model "
"and cannot be used to compute properties with the DGL backend"
)
dglmol = DGLMolecule.from_openff(
molecule,
atom_features=self.config.atom_features,
bond_features=self.config.bond_features,
)
return self.forward(dglmol)
def _convert_to_nagl_molecule(self, molecule: "Molecule"):
from openff.nagl.molecule._graph.molecule import GraphMolecule
if self._is_dgl:
from openff.nagl.molecule._dgl.molecule import DGLMolecule
return DGLMolecule.from_openff(
molecule,
atom_features=self.config.atom_features,
bond_features=self.config.bond_features,
)
return GraphMolecule.from_openff(
molecule,
atom_features=self.config.atom_features,
bond_features=self.config.bond_features,
)
[docs] @classmethod
def load(cls, model: str, eval_mode: bool = True, **kwargs):
"""
Load a model from a file.
Parameters
----------
model: str
The path to the model to load.
This should be a file containing a dictionary of
hyperparameters and a state dictionary,
with the keys "hyperparameters" and "state_dict".
This can be created using the `save` method.
eval_mode: bool
Whether to set the model to evaluation mode.
**kwargs
Additional keyword arguments to pass to `torch.load`.
Returns
-------
model: GNNModel
Examples
--------
>>> model.save("model.pt")
>>> new_model = GNNModel.load("model.pt")
Notes
-----
This method is not compatible with normal Pytorch
models saved with ``torch.save``, as it expects
a dictionary of hyperparameters and a state dictionary.
"""
model_kwargs = torch.load(str(model), weights_only=False, **kwargs)
if isinstance(model_kwargs, dict):
model = cls(**model_kwargs["hyperparameters"])
model.load_state_dict(model_kwargs["state_dict"])
elif isinstance(model_kwargs, cls):
model = model_kwargs
else:
raise ValueError(f"Unknown model type {type(model_kwargs)}")
if eval_mode:
model.eval()
return model
[docs] def save(self, path: str):
"""
Save this model to a file.
Parameters
----------
path: str
The path to save this file to.
Examples
--------
>>> model.save("model.pt")
>>> new_model = GNNModel.load("model.pt")
Notes
-----
This method writes a dictionary of the hyperparameters and the state dictionary,
with the keys "hyperparameters" and "state_dict".
"""
torch.save(
{
"hyperparameters": self.hparams,
"state_dict": self.state_dict(),
},
str(path),
)