Download Notebook View in GitHub Open in Google Colab
Training a Graph Neural Network with NAGL
Dependency installation instructions
Install example dependencies into a new Conda environment using the provided environment.yaml:mamba env create --file ../../devtools/conda-envs/examples_env.yaml --name openff-nagl-examples
mamba activate openff-nagl-examples
jupyter notebook train-gnn-notebook.ipynb
This notebook will go through the process of training a new Graph Neural Network (GNN) on a small dataset of alkanes, and demonstrate inference with the resulting model. On the way, we’ll put together a tiny test dataset, and talk a bit about the architecture of the GNN we’re training.
Imports
from pathlib import Path
import numpy as np
from openff.toolkit import Molecule
from openff.units import unit
Create the model
First, we need a neural network to train! The neural network is called a “model”, because once trained it amounts to a statistical model of some aspect of reality - in this case, a mapping from a molecule to a list of partial charges.
NAGL’s models start by defining a format to describe a molecule. General purpose formats that work well for humans, like SMILES strings or Kekule structures, tend not to work for AI. Humans have a huge amount of background knowledge; we can read letters, we can interpret visual images, and those of us that can read SMILES strings and Kekule structures know a lot about chemistry too. An untrained neural network doesn’t know any of this, and a neural network big enough to learn it would take forever to train and might find it easier to just memorize the partial charges of the molecules in the training set.
Instead, we want to use our human background knowledge to give the neural network as much relevant information as we can, so it can learn to generalize from the training set rather than memorize it. This is often the great challenge in training neural networks: Giving the network just the right information and abilities to glean a relationship from the training data. Too much and it can invent relationships that don’t exist in reality but allow it to reproduce the training set slightly better; too little and it can’t learn anything at all. This is also why it’s important to separate the test, training and validation datasets!
In a graph neural network, we describe a molecule as a graph of atoms. In mathematics, a graph is a collection of nodes that are connected by edges. Lots of everyday systems and objects are easy to describe as graphs: in public transit maps, nodes are stops and edges are routes; in geometric shapes, nodes are vertices and edges are, well, edges; on Twitter, nodes are people and edges are follows (or tweets or likes or all of the above!). In a molecular graph, nodes are atoms and edges are bonds. This allows us to apply computational science techniques developed for graphs to molecules. More on that later!
To describe a graph, we first need to describe its nodes and edges. Then we can do some magic to turn it into a graph. The format we use to describe nodes and edges is a list of numbers that each describe a particular feature of the thing, so the construction of this format is called featurization. This lets us choose exactly what information the network has access to, and it lets us use our chemical background knowledge to provide the network with theoretical information that would help it make its decision, but would be too complicated to learn in training. In this example, this includes features like connectivity and ring size, even though that information is redundant with the graph itself; it turns out that having this information “close to hand” helps the neural network out!
In NAGL, we construct a featurization for atoms by choosing a list of features from those in the features.atoms
module, and for bonds by choosing a list of features from the features.bonds
module. Once it has the featurization, NAGL can apply it to a molecule automatically. A lot of these features are useless in a dataset of acyclic alkanes, but we include them to demonstrate the sorts of features that are available:
from openff.nagl.features import atoms, bonds
atom_features = (
atoms.AtomicElement(["C", "H"]), # Is the atom Carbon or Hydrogen?
atoms.AtomConnectivity(), # Is the atom bonded to 1, 2, 3, or 4 other atoms?
atoms.AtomAverageFormalCharge(), # What is the atom's mean formal charge over the molecule's tautomers?
atoms.AtomHybridization(), # What is the hybridization of the atom?
atoms.AtomInRingOfSize(3), # Is the atom in a 3-membered ring?
atoms.AtomInRingOfSize(4), # Is the atom in a 4-membered ring?
atoms.AtomInRingOfSize(5), # Is the atom in a 5-membered ring?
atoms.AtomInRingOfSize(6), # Is the atom in a 6-membered ring?
)
bond_features = (
bonds.BondInRingOfSize(3), # Is the bond in a 3-membered ring?
bonds.BondInRingOfSize(4), # Is the bond in a 4-membered ring?
bonds.BondInRingOfSize(5), # Is the bond in a 5-membered ring?
bonds.BondInRingOfSize(6), # Is the bond in a 6-membered ring?
)
In addition to the description of the molecule, we also need to specify the architecture of the GNN. NAGL’s GNNs consist of two modules, each of which is a neural network: Convolution, which incorporates information about its surroundings into the representation of each atom; and Readout, which computes some chemical property or properties from the convolved representation of an atom. Once the Readout module makes its prediction, a final post-processing layer can be applied to inject some human chemical knowledge on the output end; in the case of partial charges, the readout neural network predicts hardness and electronegativity, and charges are computed analytically from them.
Both modules are configured as arguments to the GNNModel
initialization method. Here, the GNNModel
class represents all the hyperparameters for a model, but after we train it the same object will store weights as well.
We first provide the model with the atom and bond features we selected above. Then, we configure the convolution module by specifying its overall architecture and size. In this example we use the GraphSAGE convolution layer architecture (no relation to the Sage force field!) with 3 hidden layers, each with 128 features. Note that these hidden layers include the output layer of the module, which is “hidden” by the next module, but do not include the input layer, which is specified by the atom_features
and bond_features
arguments.
We then specify the readout module. We specify the number of layers between the pooling layer (whose size is set by the convolution module) and the optional post-processing layer, the number of features in each of these layers, and the activation function. Finally, we specify the ComputePartialCharges
post-processing layer, which also adds a final layer with the appropriate number of features to the module.
from openff.nagl import GNNModel
from openff.nagl.nn.gcn import SAGEConvStack
from torch.nn import ReLU
from openff.nagl.nn.postprocess import ComputePartialCharges
model = GNNModel(
atom_features=atom_features,
bond_features=bond_features,
convolution_architecture=SAGEConvStack, # GraphSAGE GCN
n_convolution_hidden_features=128, # 128 features per hidden convolution layer
n_convolution_layers=3, # 3 hidden convolution layers
n_readout_hidden_features=128, # 128 features per internal readout layer
n_readout_layers=4, # 4 internal readout layers
activation_function=ReLU, # max(0, x) activation function for readout layer
postprocess_layer=ComputePartialCharges, # Add a 2-feature output layer to readout and compute charge from it
readout_name=f"am1bcc-charges",
learning_rate=0.001,
)
DGL backend not selected or invalid. Assuming PyTorch for now.
Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable. Valid options are: pytorch, mxnet, tensorflow (all lowercase)
Put together our datasets
When it comes down to it, a neural network infers a function from data and then interpolates that function’s value for inputs that are outside the dataset. To get a good interpolation, we need both a robust functional form (the model), and a rich set of data that spans the many-dimensional space we’d like to interpolate over. To evaluate the quality of the interpolation, we need even more data to test the trained model on - data that wasn’t used to train the model.
Usually this takes thousands of data points, but for this demonstration we’ll just use a dozen or so. By confining ourselves to short alkanes, we can get away with a small dataset.
Traditionally, and in NAGL, data is split up into three categories to minimise overfitting:
training: Data the model is trained against
validation: Data used to validate the model as it is trained
tests: Data used to test that the final model is good
In this example, we’ll use a collection of ten molecules for training. This dataset is in the alkanes.sqlite
file distributed with this notebook. We’ll also build a test/validation dataset of 3 molecules by hand.
We can use the MoleculeStore
class to take a look at what’s inside the alkanes.sqlite
file. Later, we’ll use it to store our custom test dataset:
from openff.nagl.storage import MoleculeStore
dataset = MoleculeStore("alkanes.sqlite")
training_molecules = [Molecule.from_smiles(smiles) for smiles in dataset.get_smiles()]
[mol.to_smiles(explicit_hydrogens=False) for mol in training_molecules]
['CCC(C)CC',
'CCCC(C)C',
'CCCCCC',
'CCCCC',
'CCC(C)C',
'CCCC',
'CC(C)C',
'CCC',
'CC',
'C']
Building a test dataset
To augment the provided training set, we’ll quickly prepare a second dataset for testing and validation. This involves preparing a list of MoleculeRecord
objects with partial charges and saving them into a SQLite database. For more information on how to prepare a dataset, see the prepare-dataset
example:
from openff.nagl.storage.record import MoleculeRecord
from openff.nagl.storage import MoleculeStore
# Choose the molecules to put in this dataset
# Note that these molecules aren't in the training dataset!
test_smiles = [
"CCCCCCC",
"CC(C)C(C)C",
"CC(C)(C)C",
]
records = []
for smiles in test_smiles:
# Create a Molecule object
molecule = Molecule.from_smiles(smiles, allow_undefined_stereo=True)
# Generate a conformer for charge assignment
# Note that a production dataset should include some method to produce
# conformation-independent charges
molecule.generate_conformers(n_conformers=1)
# Compute partial charges
molecule.assign_partial_charges("am1bcc")
# Create a MoleculeRecord
record = MoleculeRecord.from_precomputed_openff(
molecule,
partial_charge_method="am1bcc"
)
# Add the record to the list
records.append(record)
# Save the dataset
test_set_path = Path("my_first_test_dataset.sqlite")
if test_set_path.exists():
test_set_path.unlink()
MoleculeStore(test_set_path).store(records)
grouping records to store by InChI key: 0%| | 0/3 [00:00<?, ?it/s]
grouping records to store by InChI key: 100%|██████████| 3/3 [00:00<00:00, 112.34it/s]
storing grouped records: 0%| | 0/3 [00:00<?, ?it/s]
storing grouped records: 100%|██████████| 3/3 [00:00<00:00, 57.26it/s]
Curating our data module
The data module is responsible for providing featurized data to the model as it is fitted. It therefore needs the featurization scheme, as well as the paths to the training, validation and test data sets. Fitting is done in parallel batches whose size can be tweaked to manage the available memory; our datasets are small enough that all fitting will be done at once.
from openff.nagl.nn.dataset import DGLMoleculeLightningDataModule
# Remove the output directory (in case we're re-running the notebook)
if Path("./data/").exists():
!rm -r data
data_module = DGLMoleculeLightningDataModule(
atom_features=atom_features,
bond_features=bond_features,
partial_charge_method="am1bcc",
training_set_paths=[Path("alkanes.sqlite")],
validation_set_paths=[test_set_path],
test_set_paths=[test_set_path],
training_batch_size=1000,
validation_batch_size=1000,
test_batch_size=1000,
)
Train the model
We’ve prepared our model architecture and our training, validation and test data; now we just need to fit the model! To do this, we use the Trainer
class from PyTorch Lightning. This allows us to configure how data and progress are stored and reported using callbacks. The fit()
method trains and validates against the data module we provide it:
from pytorch_lightning import Trainer
trainer = Trainer(max_epochs=200)
trainer.progress_bar_callback.disable()
trainer.checkpoint_callback.monitor = "val_loss"
trainer.fit(
model,
datamodule=data_module,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
featurizing molecules: 0%| | 0/10 [00:00<?, ?it/s]
featurizing molecules: 20%|██ | 2/10 [00:00<00:00, 12.06it/s]
featurizing molecules: 40%|████ | 4/10 [00:00<00:00, 8.45it/s]
featurizing molecules: 70%|███████ | 7/10 [00:00<00:00, 13.11it/s]
featurizing molecules: 100%|██████████| 10/10 [00:00<00:00, 15.29it/s]
featurizing molecules: 0%| | 0/3 [00:00<?, ?it/s]
featurizing molecules: 67%|██████▋ | 2/3 [00:00<00:00, 14.25it/s]
featurizing molecules: 100%|██████████| 3/3 [00:00<00:00, 14.23it/s]
| Name | Type | Params
---------------------------------------------------------
0 | convolution_module | ConvolutionModule | 70.3 K
1 | readout_modules | ModuleDict | 66.3 K
---------------------------------------------------------
136 K Trainable params
0 Non-trainable params
136 K Total params
0.546 Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=200` reached.
Results!
We can use the Trainer
object’s test()
method to evaluate the model against our test data:
trainer.test(model, data_module)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_loss │ 0.0032399590127170086 │ └───────────────────────────┴───────────────────────────┘
[{'test_loss': 0.0032399590127170086}]
Octane isn’t in any of our data, so the model hasn’t seen it yet! We can predict it’s partial charges with the compute_property()
method:
octane = Molecule.from_smiles("CCCCCCCC")
model.compute_property(octane)
Warning (not error because allow_undefined_stereo=True): RDMol has unspecified stereochemistry. Undefined chiral centers are:
- Atom C (index 0)
- Atom C (index 1)
- Atom C (index 2)
- Atom C (index 3)
- Atom C (index 4)
- Atom C (index 5)
- Atom C (index 6)
- Atom C (index 7)
tensor([[-0.0884],
[-0.0768],
[-0.0751],
[-0.0752],
[-0.0752],
[-0.0751],
[-0.0768],
[-0.0884],
[ 0.0314],
[ 0.0314],
[ 0.0314],
[ 0.0366],
[ 0.0366],
[ 0.0370],
[ 0.0370],
[ 0.0370],
[ 0.0370],
[ 0.0370],
[ 0.0370],
[ 0.0370],
[ 0.0370],
[ 0.0366],
[ 0.0366],
[ 0.0314],
[ 0.0314],
[ 0.0314]], grad_fn=<CatBackward0>)
And we can compare that to the partial charges produced by the OpenFF Toolkit:
octane.assign_partial_charges("am1bcc")
octane.partial_charges
Magnitude | [-0.09306153846153846 -0.07936153846153846 -0.07836153846153845 |
---|---|
Units | elementary_charge |
prediction = model.compute_property(octane, as_numpy=True) * unit.elementary_charge
np.abs(prediction - octane.partial_charges)
Warning (not error because allow_undefined_stereo=True): RDMol has unspecified stereochemistry. Undefined chiral centers are:
- Atom C (index 0)
- Atom C (index 1)
- Atom C (index 2)
- Atom C (index 3)
- Atom C (index 4)
- Atom C (index 5)
- Atom C (index 6)
- Atom C (index 7)
Magnitude | [0.004622163299413826 0.002569596341940067 0.0032267693794690577 |
---|---|
Units | elementary_charge |
All within 0.002 elementary charge units of true AM1BCC charges! Not too bad!
Saving and loading our model
We can save the final model with the model.save()
method. This’ll let us store it for later.
model.save("trained_alkane_model.pt")
When we want it again, we can use the GNNModel.load()
method:
model_from_disk = GNNModel.load("trained_alkane_model.pt")
model_from_disk.compute_property(octane)
Warning (not error because allow_undefined_stereo=True): RDMol has unspecified stereochemistry. Undefined chiral centers are:
- Atom C (index 0)
- Atom C (index 1)
- Atom C (index 2)
- Atom C (index 3)
- Atom C (index 4)
- Atom C (index 5)
- Atom C (index 6)
- Atom C (index 7)
tensor([[-0.0884],
[-0.0768],
[-0.0751],
[-0.0752],
[-0.0752],
[-0.0751],
[-0.0768],
[-0.0884],
[ 0.0314],
[ 0.0314],
[ 0.0314],
[ 0.0366],
[ 0.0366],
[ 0.0370],
[ 0.0370],
[ 0.0370],
[ 0.0370],
[ 0.0370],
[ 0.0370],
[ 0.0370],
[ 0.0370],
[ 0.0366],
[ 0.0366],
[ 0.0314],
[ 0.0314],
[ 0.0314]], grad_fn=<CatBackward0>)