Download Notebook View in GitHub Open in Google Colab
Training a Graph Neural Network with NAGL with multiple objectives
This notebook will go through the process of training a new Graph Neural Network (GNN) on a small dataset of alkanes with multiple objectives. Please see the train-gnn-notebook
tutorial for more on what’s happening under the hood.
Imports
from pathlib import Path
import numpy as np
from openff.toolkit import Molecule
from openff.units import unit
Create the model
First, let’s specify the model features.
from openff.nagl.features import atoms
atom_features = (
atoms.AtomicElement(categories=["C", "H"]), # Is the atom Carbon or Hydrogen?
atoms.AtomConnectivity(), # Is the atom bonded to 1, 2, 3, or 4 other atoms?
atoms.AtomAverageFormalCharge(), # What is the atom's mean formal charge over the molecule's tautomers?
atoms.AtomHybridization(), # What is the hybridization of the atom?
atoms.AtomInRingOfSize(ring_size=3), # Is the atom in a 3-membered ring?
atoms.AtomInRingOfSize(ring_size=4), # Is the atom in a 4-membered ring?
atoms.AtomInRingOfSize(ring_size=5), # Is the atom in a 5-membered ring?
atoms.AtomInRingOfSize(ring_size=6), # Is the atom in a 6-membered ring?
)
We also need to specify the architecture of the GNN. We can make this as complicated as we like.
from openff.nagl.config.model import (
ConvolutionLayer,
ConvolutionModule,
)
from openff.nagl import GNNModel
from openff.nagl.nn.gcn import SAGEConvStack
from torch.nn import ReLU
from openff.nagl.nn.postprocess import ComputePartialCharges
single_convolution_layer = ConvolutionLayer(
hidden_feature_size=128, # 128 features per hidden convolution layer
aggregator_type="mean", # aggregate atom representations with mean
activation_function="ReLU", # max(0, x) activation function for layer
dropout=0.0, # no dropout
)
convolution_module = ConvolutionModule(
architecture="SAGEConv", # GraphSAGE GCN
layers=[single_convolution_layer] * 3, # 3 hidden convolution layers
)
We then specify the readout module.
from openff.nagl.config.model import (
ForwardLayer,
ReadoutModule,
)
single_readout_layer = ForwardLayer(
hidden_feature_size=128, # 128 features per hidden convolution layer
activation_function="ReLU", # max(0, x) activation function for layer
dropout=0.0, # no dropout
)
normal_readout_module = ReadoutModule(
pooling="atoms",
layers=[single_readout_layer] * 4, # 4 internal readout layers
# calculate charges with charge equilibration scheme from
# electronegativity and hardness
postprocess="compute_partial_charges"
)
regularised_readout_module = ReadoutModule(
pooling="atoms",
layers=[single_readout_layer] * 4, # 4 internal readout layers
# calculate charges with charge equilibration scheme from
# electronegativity, hardness, and an initial charge prediction
postprocess="regularized_compute_partial_charges"
)
Now we can put them together in a full ModelConfig
. This can be passed to create a GNNModel
. A model can have multiple readouts that derive different properties from the convolution representation, so each readout module is specified in a dictionary with a label.
Here, the GNNModel
class represents all the hyperparameters for a model, but after we train it the same object will store weights as well.
from openff.nagl.config.model import ModelConfig
model_config = ModelConfig(
version="0.1",
atom_features=atom_features,
bond_features=[],
convolution=convolution_module,
readouts={
"predicted-am1bcc-charges": normal_readout_module,
"predicted-am1-charges": regularised_readout_module
}
)
Put together our datasets
We need to set up three datasets hers:
training: Data the model is trained against
validation: Data used to validate the model as it is trained
tests: Data used to test that the final model is good
In this example, we’ll use a collection of ten molecules for training. We’ll also build a test/validation dataset of 3 molecules that are not in the training set.
We can use the [LabelledDataset
] class to generate training data that is saved in the training_data
directory (or use pyarrow
directly). First we can generate the dataset from SMILES:
from openff.nagl.label.dataset import LabelledDataset
training_alkanes = [
'C',
'CC',
'CCC',
'CCCC',
'CC(C)C',
'CCCCC',
'CC(C)CC',
'CCCCCC',
'CC(C)CCC',
'CC(CC)CC',
]
training_dataset = LabelledDataset.from_smiles(
"training_data",
training_alkanes,
mapped=False,
overwrite_existing=True,
)
training_dataset.to_pandas()
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pandas/core/computation/expressions.py:21: UserWarning: Pandas requires version '2.8.4' or newer of 'numexpr' (version '2.7.3' currently installed).
from pandas.core.computation.check import NUMEXPR_INSTALLED
mapped_smiles | |
---|---|
0 | [C:1]([H:2])([H:3])([H:4])[H:5] |
1 | [C:1]([C:2]([H:6])([H:7])[H:8])([H:3])([H:4])[... |
2 | [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([H:7])... |
3 | [C:1]([C:2]([C:3]([C:4]([H:12])([H:13])[H:14])... |
4 | [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([C:4](... |
5 | [C:1]([C:2]([C:3]([C:4]([C:5]([H:15])([H:16])[... |
6 | [C:1]([C:2]([C:3]([H:10])([H:11])[H:12])([C:4]... |
7 | [C:1]([C:2]([C:3]([C:4]([C:5]([C:6]([H:18])([H... |
8 | [C:1]([C:2]([C:3]([H:11])([H:12])[H:13])([C:4]... |
9 | [C:1]([C:2]([C:3]([C:4]([H:13])([H:14])[H:15])... |
Below we specify label functions to label our molecules with the information that we will use in training and testing. Each argument is specified and annotated to explain their purpose; however, all label functions can be instantiated with default arguments (e.g. LabelConformers()
) unless you need specific column names or arguments (e.g. changing the charge method).
Note: the ESP label function requires openff-recharge
to be installed.
from openff.nagl.label.labels import (
LabelConformers,
LabelCharges,
LabelMultipleDipoles,
LabelMultipleESPs,
)
import openff.recharge
# generate ELF conformers
label_conformers = LabelConformers(
# create a new 'conformers' with output conformers
conformer_column="conformers",
# create a new 'n_conformers' with number of conformers
n_conformer_column="n_conformers",
n_conformer_pool=500, # initially generate 500 conformers
n_conformers=10, # prune to max 10 conformers
rms_cutoff=0.05,
)
# generate AM1 charges
label_am1_charges = LabelCharges(
charge_method="am1-mulliken", # AM1
# use previously generate conformers instead of new ones
use_existing_conformers=True,
# use the 'conformers' column as input for charge assignment
conformer_column="conformers",
# write generated charges to 'target-am1-charges' column
charge_column="target-am1-charges",
)
# generate AM1-BCC charges
label_am1bcc_charges = LabelCharges(
charge_method="am1bcc", # AM1BCC
# use previously generate conformers instead of new ones
use_existing_conformers=True,
# use the 'conformers' column as input for charge assignment
conformer_column="conformers",
# write generated charges to 'target-am1bcc-charges' column
charge_column="target-am1bcc-charges",
)
label_am1bcc_dipoles = LabelMultipleDipoles(
# use the 'conformers' column as input to calculate dipole moments
conformer_column="conformers",
# use the 'n_conformers' column as input
n_conformer_column="n_conformers",
# use the "target-am1bcc-charges" column as input to calculate dipole moments
charge_column="target-am1bcc-charges",
# write calculated dipoles to 'target-am1bcc-dipoles' column
dipole_column="target-am1bcc-dipoles",
)
label_am1bcc_esps = LabelMultipleESPs(
# use the 'conformers' column as input to calculate ESPs
conformer_column="conformers",
# use the 'n_conformers' column as input
n_conformer_column="n_conformers",
# use the "target-am1bcc-charges" column as input to calculate ESPS
charge_column="target-am1bcc-charges",
# generate new grids and inverse distances to points
use_existing_inverse_distances=False,
# write inverse distances from conformer to surface to this column
inverse_distance_matrix_column="grid_inverse_distances",
# write number of grid points for each surface to this column
grid_length_column="esp_lengths",
# write calculated ESPs to 'esps' column
esp_column="esps",
)
Below we apply the label functions to actually generate the labels. The order matters, as later label functions use the output of earlier ones.
labellers = [
label_conformers, # generate initial conformers,
label_am1_charges,
label_am1bcc_charges,
label_am1bcc_dipoles,
label_am1bcc_esps,
]
training_dataset.apply_labellers(labellers)
training_dataset.to_pandas()
mapped_smiles | conformers | n_conformers | target-am1-charges | target-am1bcc-charges | target-am1bcc-dipoles | esp_lengths | grid_inverse_distances | esps | |
---|---|---|---|---|---|---|---|---|---|
0 | [C:1]([H:2])([H:3])([H:4])[H:5] | [0.005118712612069831, -0.010620498663889949, ... | 1 | [-0.2656, 0.0664, 0.0664, 0.0664, 0.0664] | [-0.1084, 0.0271, 0.0271, 0.0271, 0.0271] | [-0.0006935855589354605, 0.0014390775689570746... | [398] | [0.22234336592604098, 0.1691980088400493, 0.26... | [-0.00048403793895890833, -0.00071972139957695... |
1 | [C:1]([C:2]([H:6])([H:7])[H:8])([H:3])([H:4])[... | [-0.7455231747504416, 0.04144445508118371, 0.0... | 1 | [-0.21225, -0.21225, 0.07075, 0.07075, 0.07075... | [-0.09435, -0.09435, 0.03145, 0.03145, 0.03145... | [-0.0002285229334980654, -0.005575840474810592... | [500] | [0.22234336592604098, 0.18774435919748517, 0.3... | [-0.001407404714360157, -0.0029490792419474554... |
2 | [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([H:7])... | [1.2163074727474537, -0.24294836930504896, 0.2... | 1 | [-0.211, -0.16, -0.211, 0.072, 0.072, 0.071, 0... | [-0.09310018181818182, -0.08140018181818182, -... | [-0.0043053606473935704, -0.009872806411602934... | [602] | [0.22234336592604098, 0.2043827690515404, 0.18... | [-0.002920570510046846, -0.0009924942195456941... |
3 | [C:1]([C:2]([C:3]([C:4]([H:12])([H:13])[H:14])... | [1.8901957496718658, 0.042575097521835, 0.2431... | 1 | [-0.21028571428571427, -0.15928571428571428, -... | [-0.09238585714285714, -0.08068585714285714, -... | [-0.003195845195204508, -0.00494009309346443, ... | [690] | [0.22234336592604098, 0.17092450502366457, 0.1... | [-0.000870799470089798, 4.8700196369716106e-05... |
4 | [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([C:4](... | [1.4351901021861362, 0.27740245802875496, 0.13... | 1 | [-0.20735714285714285, -0.10935714285714285, -... | [-0.08945692857142856, -0.07005692857142856, -... | [-0.0059915089786667305, -0.009373247525965786... | [672] | [0.22234336592604098, 0.19109842170559405, 0.1... | [-0.0008717150501716528, -0.003425712120807677... |
5 | [C:1]([C:2]([C:3]([C:4]([C:5]([H:15])([H:16])[... | [2.4660243557948207, -0.0503036791214956, 0.03... | 1 | [-0.21, -0.159, -0.158, -0.159, -0.21, 0.07200... | [-0.09210011764705882, -0.08040011764705882, -... | [-0.003987426571054067, 0.005559284278063413, ... | [791] | [0.22234336592604098, 0.14213609699084526, 0.1... | [0.0005046598853891658, 0.0002250632479590313,... |
6 | [C:1]([C:2]([C:3]([H:10])([H:11])[H:12])([C:4]... | [-1.2418582991411082, 1.052060941462957, 0.017... | 1 | [-0.207, -0.10599999999999998, -0.207, -0.153,... | [-0.08909988235294117, -0.06669988235294116, -... | [-0.0026498796085173554, -0.008039775679738135... | [750] | [0.22234336592604098, 0.15192142508151973, 0.1... | [-0.00023715760591162555, -0.00101041663676669... |
7 | [C:1]([C:2]([C:3]([C:4]([C:5]([C:6]([H:18])([H... | [-3.0907008742942494, -0.14370944748493916, -0... | 1 | [-0.21009999999999998, -0.1591, -0.1581, -0.15... | [-0.0922, -0.0805, -0.0795, -0.0795, -0.0805, ... | [0.0018400328214143447, -0.014928553541414105,... | [901] | [0.22234336592604098, 0.1986135821292406, 0.15... | [-0.0014053749951576907, -0.000203661455117260... |
8 | [C:1]([C:2]([C:3]([H:11])([H:12])[H:13])([C:4]... | [-1.3441616214248866, 1.2928572296227032, -0.3... | 1 | [-0.20694999999999997, -0.10694999999999998, -... | [-0.08955004999999999, -0.06765004999999999, -... | [-0.00816547475433177, 0.0055545221371308545, ... | [844] | [0.22234336592604098, 0.1852101593428309, 0.13... | [-0.0012388351037772895, -0.002423698057369004... |
9 | [C:1]([C:2]([C:3]([C:4]([H:13])([H:14])[H:15])... | [-0.40122078234363445, -1.422374025464223, -0.... | 1 | [-0.20794999999999997, -0.10594999999999999, -... | [-0.09005, -0.06665, -0.07785, -0.09205, -0.07... | [0.001310505146191715, -0.012141030775731199, ... | [829] | [0.22234336592604098, 0.1436149999840705, 0.11... | [-0.00044908769940229555, -0.00022010214148746... |
Building a test dataset
To augment the provided training set, we’ll quickly prepare a second dataset for testing and validation. We use the same label functions:
from openff.nagl.label.labels import LabelCharges
# Choose the molecules to put in this dataset
# Note that these molecules aren't in the training dataset!
test_smiles = [
"CCCCCCC",
"CC(C)C(C)C",
"CC(C)(C)C",
]
test_dataset = LabelledDataset.from_smiles(
"my_first_test_dataset", # path to save to
test_smiles,
mapped=False,
overwrite_existing=True,
)
test_dataset.apply_labellers(labellers)
test_dataset.to_pandas()
mapped_smiles | conformers | n_conformers | target-am1-charges | target-am1bcc-charges | target-am1bcc-dipoles | esp_lengths | grid_inverse_distances | esps | |
---|---|---|---|---|---|---|---|---|---|
0 | [C:1]([C:2]([C:3]([C:4]([C:5]([C:6]([C:7]([H:2... | [-3.545070958260261, 0.5832530008833098, 0.079... | 1 | [-0.21008695652173912, -0.15908695652173913, -... | [-0.09218686956521739, -0.07998686956521739, -... | [-0.00595775894206417, 0.0008582141430205914, ... | [944] | [0.22234336592604098, 0.19692667350664092, 0.1... | [-0.0006610506078259431, -0.000516657266085287... |
1 | [C:1]([C:2]([C:3]([H:11])([H:12])[H:13])([C:4]... | [0.9713614798232566, -0.892209543378047, 1.004... | 1 | [-0.20589999999999997, -0.10289999999999998, -... | [-0.0890002, -0.06360020000000001, -0.0890002,... | [-0.002531086126934007, -0.00711238458844677, ... | [803] | [0.22234336592604098, 0.17641696600758316, 0.1... | [-0.0006969112009999837, -0.000166671288359785... |
2 | [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([C:4](... | [1.2935034573089457, -0.7675514803987564, -0.0... | 1 | [-0.20323529411764704, -0.06023529411764704, -... | [-0.08533529411764705, -0.06023529411764706, -... | [0.003012641718856162, 0.010054820006400858, 0... | [743] | [0.22234336592604098, 0.1629178797473639, 0.18... | [1.7940051204891212e-05, 0.0004078426202498559... |
Curating our data module
Now we assemble our datasets into a DataConfig
. For each DatasetConfig
, we need to specify the targets we are choosing to fit. A Target
is what we used to construct the objective function and calculate loss. Below we:
fit the
predicted-am1-charges
property directly to the labelledtarget-am1-charges
fit the
predicted-am1bcc-charges
property to a combined objective of:charge RMSE
dipole moments
ESP targets
For the physical properties, we need additional information (the conformers
, n_conformers
, … columns) to calculate the dipole moments and ESPs for comparison. Each of these has been annotated. The target_label
always refers to the column in the input dataset that is the property we are comparing.
from openff.nagl.config.data import DatasetConfig, DataConfig
from openff.nagl.training.metrics import RMSEMetric
from openff.nagl.training.loss import ReadoutTarget, MultipleDipoleTarget, MultipleESPTarget
am1_charge_rmse_target = ReadoutTarget(
metric=RMSEMetric(), # use RMSE to calculate loss
target_label="target-am1-charges", # column to use from data as reference target
prediction_label="predicted-am1-charges", # readout value to compare to target
denominator=1.0, # denominator to normalise loss -- important for multi-target objectives
weight=1.0, # how much to weight the loss -- important for multi-target objectives
)
am1bcc_charge_rmse_target = ReadoutTarget(
metric=RMSEMetric(), # use RMSE to calculate loss
target_label="target-am1bcc-charges", # column to use from data as reference target
prediction_label="predicted-am1bcc-charges", # readout value to compare to target
denominator=0.001, # denominator to normalise loss -- important for multi-target objectives
weight=1.0, # how much to weight the loss -- important for multi-target objectives
)
am1bcc_dipole_target = MultipleDipoleTarget(
metric=RMSEMetric(),
target_label="target-am1bcc-dipoles", # column to use from input data as reference target
charge_label="predicted-am1bcc-charges", # readout charge value to calculate dipoles with
conformation_column="conformers", # input data to use for calculating dipoles
n_conformation_column="n_conformers", # input data to use for calculating dipoles
denominator=0.01,
weight=1.0
)
am1bcc_esp_target = MultipleESPTarget(
metric=RMSEMetric(),
target_label="esps", # column to use from input data as reference target
charge_label="predicted-am1bcc-charges", # readout charge value to calculate ESPs with
inverse_distance_matrix_column="grid_inverse_distances", # input data to use to calculate ESPs
esp_length_column="esp_lengths", # input data to use to calculate ESPs
n_esp_column="n_conformers", # input data to use to calculate ESPs
denominator=0.001,
weight=1.0
)
Now we combine each of these targets into each DatasetConfig
.
targets = [
am1_charge_rmse_target,
am1bcc_charge_rmse_target,
am1bcc_dipole_target,
am1bcc_esp_target,
]
training_dataset_config = DatasetConfig(
sources=["training_data"],
targets=targets,
batch_size=1000,
)
test_dataset_config = validation_dataset_config = DatasetConfig(
sources=["my_first_test_dataset"],
targets=targets,
batch_size=1000,
)
data_config = DataConfig(
training=training_dataset_config,
validation=validation_dataset_config,
test=test_dataset_config
)
Train the model
We’ve prepared our model architecture and our training, validation and test data; now we just need to fit the model! To do this, we need to specify optimization settings with a OptimizerConfig
, and then put everything together in a TrainingConfig
.
from openff.nagl.config.optimizer import OptimizerConfig
from openff.nagl.config.training import TrainingConfig
optimizer_config = OptimizerConfig(
optimizer="Adam",
learning_rate=0.001,
)
training_config = TrainingConfig(
model=model_config,
data=data_config,
optimizer=optimizer_config
)
from openff.nagl.training.training import TrainingGNNModel, DGLMoleculeDataModule
training_model = TrainingGNNModel(training_config)
data_module = DGLMoleculeDataModule(training_config)
To properly fit the model, we use the Trainer
class from PyTorch Lightning. This allows us to configure how data and progress are stored and reported using callbacks. The fit()
method trains and validates against the data module we provide it:
from pytorch_lightning import Trainer
trainer = Trainer(max_epochs=200)
trainer.progress_bar_callback.disable()
trainer.checkpoint_callback.monitor = "val/loss"
trainer.fit(
training_model,
datamodule=data_module
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/10 [00:00<?, ?it/s]
Featurizing batch: 40%|████ | 4/10 [00:00<00:00, 39.01it/s]
Featurizing batch: 80%|████████ | 8/10 [00:00<00:00, 32.66it/s]
Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 31.83it/s]
Featurizing dataset: 1it [00:00, 2.62it/s]
Featurizing dataset: 1it [00:00, 2.62it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/3 [00:00<?, ?it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 26.73it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 26.18it/s]
Featurizing dataset: 1it [00:00, 6.92it/s]
Featurizing dataset: 1it [00:00, 6.88it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/3 [00:00<?, ?it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 36.79it/s]
Featurizing dataset: 1it [00:00, 8.96it/s]
Featurizing dataset: 1it [00:00, 8.89it/s]
2024-10-07 00:38:11.171953: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/10 [00:00<?, ?it/s]
Featurizing batch: 80%|████████ | 8/10 [00:00<00:00, 67.78it/s]
Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 63.81it/s]
Featurizing dataset: 1it [00:00, 4.95it/s]
Featurizing dataset: 1it [00:00, 4.93it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/3 [00:00<?, ?it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 53.76it/s]
Featurizing dataset: 1it [00:00, 13.64it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/3 [00:00<?, ?it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 52.47it/s]
Featurizing dataset: 1it [00:00, 11.98it/s]
| Name | Type | Params | Mode
-------------------------------------------
0 | model | GNNModel | 203 K | train
-------------------------------------------
203 K Trainable params
0 Non-trainable params
203 K Total params
0.812 Total estimated model params size (MB)
57 Modules in train mode
0 Modules in eval mode
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
`Trainer.fit` stopped: `max_epochs=200` reached.
Results!
We can use the Trainer
object’s test()
method to evaluate the model against our test data:
trainer.test(training_model, data_module)
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/10 [00:00<?, ?it/s]
Featurizing batch: 80%|████████ | 8/10 [00:00<00:00, 70.44it/s]
Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 66.18it/s]
Featurizing dataset: 1it [00:00, 5.30it/s]
Featurizing dataset: 1it [00:00, 5.28it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/3 [00:00<?, ?it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 54.35it/s]
Featurizing dataset: 1it [00:00, 13.53it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/3 [00:00<?, ?it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 53.74it/s]
Featurizing dataset: 1it [00:00, 13.62it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/10 [00:00<?, ?it/s]
Featurizing batch: 80%|████████ | 8/10 [00:00<00:00, 69.46it/s]
Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 65.57it/s]
Featurizing dataset: 1it [00:00, 5.29it/s]
Featurizing dataset: 1it [00:00, 5.27it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/3 [00:00<?, ?it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 55.25it/s]
Featurizing dataset: 1it [00:00, 13.81it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/3 [00:00<?, ?it/s]
Featurizing batch: 100%|██████████| 3/3 [00:00<00:00, 56.08it/s]
Featurizing dataset: 1it [00:00, 14.04it/s]
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test/esps/multiple_esps/rmse/1.0/0.001 │ 0.05779905989766121 │ │ test/loss │ 2.337519645690918 │ │ test/target-am1-charges/readout/rmse/1.0/1.0 │ 0.006489156279712915 │ │ test/target-am1bcc-charges/readout/rmse/1.0/0.001 │ 2.153856039047241 │ │ test/target-am1bcc-dipoles/multiple_dipoles/rmse/1.0/… │ 0.11937537044286728 │ └────────────────────────────────────────────────────────┴────────────────────────────────────────────────────────┘
[{'test/target-am1-charges/readout/rmse/1.0/1.0': 0.006489156279712915,
'test/target-am1bcc-charges/readout/rmse/1.0/0.001': 2.153856039047241,
'test/target-am1bcc-dipoles/multiple_dipoles/rmse/1.0/0.01': 0.11937537044286728,
'test/esps/multiple_esps/rmse/1.0/0.001': 0.05779905989766121,
'test/loss': 2.337519645690918}]
We can isolate the model itself from all the training requirements:
model = training_model.model
Octane isn’t in any of our data, so the model hasn’t seen it yet! We can predict its partial charges with the compute_property()
method:
octane = Molecule.from_smiles("CCCCCCCC")
am1bcc_charges = model.compute_property(octane, readout_name="predicted-am1bcc-charges")
And we can compare that to the AM1BCC partial charges produced by the OpenFF Toolkit:
octane.assign_partial_charges("am1bcc")
octane.partial_charges
Magnitude | [-0.09306153846153846 -0.07936153846153846 -0.07836153846153845 |
---|---|
Units | elementary_charge |
prediction = am1bcc_charges * unit.elementary_charge
np.abs(prediction - octane.partial_charges)
Magnitude | [0.00044303196301827275 0.0020996229135073208 0.0019252325974978002 |
---|---|
Units | elementary_charge |
All within 0.002 elementary charge units of true AM1BCC charges! Not too bad!
Similarly, looking at AM1 charges:
am1_charges = model.compute_property(octane, readout_name="predicted-am1-charges")
am1_charges
array([-0.1979023 , -0.14820972, -0.1478569 , -0.14731342, -0.14731342,
-0.1478569 , -0.14820972, -0.1979023 , 0.06688948, 0.06688948,
0.06688948, 0.07278787, 0.07278787, 0.07375953, 0.07375953,
0.07375953, 0.07375953, 0.07375953, 0.07375953, 0.07375953,
0.07375953, 0.07278787, 0.07278787, 0.06688948, 0.06688948,
0.06688948], dtype=float32)
octane.assign_partial_charges("am1-mulliken")
octane.partial_charges
Magnitude | [-0.21096153846153842 -0.15896153846153843 -0.15696153846153843 |
---|---|
Units | elementary_charge |
am1_prediction = am1_charges * unit.elementary_charge
np.abs(am1_prediction - octane.partial_charges)
Magnitude | [0.013059231547208894 0.010751817611547587 0.009104632405134316 |
---|---|
Units | elementary_charge |
This is slightly less accurate (to within 0.02 elementary charge), possibly because AM1 charges were only fit to the charges directly, with none of the physical properties.
Saving and loading our model
We can save the final model with the model.save()
method. This’ll let us store it for later.
model.save("trained_alkane_model.pt")
When we want it again, we can use the GNNModel.load()
method:
model_from_disk = GNNModel.load("trained_alkane_model.pt")
model_from_disk.compute_property(octane, readout_name="predicted-am1bcc-charges")
array([-0.09350457, -0.08146116, -0.08028677, -0.08014869, -0.08014869,
-0.08028677, -0.08146116, -0.09350457, 0.03329875, 0.03329875,
0.03329875, 0.03867169, 0.03867169, 0.03954038, 0.03954038,
0.03954038, 0.03954038, 0.03954038, 0.03954038, 0.03954038,
0.03954038, 0.03867169, 0.03867169, 0.03329875, 0.03329875,
0.03329875], dtype=float32)