Training a Graph Neural Network with NAGL with multiple objectives

This notebook will go through the process of training a new Graph Neural Network (GNN) on a small dataset of alkanes with multiple objectives. Please see the train-gnn-notebook tutorial for more on what’s happening under the hood.

Imports

from pathlib import Path

import numpy as np

from openff.toolkit import Molecule
from openff.units import unit

Create the model

First, let’s specify the model features.

from openff.nagl.features import atoms

atom_features = (
    atoms.AtomicElement(categories=["C", "H"]), # Is the atom Carbon or Hydrogen?
    atoms.AtomConnectivity(), # Is the atom bonded to 1, 2, 3, or 4 other atoms?
    atoms.AtomAverageFormalCharge(), # What is the atom's mean formal charge over the molecule's tautomers?
    atoms.AtomHybridization(), # What is the hybridization of the atom?
    atoms.AtomInRingOfSize(ring_size=3), # Is the atom in a 3-membered ring?
    atoms.AtomInRingOfSize(ring_size=4), # Is the atom in a 4-membered ring?
    atoms.AtomInRingOfSize(ring_size=5), # Is the atom in a 5-membered ring?
    atoms.AtomInRingOfSize(ring_size=6), # Is the atom in a 6-membered ring?
)

We also need to specify the architecture of the GNN. We can make this as complicated as we like.

from openff.nagl.config.model import (
    ConvolutionLayer,
    ConvolutionModule,
)
from openff.nagl import GNNModel
from openff.nagl.nn.gcn import SAGEConvStack
from torch.nn import ReLU
from openff.nagl.nn.postprocess import ComputePartialCharges

single_convolution_layer = ConvolutionLayer(
    hidden_feature_size=128,  # 128 features per hidden convolution layer
    aggregator_type="mean",  # aggregate atom representations with mean
    activation_function="ReLU", # max(0, x) activation function for layer
    dropout=0.0, # no dropout
)

convolution_module = ConvolutionModule(
    architecture="SAGEConv", # GraphSAGE GCN
    layers=[single_convolution_layer] * 3, # 3 hidden convolution layers        
)

We then specify the readout module.

from openff.nagl.config.model import (
    ForwardLayer,
    ReadoutModule,
)

single_readout_layer = ForwardLayer(
    hidden_feature_size=128,  # 128 features per hidden convolution layer
    activation_function="ReLU", # max(0, x) activation function for layer
    dropout=0.0, # no dropout
)

normal_readout_module = ReadoutModule(
    pooling="atoms",
    layers=[single_readout_layer] * 4, # 4 internal readout layers
    # calculate charges with charge equilibration scheme from
    # electronegativity and hardness
    postprocess="compute_partial_charges"
)
regularised_readout_module = ReadoutModule(
    pooling="atoms",
    layers=[single_readout_layer] * 4, # 4 internal readout layers
    # calculate charges with charge equilibration scheme from
    # electronegativity, hardness, and an initial charge prediction
    postprocess="regularized_compute_partial_charges"
)

Now we can put them together in a full ModelConfig. This can be passed to create a GNNModel. A model can have multiple readouts that derive different properties from the convolution representation, so each readout module is specified in a dictionary with a label.

Here, the GNNModel class represents all the hyperparameters for a model, but after we train it the same object will store weights as well.

from openff.nagl.config.model import ModelConfig

model_config = ModelConfig(
    version="0.1",
    atom_features=atom_features,
    bond_features=[],
    convolution=convolution_module,
    readouts={
        "predicted-am1bcc-charges": normal_readout_module,
        "predicted-am1-charges": regularised_readout_module
    }
)

Put together our datasets

We need to set up three datasets hers:

  • training: Data the model is trained against

  • validation: Data used to validate the model as it is trained

  • tests: Data used to test that the final model is good

In this example, we’ll use a collection of ten molecules for training. We’ll also build a test/validation dataset of 3 molecules that are not in the training set.

We can use the [LabelledDataset] class to generate training data that is saved in the training_data directory (or use pyarrow directly). First we can generate the dataset from SMILES:

from openff.nagl.label.dataset import LabelledDataset

training_alkanes = [
    'C',
     'CC',
     'CCC',
     'CCCC',
     'CC(C)C',
     'CCCCC',
     'CC(C)CC',
     'CCCCCC',
     'CC(C)CCC',
     'CC(CC)CC',
]

training_dataset = LabelledDataset.from_smiles(
    "training_data",
    training_alkanes,
    mapped=False,
    overwrite_existing=True,
)
training_dataset.to_pandas()
mapped_smiles
0 [C:1]([H:2])([H:3])([H:4])[H:5]
1 [C:1]([C:2]([H:6])([H:7])[H:8])([H:3])([H:4])[...
2 [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([H:7])...
3 [C:1]([C:2]([C:3]([C:4]([H:12])([H:13])[H:14])...
4 [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([C:4](...
5 [C:1]([C:2]([C:3]([C:4]([C:5]([H:15])([H:16])[...
6 [C:1]([C:2]([C:3]([H:10])([H:11])[H:12])([C:4]...
7 [C:1]([C:2]([C:3]([C:4]([C:5]([C:6]([H:18])([H...
8 [C:1]([C:2]([C:3]([H:11])([H:12])[H:13])([C:4]...
9 [C:1]([C:2]([C:3]([C:4]([H:13])([H:14])[H:15])...

Below we specify label functions to label our molecules with the information that we will use in training and testing. Each argument is specified and annotated to explain their purpose; however, all label functions can be instantiated with default arguments (e.g. LabelConformers()) unless you need specific column names or arguments (e.g. changing the charge method).

Note: the ESP label function requires openff-recharge to be installed.

from openff.nagl.label.labels import (
    LabelConformers,
    LabelCharges,
    LabelMultipleDipoles,
    LabelMultipleESPs,
)
import openff.recharge

# generate ELF conformers
label_conformers = LabelConformers(
    # create a new 'conformers' with output conformers
    conformer_column="conformers",
    # create a new 'n_conformers' with number of conformers
    n_conformer_column="n_conformers",
    n_conformer_pool=500, # initially generate 500 conformers
    n_conformers=10, # prune to max 10 conformers
    rms_cutoff=0.05,
)

# generate AM1 charges
label_am1_charges = LabelCharges(
    charge_method="am1-mulliken", # AM1
    # use previously generate conformers instead of new ones
    use_existing_conformers=True,
    # use the 'conformers' column as input for charge assignment
    conformer_column="conformers",
    # write generated charges to 'target-am1-charges' column
    charge_column="target-am1-charges",
)

# generate AM1-BCC charges
label_am1bcc_charges = LabelCharges(
    charge_method="am1bcc", # AM1BCC
    # use previously generate conformers instead of new ones
    use_existing_conformers=True,
    # use the 'conformers' column as input for charge assignment
    conformer_column="conformers",
    # write generated charges to 'target-am1bcc-charges' column
    charge_column="target-am1bcc-charges",
)

label_am1bcc_dipoles = LabelMultipleDipoles(
    # use the 'conformers' column as input to calculate dipole moments
    conformer_column="conformers",
    # use the 'n_conformers' column as input
    n_conformer_column="n_conformers",
    # use the "target-am1bcc-charges" column as input to calculate dipole moments
    charge_column="target-am1bcc-charges",
    # write calculated dipoles to 'target-am1bcc-dipoles' column
    dipole_column="target-am1bcc-dipoles",
)

label_am1bcc_esps = LabelMultipleESPs(
    # use the 'conformers' column as input to calculate ESPs
    conformer_column="conformers",
    # use the 'n_conformers' column as input
    n_conformer_column="n_conformers",
    # use the "target-am1bcc-charges" column as input to calculate ESPS
    charge_column="target-am1bcc-charges",
    # generate new grids and inverse distances to points
    use_existing_inverse_distances=False,
    # write inverse distances from conformer to surface to this column
    inverse_distance_matrix_column="grid_inverse_distances",
    # write number of grid points for each surface to this column
    grid_length_column="esp_lengths",
    # write calculated ESPs to 'esps' column
    esp_column="esps",
)

Below we apply the label functions to actually generate the labels. The order matters, as later label functions use the output of earlier ones.

labellers = [
    label_conformers, # generate initial conformers,
    label_am1_charges,
    label_am1bcc_charges,
    label_am1bcc_dipoles,
    label_am1bcc_esps,
]

training_dataset.apply_labellers(labellers)
training_dataset.to_pandas()
mapped_smiles conformers n_conformers target-am1-charges target-am1bcc-charges target-am1bcc-dipoles esp_lengths grid_inverse_distances esps
0 [C:1]([H:2])([H:3])([H:4])[H:5] [0.005118712612069831, -0.010620498663889949, ... 1 [-0.2656, 0.0664, 0.0664, 0.0664, 0.0664] [-0.1084, 0.0271, 0.0271, 0.0271, 0.0271] [-0.0006935855589354605, 0.0014390775689570746... [398] [0.22234336592604098, 0.1691980088400493, 0.26... [-0.00048403793895890833, -0.00071972139957695...
1 [C:1]([C:2]([H:6])([H:7])[H:8])([H:3])([H:4])[... [-0.7455231747504416, 0.04144445508118371, 0.0... 1 [-0.21225, -0.21225, 0.07075, 0.07075, 0.07075... [-0.09435, -0.09435, 0.03145, 0.03145, 0.03145... [-0.0002285229334980654, -0.005575840474810592... [500] [0.22234336592604098, 0.18774435919748517, 0.3... [-0.001407404714360157, -0.0029490792419474554...
2 [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([H:7])... [1.2163074727474537, -0.24294836930504896, 0.2... 1 [-0.211, -0.16, -0.211, 0.072, 0.072, 0.071, 0... [-0.09310018181818182, -0.08140018181818182, -... [-0.0043053606473935704, -0.009872806411602934... [602] [0.22234336592604098, 0.2043827690515404, 0.18... [-0.002920570510046846, -0.0009924942195456941...
3 [C:1]([C:2]([C:3]([C:4]([H:12])([H:13])[H:14])... [1.8901957496718658, 0.042575097521835, 0.2431... 1 [-0.21028571428571427, -0.15928571428571428, -... [-0.09238585714285714, -0.08068585714285714, -... [-0.003195845195204508, -0.00494009309346443, ... [690] [0.22234336592604098, 0.17092450502366457, 0.1... [-0.000870799470089798, 4.8700196369716106e-05...
4 [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([C:4](... [1.4351901021861362, 0.27740245802875496, 0.13... 1 [-0.20735714285714285, -0.10935714285714285, -... [-0.08945692857142856, -0.07005692857142856, -... [-0.0059915089786667305, -0.009373247525965786... [672] [0.22234336592604098, 0.19109842170559405, 0.1... [-0.0008717150501716528, -0.003425712120807677...
5 [C:1]([C:2]([C:3]([C:4]([C:5]([H:15])([H:16])[... [2.4660243557948207, -0.0503036791214956, 0.03... 1 [-0.21, -0.159, -0.158, -0.159, -0.21, 0.07200... [-0.09210011764705882, -0.08040011764705882, -... [-0.003987426571054067, 0.005559284278063413, ... [791] [0.22234336592604098, 0.14213609699084526, 0.1... [0.0005046598853891658, 0.0002250632479590313,...
6 [C:1]([C:2]([C:3]([H:10])([H:11])[H:12])([C:4]... [-1.2418582991411082, 1.052060941462957, 0.017... 1 [-0.207, -0.10599999999999998, -0.207, -0.153,... [-0.08909988235294117, -0.06669988235294116, -... [-0.0026498796085173554, -0.008039775679738135... [750] [0.22234336592604098, 0.15192142508151973, 0.1... [-0.00023715760591162555, -0.00101041663676669...
7 [C:1]([C:2]([C:3]([C:4]([C:5]([C:6]([H:18])([H... [-3.0907008742942494, -0.14370944748493916, -0... 1 [-0.21009999999999998, -0.1591, -0.1581, -0.15... [-0.0922, -0.0805, -0.0795, -0.0795, -0.0805, ... [0.0018400328214143447, -0.014928553541414105,... [901] [0.22234336592604098, 0.1986135821292406, 0.15... [-0.0014053749951576907, -0.000203661455117260...
8 [C:1]([C:2]([C:3]([H:11])([H:12])[H:13])([C:4]... [-1.3441616214248866, 1.2928572296227032, -0.3... 1 [-0.20694999999999997, -0.10694999999999998, -... [-0.08955004999999999, -0.06765004999999999, -... [-0.00816547475433177, 0.0055545221371308545, ... [844] [0.22234336592604098, 0.1852101593428309, 0.13... [-0.0012388351037772895, -0.002423698057369004...
9 [C:1]([C:2]([C:3]([C:4]([H:13])([H:14])[H:15])... [-0.40122078234363445, -1.422374025464223, -0.... 1 [-0.20794999999999997, -0.10594999999999999, -... [-0.09005, -0.06665, -0.07785, -0.09205, -0.07... [0.001310505146191715, -0.012141030775731199, ... [829] [0.22234336592604098, 0.1436149999840705, 0.11... [-0.00044908769940229555, -0.00022010214148746...

Building a test dataset

To augment the provided training set, we’ll quickly prepare a second dataset for testing and validation. We use the same label functions:

from openff.nagl.label.labels import LabelCharges

# Choose the molecules to put in this dataset
# Note that these molecules aren't in the training dataset!
test_smiles = [
    "CCCCCCC",
    "CC(C)C(C)C",
    "CC(C)(C)C",
]

test_dataset = LabelledDataset.from_smiles(
    "my_first_test_dataset",  # path to save to
    test_smiles,
    mapped=False,
    overwrite_existing=True,
)

test_dataset.apply_labellers(labellers)
test_dataset.to_pandas()
mapped_smiles conformers n_conformers target-am1-charges target-am1bcc-charges target-am1bcc-dipoles esp_lengths grid_inverse_distances esps
0 [C:1]([C:2]([C:3]([C:4]([C:5]([C:6]([C:7]([H:2... [-3.545070958260261, 0.5832530008833098, 0.079... 1 [-0.21008695652173912, -0.15908695652173913, -... [-0.09218686956521739, -0.07998686956521739, -... [-0.00595775894206417, 0.0008582141430205914, ... [944] [0.22234336592604098, 0.19692667350664092, 0.1... [-0.0006610506078259431, -0.000516657266085287...
1 [C:1]([C:2]([C:3]([H:11])([H:12])[H:13])([C:4]... [0.9713614798232566, -0.892209543378047, 1.004... 1 [-0.20589999999999997, -0.10289999999999998, -... [-0.0890002, -0.06360020000000001, -0.0890002,... [-0.002531086126934007, -0.00711238458844677, ... [803] [0.22234336592604098, 0.17641696600758316, 0.1... [-0.0006969112009999837, -0.000166671288359785...
2 [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([C:4](... [1.2935034573089457, -0.7675514803987564, -0.0... 1 [-0.20323529411764704, -0.06023529411764704, -... [-0.08533529411764705, -0.06023529411764706, -... [0.003012641718856162, 0.010054820006400858, 0... [743] [0.22234336592604098, 0.1629178797473639, 0.18... [1.7940051204891212e-05, 0.0004078426202498559...

Curating our data module

Now we assemble our datasets into a DataConfig. For each DatasetConfig, we need to specify the targets we are choosing to fit. A Target is what we used to construct the objective function and calculate loss. Below we:

  • fit the predicted-am1-charges property directly to the labelled target-am1-charges

  • fit the predicted-am1bcc-charges property to a combined objective of:

    • charge RMSE

    • dipole moments

    • ESP targets

For the physical properties, we need additional information (the conformers, n_conformers, … columns) to calculate the dipole moments and ESPs for comparison. Each of these has been annotated. The target_label always refers to the column in the input dataset that is the property we are comparing.

from openff.nagl.config.data import DatasetConfig, DataConfig
from openff.nagl.training.metrics import RMSEMetric
from openff.nagl.training.loss import ReadoutTarget, MultipleDipoleTarget, MultipleESPTarget


am1_charge_rmse_target = ReadoutTarget(
    metric=RMSEMetric(),  # use RMSE to calculate loss
    target_label="target-am1-charges", # column to use from data as reference target
    prediction_label="predicted-am1-charges", # readout value to compare to target
    denominator=1.0, # denominator to normalise loss -- important for multi-target objectives
    weight=1.0, # how much to weight the loss -- important for multi-target objectives
)

am1bcc_charge_rmse_target = ReadoutTarget(
    metric=RMSEMetric(),  # use RMSE to calculate loss
    target_label="target-am1bcc-charges", # column to use from data as reference target
    prediction_label="predicted-am1bcc-charges", # readout value to compare to target
    denominator=0.001, # denominator to normalise loss -- important for multi-target objectives
    weight=1.0, # how much to weight the loss -- important for multi-target objectives
)

am1bcc_dipole_target = MultipleDipoleTarget(
    metric=RMSEMetric(),
    target_label="target-am1bcc-dipoles", # column to use from input data as reference target
    charge_label="predicted-am1bcc-charges", # readout charge value to calculate dipoles with
    conformation_column="conformers", # input data to use for calculating dipoles
    n_conformation_column="n_conformers", # input data to use for calculating dipoles
    denominator=0.01,
    weight=1.0
)

am1bcc_esp_target = MultipleESPTarget(
    metric=RMSEMetric(),
    target_label="esps", # column to use from input data as reference target
    charge_label="predicted-am1bcc-charges", # readout charge value to calculate ESPs with
    inverse_distance_matrix_column="grid_inverse_distances", # input data to use to calculate ESPs
    esp_length_column="esp_lengths", # input data to use to calculate ESPs
    n_esp_column="n_conformers", # input data to use to calculate ESPs
    denominator=0.001,
    weight=1.0
)

Now we combine each of these targets into each DatasetConfig.

targets = [
    am1_charge_rmse_target,
    am1bcc_charge_rmse_target,
    am1bcc_dipole_target,
    am1bcc_esp_target,
]
    

training_dataset_config = DatasetConfig(
    sources=["training_data"],
    targets=targets,
    batch_size=1000,
)

test_dataset_config = validation_dataset_config = DatasetConfig(
    sources=["my_first_test_dataset"],
    targets=targets,
    batch_size=1000,
)

data_config = DataConfig(
    training=training_dataset_config,
    validation=validation_dataset_config,
    test=test_dataset_config
)

Train the model

We’ve prepared our model architecture and our training, validation and test data; now we just need to fit the model! To do this, we need to specify optimization settings with a OptimizerConfig, and then put everything together in a TrainingConfig.

from openff.nagl.config.optimizer import OptimizerConfig
from openff.nagl.config.training import TrainingConfig

optimizer_config = OptimizerConfig(
    optimizer="Adam",
    learning_rate=0.001,
)

training_config = TrainingConfig(
    model=model_config,
    data=data_config,
    optimizer=optimizer_config
)
from openff.nagl.training.training import TrainingGNNModel, DGLMoleculeDataModule

training_model = TrainingGNNModel(training_config)
data_module = DGLMoleculeDataModule(training_config)

To properly fit the model, we use the Trainer class from PyTorch Lightning. This allows us to configure how data and progress are stored and reported using callbacks. The fit() method trains and validates against the data module we provide it:

from pytorch_lightning import Trainer

trainer = Trainer(max_epochs=200)

trainer.progress_bar_callback.disable()
trainer.checkpoint_callback.monitor = "val/loss"

trainer.fit(
    training_model,
    datamodule=data_module
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/runner/work/openff-docs/openff-docs/build/cookbook/src/openforcefield/openff-nagl/train-multi-objective-gnn/lightning_logs
Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]

Featurizing batch:  40%|████      | 4/10 [00:00<00:00, 27.92it/s]

Featurizing batch:  70%|███████   | 7/10 [00:00<00:00, 22.24it/s]

Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 19.38it/s]
Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 20.01it/s]

Featurizing dataset: 1it [00:00,  1.60it/s]
Featurizing dataset: 1it [00:00,  1.59it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch:  67%|██████▋   | 2/3 [00:00<00:00, 14.48it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 17.08it/s]

Featurizing dataset: 1it [00:00,  4.27it/s]
Featurizing dataset: 1it [00:00,  4.24it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 21.60it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 18.94it/s]

Featurizing dataset: 1it [00:00,  4.90it/s]
Featurizing dataset: 1it [00:00,  4.86it/s]

2024-05-09 00:36:20.752632: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]

Featurizing batch:  30%|███       | 3/10 [00:00<00:00, 24.61it/s]

Featurizing batch:  60%|██████    | 6/10 [00:00<00:00, 21.79it/s]

Featurizing batch:  90%|█████████ | 9/10 [00:00<00:00, 19.99it/s]
Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 19.39it/s]

Featurizing dataset: 1it [00:00,  1.61it/s]
Featurizing dataset: 1it [00:00,  1.61it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch:  67%|██████▋   | 2/3 [00:00<00:00, 19.77it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 19.44it/s]

Featurizing dataset: 1it [00:00,  5.04it/s]
Featurizing dataset: 1it [00:00,  4.97it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 17.46it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 17.06it/s]

Featurizing dataset: 1it [00:00,  4.19it/s]
Featurizing dataset: 1it [00:00,  4.16it/s]

  | Name  | Type     | Params
-----------------------------------
0 | model | GNNModel | 203 K 
-----------------------------------
203 K     Trainable params
0         Non-trainable params
203 K     Total params
0.812     Total estimated model params size (MB)
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
`Trainer.fit` stopped: `max_epochs=200` reached.

Results!

We can use the Trainer object’s test() method to evaluate the model against our test data:

trainer.test(training_model, data_module)
Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]

Featurizing batch:  50%|█████     | 5/10 [00:00<00:00, 45.16it/s]

Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 42.53it/s]
Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 42.21it/s]

Featurizing dataset: 1it [00:00,  3.00it/s]
Featurizing dataset: 1it [00:00,  2.99it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 29.05it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 25.76it/s]

Featurizing dataset: 1it [00:00,  7.02it/s]
Featurizing dataset: 1it [00:00,  6.92it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]

Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 29.92it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 29.29it/s]

Featurizing dataset: 1it [00:00,  6.73it/s]
Featurizing dataset: 1it [00:00,  6.68it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]

Featurizing batch:  60%|██████    | 6/10 [00:00<00:00, 56.27it/s]
Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 45.28it/s]

Featurizing dataset: 1it [00:00,  3.67it/s]
Featurizing dataset: 1it [00:00,  3.65it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 39.56it/s]

Featurizing dataset: 1it [00:00,  9.74it/s]
Featurizing dataset: 1it [00:00,  9.57it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/3 [00:00<?, ?it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 30.73it/s]

Featurizing dataset: 1it [00:00,  8.04it/s]
Featurizing dataset: 1it [00:00,  7.93it/s]

/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                      Test metric                                             DataLoader 0                      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test/esps/multiple_esps/rmse/1.0/0.001                           0.07375835627317429                   │
│                       test/loss                                           4.134699821472168                    │
│      test/target-am1-charges/readout/rmse/1.0/1.0                        0.005569483619183302                  │
│   test/target-am1bcc-charges/readout/rmse/1.0/0.001                       3.9258058071136475                   │
│ test/target-am1bcc-dipoles/multiple_dipoles/rmse/1.0/…                   0.12956663966178894                   │
└────────────────────────────────────────────────────────┴────────────────────────────────────────────────────────┘
[{'test/target-am1-charges/readout/rmse/1.0/1.0': 0.005569483619183302,
  'test/target-am1bcc-charges/readout/rmse/1.0/0.001': 3.9258058071136475,
  'test/target-am1bcc-dipoles/multiple_dipoles/rmse/1.0/0.01': 0.12956663966178894,
  'test/esps/multiple_esps/rmse/1.0/0.001': 0.07375835627317429,
  'test/loss': 4.134699821472168}]

We can isolate the model itself from all the training requirements:

model = training_model.model

Octane isn’t in any of our data, so the model hasn’t seen it yet! We can predict its partial charges with the compute_property() method:

octane = Molecule.from_smiles("CCCCCCCC")

am1bcc_charges = model.compute_property(octane, readout_name="predicted-am1bcc-charges")

And we can compare that to the AM1BCC partial charges produced by the OpenFF Toolkit:

octane.assign_partial_charges("am1bcc")
octane.partial_charges
Magnitude
[-0.09306153846153846 -0.07936153846153846 -0.07836153846153845
-0.07836153846153845 -0.07836153846153845 -0.07836153846153845
-0.07936153846153846 -0.09306153846153846 0.03223846153846154
0.03223846153846154 0.03223846153846154 0.03798846153846154
0.03798846153846154 0.03848846153846154 0.03848846153846154
0.03973846153846154 0.03973846153846154 0.03973846153846154
0.03973846153846154 0.03848846153846154 0.03848846153846154
0.03798846153846154 0.03798846153846154 0.03223846153846154
0.03223846153846154 0.03223846153846154]
Unitselementary_charge
prediction = am1bcc_charges * unit.elementary_charge
np.abs(prediction - octane.partial_charges)
Magnitude
[0.0011176000796831592 0.0029470817529238247 0.0019667993866480382
0.0016618098699129613 0.0016617726170099767 0.0019667919360674413
0.0029470743023432278 0.0011176000796831592 0.0013382355726682185
0.0013382355726682185 0.0013382355726682185 0.0008358732021771925
0.0008358433998548048 0.0011267039037667764 0.0011267039037667764
0.00012329609623322468 0.00012329609623322468 0.00012328119507203084
0.0001232737444914339 0.0011267076290570749 0.0011267076290570749
0.0008358732021771925 0.0008358732021771925 0.0013382281220876216
0.0013382281220876216 0.0013382281220876216]
Unitselementary_charge

All within 0.002 elementary charge units of true AM1BCC charges! Not too bad!

Similarly, looking at AM1 charges:

am1_charges = model.compute_property(octane, readout_name="predicted-am1-charges")
am1_charges
array([-0.21599516, -0.16339104, -0.1620512 , -0.16116777, -0.16116776,
       -0.1620512 , -0.16339107, -0.21599522,  0.07433881,  0.07433881,
        0.07433881,  0.07892299,  0.07892299,  0.08043569,  0.08043569,
        0.08043569,  0.08043569,  0.08043569,  0.08043569,  0.08043569,
        0.08043569,  0.07892299,  0.07892299,  0.07433875,  0.07433875,
        0.07433875], dtype=float32)
octane.assign_partial_charges("am1-mulliken")
octane.partial_charges
Magnitude
[-0.21096153846153842 -0.15896153846153843 -0.15696153846153843
-0.15796153846153843 -0.15596153846153843 -0.15696153846153843
-0.15696153846153843 -0.21096153846153842 0.07103846153846155
0.07103846153846155 0.07103846153846155 0.07903846153846156
0.07803846153846156 0.07903846153846156 0.07703846153846156
0.07903846153846156 0.08003846153846156 0.08103846153846156
0.07603846153846155 0.08003846153846156 0.07503846153846155
0.07703846153846156 0.07503846153846155 0.07403846153846155
0.07003846153846156 0.07203846153846155]
Unitselementary_charge
am1_prediction = am1_charges * unit.elementary_charge
np.abs(am1_prediction - octane.partial_charges)
Magnitude
[0.005033624263910186 0.004429500313905599 0.005089662405160789
0.0032062321626223356 0.0052062172614611435 0.005089662405160789
0.006429530116227988 0.005033683868554961 0.0033003471172772803
0.0033003471172772803 0.0033003471172772803 0.00011547455420862773
0.0008845254457913732 0.001397231725546011 0.0033972317255460127
0.001397231725546011 0.00039723172554601005 0.0006027682744539908
0.004397231725546014 0.00039723172554601005 0.0053972317255460145
0.001884525445791374 0.003884525445791376 0.0003002875126325022
0.004300287512632492 0.002300287512632504]
Unitselementary_charge

This is slightly less accurate (to within 0.02 elementary charge), possibly because AM1 charges were only fit to the charges directly, with none of the physical properties.

Saving and loading our model

We can save the final model with the model.save() method. This’ll let us store it for later.

model.save("trained_alkane_model.pt")

When we want it again, we can use the GNNModel.load() method:

model_from_disk = GNNModel.load("trained_alkane_model.pt")
model_from_disk.compute_property(octane, readout_name="predicted-am1bcc-charges")
array([-0.09417914, -0.08230862, -0.08032834, -0.08002335, -0.08002331,
       -0.08032833, -0.08230861, -0.09417914,  0.0335767 ,  0.0335767 ,
        0.0335767 ,  0.03882433,  0.0388243 ,  0.03961517,  0.03961517,
        0.03961517,  0.03961517,  0.03961518,  0.03961519,  0.03961517,
        0.03961517,  0.03882433,  0.03882433,  0.03357669,  0.03357669,
        0.03357669], dtype=float32)