Training a Graph Neural Network with NAGL

This notebook will go through the process of training a new Graph Neural Network (GNN) on a small dataset of alkanes, and demonstrate inference with the resulting model. On the way, we’ll put together a tiny test dataset, and talk a bit about the architecture of the GNN we’re training.


from pathlib import Path

import numpy as np

from openff.toolkit import Molecule
from openff.units import unit

Create the model

First, we need a neural network to train! The neural network is called a “model”, because once trained it amounts to a statistical model of some aspect of reality - in this case, a mapping from a molecule to a list of partial charges.

NAGL’s models start by defining a format to describe a molecule. General purpose formats that work well for humans, like SMILES strings or Kekule structures, tend not to work for AI. Humans have a huge amount of background knowledge; we can read letters, we can interpret visual images, and those of us that can read SMILES strings and Kekule structures know a lot about chemistry too. An untrained neural network doesn’t know any of this, and a neural network big enough to learn it would take forever to train and might find it easier to just memorize the partial charges of the molecules in the training set.

Instead, we want to use our human background knowledge to give the neural network as much relevant information as we can, so it can learn to generalize from the training set rather than memorize it. This is often the great challenge in training neural networks: Giving the network just the right information and abilities to glean a relationship from the training data. Too much and it can invent relationships that don’t exist in reality but allow it to reproduce the training set slightly better; too little and it can’t learn anything at all. This is also why it’s important to separate the test, training and validation datasets!

In a graph neural network, we describe a molecule as a graph of atoms. In mathematics, a graph is a collection of nodes that are connected by edges. Lots of everyday systems and objects are easy to describe as graphs: in public transit maps, nodes are stops and edges are routes; in geometric shapes, nodes are vertices and edges are, well, edges; on Twitter, nodes are people and edges are follows (or tweets or likes or all of the above!). In a molecular graph, nodes are atoms and edges are bonds. This allows us to apply computational science techniques developed for graphs to molecules. More on that later!

To describe a graph, we first need to describe its nodes and edges. Then we can do some magic to turn it into a graph. The format we use to describe nodes and edges is a list of numbers that each describe a particular feature of the thing, so the construction of this format is called featurization. This lets us choose exactly what information the network has access to, and it lets us use our chemical background knowledge to provide the network with theoretical information that would help it make its decision, but would be too complicated to learn in training. In this example, this includes features like connectivity and ring size, even though that information is redundant with the graph itself; it turns out that having this information “close to hand” helps the neural network out!

In NAGL, we construct a featurization for atoms by choosing a list of features from those in the features.atoms module, and for bonds by choosing a list of features from the features.bonds module. Once it has the featurization, NAGL can apply it to a molecule automatically. A lot of these features are useless in a dataset of acyclic alkanes, but we include them to demonstrate the sorts of features that are available:

from openff.nagl.features import atoms

atom_features = (
    atoms.AtomicElement(categories=["C", "H"]), # Is the atom Carbon or Hydrogen?
    atoms.AtomConnectivity(), # Is the atom bonded to 1, 2, 3, or 4 other atoms?
    atoms.AtomAverageFormalCharge(), # What is the atom's mean formal charge over the molecule's tautomers?
    atoms.AtomHybridization(), # What is the hybridization of the atom?
    atoms.AtomInRingOfSize(ring_size=3), # Is the atom in a 3-membered ring?
    atoms.AtomInRingOfSize(ring_size=4), # Is the atom in a 4-membered ring?
    atoms.AtomInRingOfSize(ring_size=5), # Is the atom in a 5-membered ring?
    atoms.AtomInRingOfSize(ring_size=6), # Is the atom in a 6-membered ring?

In addition to the description of the molecule, we also need to specify the architecture of the GNN. NAGL’s GNNs consist of two modules, each of which is a neural network: Convolution, which incorporates information about its surroundings into the representation of each atom; and Readout, which computes some chemical property or properties from the convolved representation of an atom. Once the Readout module makes its prediction, a final post-processing layer can be applied to inject some human chemical knowledge on the output end; in the case of partial charges, the readout neural network predicts hardness and electronegativity, and charges are computed analytically from them.

The GNN model should be configured via a ModelConfig object, with convolution and readout layers individually specified. This can be done in Python as specified below, although typically we expect to read from a YAML or JSON file (e.g. ModelConfig.from_yaml_file or ModelConfig.parse_file).

In this example we use the GraphSAGE convolution layer architecture (no relation to the Sage force field!) with 3 hidden layers, each with 128 features. Note that these hidden layers include the output layer of the module, which is “hidden” by the next module, but do not include the input layer, which is specified by the atom_features and bond_features arguments. We first provide the model with the atom features we selected above. (We do not pass bond_features in the example below as GraphSAGE only makes use of atom features.) Then, we configure the convolution module by specifying its overall architecture and size.

from openff.nagl.config.model import (
from openff.nagl import GNNModel
from openff.nagl.nn.gcn import SAGEConvStack
from torch.nn import ReLU
from openff.nagl.nn.postprocess import ComputePartialCharges

single_convolution_layer = ConvolutionLayer(
    hidden_feature_size=128,  # 128 features per hidden convolution layer
    aggregator_type="mean",  # aggregate atom representations with mean
    activation_function="ReLU", # max(0, x) activation function for layer
    dropout=0.0, # no dropout

convolution_module = ConvolutionModule(
    architecture="SAGEConv", # GraphSAGE GCN
    layers=[single_convolution_layer] * 3, # 3 hidden convolution layers        

We then specify the readout module. We specify the number of dense layers between the pooling layer (whose size is set by the convolution module) and the optional post-processing layer, the number of features in each of these layers, and the activation function. Finally, we specify the ComputePartialCharges post-processing layer, which also adds a final layer with the appropriate number of features to the module.

from openff.nagl.config.model import (

single_readout_layer = ForwardLayer(
    hidden_feature_size=128,  # 128 features per hidden convolution layer
    activation_function="ReLU", # max(0, x) activation function for layer
    dropout=0.0, # no dropout

readout_module = ReadoutModule(
    layers=[single_readout_layer] * 4, # 4 internal readout layers

Now we can put them together in a full ModelConfig. This can be passed to create a GNNModel. A model can have multiple readouts that derive different properties from the convolution representation, so each readout module is specified in a dictionary with a label.

Here, the GNNModel class represents all the hyperparameters for a model, but after we train it the same object will store weights as well.

from openff.nagl.config.model import ModelConfig

model_config = ModelConfig(
        "predicted-am1bcc-charges": readout_module

Put together our datasets

When it comes down to it, a neural network infers a function from data and then interpolates that function’s value for inputs that are outside the dataset. To get a good interpolation, we need both a robust functional form (the model), and a rich set of data that spans the many-dimensional space we’d like to interpolate over. To evaluate the quality of the interpolation, we need even more data to test the trained model on - data that wasn’t used to train the model.

Usually this takes thousands of data points, but for this demonstration we’ll just use a dozen or so. By confining ourselves to short alkanes, we can get away with a small dataset.

Traditionally, and in NAGL, data is split up into three categories to minimise overfitting:

  • training: Data the model is trained against

  • validation: Data used to validate the model as it is trained

  • tests: Data used to test that the final model is good

In this example, we’ll use a collection of ten molecules for training. This dataset is in the labelled_alkanes directory distributed with this notebook. We’ll also build a test/validation dataset of 3 molecules by hand.

We can use the [LabelledDataset] class to take a look at what’s inside the labelled_alkanes directory (or use pyarrow directly). Later, we’ll use it to store our custom test dataset.

from openff.nagl.label.dataset import LabelledDataset

dataset = LabelledDataset("labelled_alkanes")
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pandas/core/computation/ UserWarning: Pandas requires version '2.8.4' or newer of 'numexpr' (version '2.7.3' currently installed).
  from pandas.core.computation.check import NUMEXPR_INSTALLED
mapped_smiles am1_charges am1bcc_charges custom_charges
0 [H:2][C:1]([H:3])([H:4])[H:5] [-0.2658799886703491, 0.06646999716758728, 0.0... [-0.10868000239133835, 0.027170000597834587, 0... [0]
1 [H:3][C:1]([H:4])([H:5])[C:2]([H:6])([H:7])[H:8] [-0.21174000017344952, -0.21174000017344952, 0... [-0.09384000208228827, -0.09384000208228827, 0... [1]
2 [H:4][C:1]([H:5])([H:6])[C:2]([H:7])([H:8])[C:... [-0.21018000082536178, -0.15999999777837234, -... [-0.09227999977090141, -0.08139999888160011, -... [2]
3 [H:5][C:1]([H:6])([H:7])[C:2]([H:8])([H:9])[C:... [-0.21003000438213348, -0.15905000269412994, -... [-0.09212999844125339, -0.08044999891093799, -... [3]
4 [H:5][C:1]([H:6])([H:7])[C:2]([H:8])([C:3]([H:... [-0.20747000138674462, -0.10981000374470438, -... [-0.08957000076770782, -0.07050999999046326, -... [4]
5 [H:6][C:1]([H:7])([H:8])[C:2]([H:9])([H:10])[C... [-0.21004000306129456, -0.15812000632286072, -... [-0.09213999658823013, -0.07952000200748444, -... [5]
6 [H:15][C:5]([H:16])([H:17])[C:4]([H:13])([H:14... [-0.20766000405830495, -0.10704000250381582, -... [-0.0897599982426447, -0.06774000100353185, -0... [6]
7 [H:7][C:1]([H:8])([H:9])[C:2]([H:10])([H:11])[... [-0.21021999344229697, -0.15823000594973563, -... [-0.0923200011253357, -0.0796300008893013, -0.... [7]
8 [H:18][C:6]([H:19])([H:20])[C:5]([H:16])([H:17... [-0.208649992197752, -0.1059999980032444, -0.2... [-0.09075000137090683, -0.06669999659061432, -... [8]
9 [H:13][C:4]([H:14])([H:15])[C:3]([H:11])([H:12... [-0.2068299949169159, -0.10380999743938446, -0... [-0.08893000297248363, -0.06451000235974788, -... [9]

Building a test dataset

To augment the provided training set, we’ll quickly prepare a second dataset for testing and validation. This involves preparing a list of SMILES with partial charges and saving them into a PyArrow Dataset. The LabelledDataset class is used as a convenient wrapper. For more information on how to prepare a dataset, see the prepare-dataset example.

from openff.nagl.label.labels import LabelCharges

# Choose the molecules to put in this dataset
# Note that these molecules aren't in the training dataset!
test_smiles = [

test_dataset = LabelledDataset.from_smiles(
    "my_first_test_dataset",  # path to save to

am1bcc_labeller = LabelCharges(
Applying labellers to batches: 0it [00:00, ?it/s]

Assigning charges:   0%|          | 0/3 [00:00<?, ?it/s]

Assigning charges:  33%|███▎      | 1/3 [00:03<00:06,  3.32s/it]

Assigning charges:  67%|██████▋   | 2/3 [00:06<00:03,  3.26s/it]

Assigning charges: 100%|██████████| 3/3 [00:08<00:00,  2.70s/it]

Assigning charges: 100%|██████████| 3/3 [00:08<00:00,  2.86s/it]
Applying labellers to batches: 1it [00:08,  8.58s/it]
Applying labellers to batches: 1it [00:08,  8.58s/it]

mapped_smiles am1bcc_charges
0 [C:1]([C:2]([C:3]([C:4]([C:5]([C:6]([C:7]([H:2... [-0.09218686956521739, -0.07998686956521739, -...
1 [C:1]([C:2]([C:3]([H:11])([H:12])[H:13])([C:4]... [-0.0890002, -0.06360020000000001, -0.0890002,...
2 [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([C:4](... [-0.08533529411764705, -0.06023529411764706, -...

Curating our data module

The data module is responsible for providing featurized data to the model as it is fitted. As with the model, the datasets we use should be specified by a config object. Ideally users should not interact with a DGLMoleculeDataModule directly. Fitting is done in parallel batches whose size can be tweaked to manage the available memory; our datasets are small enough that all fitting will be done at once.

One of the strengths of using Arrow datasets is that we can choose which columns to load into memory as needed for training. That means that for the DatasetConfig, we need to specify the targets we are choosing to fit. A Target is what we used to construct the objective function and calculate loss.

from import DatasetConfig, DataConfig
from import RMSEMetric
from import ReadoutTarget

charge_rmse_target = ReadoutTarget(
    metric=RMSEMetric(),  # use RMSE to calculate loss
    target_label="am1bcc_charges", # column to use from data as reference target
    prediction_label="predicted-am1bcc-charges", # readout value to compare to target
    denominator=1.0, # denominator to normalise loss -- important for multi-target objectives
    weight=1.0, # how much to weight the loss -- important for multi-target objectives

training_dataset_config = DatasetConfig(

test_dataset_config = validation_dataset_config = DatasetConfig(

data_config = DataConfig(

Train the model

We’ve prepared our model architecture and our training, validation and test data; now we just need to fit the model! To do this, we need to specify optimization settings with a OptimizerConfig, and then put everything together in a TrainingConfig.

from openff.nagl.config.optimizer import OptimizerConfig
from import TrainingConfig

optimizer_config = OptimizerConfig(

training_config = TrainingConfig(
from import TrainingGNNModel, DGLMoleculeDataModule

training_model = TrainingGNNModel(training_config)
data_module = DGLMoleculeDataModule(training_config)

To properly fit the model, we use the Trainer class from PyTorch Lightning. This allows us to configure how data and progress are stored and reported using callbacks. The fit() method trains and validates against the data module we provide it:

from pytorch_lightning import Trainer

trainer = Trainer(max_epochs=200)

trainer.checkpoint_callback.monitor = "val/loss"
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/runner/work/openff-docs/openff-docs/build/cookbook/src/openforcefield/openff-nagl/train-gnn-notebook/lightning_logs
Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]

Featurizing batch:  40%|████      | 4/10 [00:00<00:00, 29.67it/s]

Featurizing batch:  80%|████████  | 8/10 [00:00<00:00, 27.71it/s]

Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 27.03it/s]
Featurizing dataset: 1it [00:00,  2.67it/s]
Featurizing dataset: 1it [00:00,  2.66it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 30.23it/s]
Featurizing dataset: 1it [00:00,  9.48it/s]
Featurizing dataset: 1it [00:00,  9.34it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 26.85it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 25.77it/s]
Featurizing dataset: 1it [00:00,  7.63it/s]
Featurizing dataset: 1it [00:00,  7.54it/s]

2024-04-16 00:36:36.529523: I tensorflow/core/platform/] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]

Featurizing batch:  40%|████      | 4/10 [00:00<00:00, 35.12it/s]

Featurizing batch:  80%|████████  | 8/10 [00:00<00:00, 25.20it/s]

Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 22.97it/s]
Featurizing dataset: 1it [00:00,  2.25it/s]
Featurizing dataset: 1it [00:00,  2.25it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch:  67%|██████▋   | 2/3 [00:00<00:00, 19.18it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 19.83it/s]
Featurizing dataset: 1it [00:00,  6.36it/s]
Featurizing dataset: 1it [00:00,  6.30it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 19.45it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 17.08it/s]
Featurizing dataset: 1it [00:00,  5.51it/s]
Featurizing dataset: 1it [00:00,  5.46it/s]
  | Name  | Type     | Params
0 | model | GNNModel | 136 K 
136 K     Trainable params
0         Non-trainable params
136 K     Total params
0.546     Total estimated model params size (MB)
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/ The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/ The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/loops/ The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
`` stopped: `max_epochs=200` reached.


We can use the Trainer object’s test() method to evaluate the model against our test data. The output has two entries:

  • test/am1bcc_charges/readout/rmse/1.0/1.0: the charge RMSE loss

  • test/loss: the total loss

The numbers are the same here as this is a single target objective.

trainer.test(training_model, data_module)
Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]

Featurizing batch:  40%|████      | 4/10 [00:00<00:00, 32.17it/s]

Featurizing batch:  80%|████████  | 8/10 [00:00<00:00, 31.86it/s]

Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 33.03it/s]
Featurizing dataset: 1it [00:00,  3.26it/s]
Featurizing dataset: 1it [00:00,  3.25it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 40.66it/s]
Featurizing dataset: 1it [00:00, 12.47it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 38.33it/s]
Featurizing dataset: 1it [00:00, 11.80it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]

Featurizing batch:  60%|██████    | 6/10 [00:00<00:00, 45.09it/s]

Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 27.25it/s]
Featurizing dataset: 1it [00:00,  2.68it/s]
Featurizing dataset: 1it [00:00,  2.67it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 30.09it/s]
Featurizing dataset: 1it [00:00,  9.45it/s]
Featurizing dataset: 1it [00:00,  9.32it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 21.95it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 19.97it/s]
Featurizing dataset: 1it [00:00,  6.42it/s]
Featurizing dataset: 1it [00:00,  6.35it/s]
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/ The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
┃               Test metric                               DataLoader 0               ┃
│ test/am1bcc_charges/readout/rmse/1.0/1.0            0.006667664274573326           │
│                test/loss                            0.006667664274573326           │
[{'test/am1bcc_charges/readout/rmse/1.0/1.0': 0.006667664274573326,
  'test/loss': 0.006667664274573326}]

We can isolate the model itself from all the training requirements:

model = training_model.model

Octane isn’t in any of our data, so the model hasn’t seen it yet! We can predict it’s partial charges with the compute_property() method:

octane = Molecule.from_smiles("CCCCCCCC")

array([-0.10181125, -0.0874697 , -0.08599971, -0.08590907, -0.08590907,
       -0.08599971, -0.0874697 , -0.10181125,  0.03629763,  0.03629763,
        0.03629763,  0.04161337,  0.04161337,  0.0422675 ,  0.0422675 ,
        0.0422675 ,  0.0422675 ,  0.0422675 ,  0.0422675 ,  0.0422675 ,
        0.0422675 ,  0.04161337,  0.04161337,  0.03629763,  0.03629763,
        0.03629763], dtype=float32)

And we can compare that to the partial charges produced by the OpenFF Toolkit:

[-0.09306153846153846 -0.07936153846153846 -0.07836153846153845
-0.07836153846153845 -0.07836153846153845 -0.07836153846153845
-0.07936153846153846 -0.09306153846153846 0.03223846153846154
0.03223846153846154 0.03223846153846154 0.03798846153846154
0.03798846153846154 0.03848846153846154 0.03848846153846154
0.03973846153846154 0.03973846153846154 0.03973846153846154
0.03973846153846154 0.03848846153846154 0.03848846153846154
0.03798846153846154 0.03798846153846154 0.03223846153846154
0.03223846153846154 0.03223846153846154]
prediction = model.compute_property(octane, as_numpy=True) * unit.elementary_charge
np.abs(prediction - octane.partial_charges)
[0.00874970662227044 0.008108158537057736 0.007638173886445859
0.007547530122903684 0.007547530122903684 0.007638173886445859
0.008108158537057736 0.00874970662227044 0.004059172705503607
0.004059172705503607 0.004059172705503607 0.0036249086416684645
0.0036249086416684645 0.0037790398157559885 0.0037790398157559885
0.0025290398157559874 0.0025290398157559874 0.0025290398157559874
0.0025290398157559874 0.0037790398157559885 0.0037790398157559885
0.0036249086416684645 0.0036249086416684645 0.004059172705503607
0.004059172705503607 0.004059172705503607]

All within 0.02 elementary charge units of true AM1BCC charges! Not too bad!

Saving and loading our model

We can save the final model with the method. This’ll let us store it for later."")

When we want it again, we can use the GNNModel.load() method:

model_from_disk = GNNModel.load("")
array([-0.10181125, -0.0874697 , -0.08599971, -0.08590907, -0.08590907,
       -0.08599971, -0.0874697 , -0.10181125,  0.03629763,  0.03629763,
        0.03629763,  0.04161337,  0.04161337,  0.0422675 ,  0.0422675 ,
        0.0422675 ,  0.0422675 ,  0.0422675 ,  0.0422675 ,  0.0422675 ,
        0.0422675 ,  0.04161337,  0.04161337,  0.03629763,  0.03629763,
        0.03629763], dtype=float32)