Download Notebook View in GitHub Open in Google Colab
Train a GNN directly to an electric field
To execute this example fully, the following packages are required.
openff-nagl
openff-recharge
openff-qcsubmit
psi4
However, if you wish to just follow along the training part without first creating the training datasets yourself, you can get away with just openff-nagl
installed and simply load the training/validation data from the provided .parquet
files. The commands are provided at the end of the “Generate and format training data” section, but commented out.
import tqdm
from qcportal import PortalClient
from openff.units import unit
from openff.toolkit import Molecule
from openff.qcsubmit.results import BasicResultCollection
from openff.recharge.esp.storage import MoleculeESPRecord
from openff.recharge.esp.qcresults import from_qcportal_results
from openff.recharge.grids import MSKGridSettings
from openff.recharge.utilities.geometry import compute_vector_field
import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np
import torch
Generate and format training data
Downloading from QCArchive
First, we will create training data. We’ll download a smaller training set for the purposes of this example.
qc_client = PortalClient("https://api.qcarchive.molssi.org:443", cache_dir=".")
# download dataset from QCPortal
br_esps_collection = BasicResultCollection.from_server(
client=qc_client,
datasets="OpenFF multi-Br ESP Fragment Conformers v1.1",
spec_name="HF/6-31G*",
)
records_and_molecules = br_esps_collection.to_records()
Converting to MoleculeESPRecords
Now we convert to OpenFF Recharge records and compute the ESPs and electric fields of each molecule.
# Create OpenFF Recharge MoleculeESPRecords
grid_settings = MSKGridSettings()
# this can take a while; set records_and_molecules[:10]
# to only use the first 10
molecule_esp_records = [
from_qcportal_results(
qc_result=qcrecord,
qc_molecule=qcrecord.molecule,
qc_keyword_set=qcrecord.specification.keywords,
grid_settings=grid_settings,
compute_field=True
)
for qcrecord, _ in tqdm.tqdm(records_and_molecules[:50])
]
0%| | 0/50 [00:00<?, ?it/s]
2%|▏ | 1/50 [00:13<10:57, 13.42s/it]
4%|▍ | 2/50 [00:23<09:08, 11.42s/it]
6%|▌ | 3/50 [00:35<09:02, 11.55s/it]
8%|▊ | 4/50 [00:44<08:07, 10.60s/it]
10%|█ | 5/50 [00:52<07:17, 9.73s/it]
12%|█▏ | 6/50 [01:00<06:46, 9.24s/it]
14%|█▍ | 7/50 [01:09<06:33, 9.15s/it]
16%|█▌ | 8/50 [01:17<06:10, 8.82s/it]
18%|█▊ | 9/50 [01:28<06:23, 9.36s/it]
20%|██ | 10/50 [01:35<05:46, 8.66s/it]
22%|██▏ | 11/50 [01:43<05:26, 8.37s/it]
24%|██▍ | 12/50 [01:50<05:04, 8.01s/it]
26%|██▌ | 13/50 [01:55<04:27, 7.22s/it]
28%|██▊ | 14/50 [02:00<03:57, 6.59s/it]
30%|███ | 15/50 [02:08<04:05, 7.02s/it]
32%|███▏ | 16/50 [02:16<04:04, 7.20s/it]
34%|███▍ | 17/50 [02:20<03:25, 6.23s/it]
36%|███▌ | 18/50 [02:25<03:05, 5.80s/it]
38%|███▊ | 19/50 [02:29<02:40, 5.18s/it]
40%|████ | 20/50 [02:38<03:13, 6.46s/it]
42%|████▏ | 21/50 [02:47<03:26, 7.14s/it]
44%|████▍ | 22/50 [02:55<03:27, 7.42s/it]
46%|████▌ | 23/50 [03:10<04:25, 9.83s/it]
48%|████▊ | 24/50 [03:25<04:51, 11.22s/it]
50%|█████ | 25/50 [03:32<04:10, 10.03s/it]
52%|█████▏ | 26/50 [03:36<03:18, 8.27s/it]
54%|█████▍ | 27/50 [03:42<02:51, 7.47s/it]
56%|█████▌ | 28/50 [03:47<02:32, 6.92s/it]
58%|█████▊ | 29/50 [03:54<02:21, 6.76s/it]
60%|██████ | 30/50 [04:01<02:15, 6.76s/it]
62%|██████▏ | 31/50 [04:05<01:55, 6.08s/it]
64%|██████▍ | 32/50 [04:10<01:45, 5.85s/it]
66%|██████▌ | 33/50 [04:16<01:37, 5.71s/it]
68%|██████▊ | 34/50 [04:21<01:28, 5.56s/it]
70%|███████ | 35/50 [04:27<01:25, 5.72s/it]
72%|███████▏ | 36/50 [04:33<01:19, 5.67s/it]
74%|███████▍ | 37/50 [04:39<01:17, 5.93s/it]
76%|███████▌ | 38/50 [04:46<01:13, 6.13s/it]
78%|███████▊ | 39/50 [04:52<01:09, 6.28s/it]
80%|████████ | 40/50 [04:59<01:03, 6.36s/it]
82%|████████▏ | 41/50 [05:05<00:57, 6.35s/it]
84%|████████▍ | 42/50 [05:12<00:52, 6.53s/it]
86%|████████▌ | 43/50 [05:19<00:46, 6.61s/it]
88%|████████▊ | 44/50 [05:25<00:38, 6.43s/it]
90%|█████████ | 45/50 [05:31<00:31, 6.37s/it]
92%|█████████▏| 46/50 [05:38<00:25, 6.47s/it]
94%|█████████▍| 47/50 [05:44<00:19, 6.44s/it]
96%|█████████▌| 48/50 [05:50<00:12, 6.18s/it]
98%|█████████▊| 49/50 [05:57<00:06, 6.48s/it]
100%|██████████| 50/50 [06:04<00:00, 6.56s/it]
100%|██████████| 50/50 [06:04<00:00, 7.29s/it]
Convert to PyArrow dataset
NAGL reads in and trains to data from PyArrow tables. Below we do some conversion of each electric field to fit a basic GeneralLinearFit target, which fits the equation Ax = b. To avoid carrying too many data points around, we can do some postprocessing by first flattening the electric field matrix from 3 dimensions to 2, then multiplying by the transpose of A.
pyarrow_entries = []
for molecule_esp_record in tqdm.tqdm(molecule_esp_records):
electric_field = molecule_esp_record.electric_field # in atomic units
grid = molecule_esp_record.grid_coordinates * unit.angstrom
conformer = molecule_esp_record.conformer * unit.angstrom
vector_field = compute_vector_field(
conformer.m_as(unit.bohr), # shape: M x 3
grid.m_as(unit.bohr), # shape: N x 3
) # N x 3 x M
# postprocess so we're not carrying around millions of floats
# firstly flatten out N x 3 -> 3N x M
vector_field_2d = np.concatenate(vector_field, axis=0)
electric_field_1d = np.concatenate(electric_field, axis=0)
# now multiply by vector_field_2d's transpose
new_precursor_matrix = vector_field_2d.T @ vector_field_2d
new_field_vector = vector_field_2d.T @ electric_field_1d
n_atoms = conformer.shape[0]
assert new_precursor_matrix.shape == (n_atoms, n_atoms)
assert new_field_vector.shape == (n_atoms,)
# create entry. These columns are essential
# mapped_smiles is essential for every target
entry = {
"mapped_smiles": molecule_esp_record.tagged_smiles,
"precursor_matrix": new_precursor_matrix.flatten().tolist(),
"prediction_vector": new_field_vector.tolist()
}
pyarrow_entries.append(entry)
# arbitrarily split into training and validation datasets
training_pyarrow_entries = pyarrow_entries[:-10]
validation_pyarrow_entries = pyarrow_entries[-10:]
training_table = pa.Table.from_pylist(training_pyarrow_entries)
validation_table = pa.Table.from_pylist(validation_pyarrow_entries)
training_table
0%| | 0/50 [00:00<?, ?it/s]
66%|██████▌ | 33/50 [00:00<00:00, 329.63it/s]
100%|██████████| 50/50 [00:00<00:00, 315.72it/s]
pyarrow.Table
mapped_smiles: string
precursor_matrix: list<item: double>
child 0, item: double
prediction_vector: list<item: double>
child 0, item: double
----
mapped_smiles: [["[H:1][N:2]1[C:3]([Br:4])=[N:5][c:6]2[n:7][c:8]([Br:9])[c:10]([Br:11])[n:12][c:13]21","[H:1][N:2]([H:3])[C:4]1([C:5](=[O:6])[O-:7])[C:8]([H:9])([H:10])[C:11]2([H:12])[C:13]([Br:14])([Br:15])[C:16]2([H:17])[C:18]1([H:19])[H:20]","[H:1][C:2]1([H:3])[C:4]2([H:5])[C:6]([Br:7])([Br:8])[C:9]2([H:10])[C:11]([H:12])([H:13])[C:14]1([C:15](=[O:16])[O-:17])[N+:18]([H:19])([H:20])[H:21]","[H:1][N:2]1[C:3]([H:4])([H:5])[C:6]2([H:7])[C:8]([Br:9])([Br:10])[C:11]2([H:12])[C:13]1([H:14])[C:15](=[O:16])[O-:17]","[H:1][C:2]12[C:3]([Br:4])([Br:5])[C:6]1([H:7])[C:8]([H:9])([C:10](=[O:11])[O-:12])[N+:13]([H:14])([H:15])[C:16]2([H:17])[H:18]",...,"[H:1][c:2]1[c:3]([Br:4])[c:5]([H:6])[c:7]2[c:8]([c:9]1[Br:10])[C:11]([H:12])([H:13])[C:14]([H:15])([H:16])[C:17]2([H:18])[H:19]","[H:1][c:2]1[c:3]([H:4])[c:5]2[c:6]([c:7]([Br:8])[c:9]1[Br:10])[C:11]([H:12])([H:13])[C:14]([H:15])([H:16])[C:17]2([H:18])[N:19]([H:20])[H:21]","[H:1][c:2]1[c:3]([H:4])[c:5]2[c:6]([c:7]([Br:8])[c:9]1[Br:10])[C:11]([H:12])([H:13])[C:14]([H:15])([H:16])[C:17]2([H:18])[N+:19]([H:20])([H:21])[H:22]","[H:1][c:2]1[c:3]([H:4])[c:5]2[c:6]([c:7]([Br:8])[c:9]1[Br:10])[N:11]([H:12])[C:13]([H:14])([H:15])[C:16]([H:17])([H:18])[C:19]2([H:20])[H:21]","[H:1][c:2]1[c:3]([Br:4])[c:5]([H:6])[c:7]2[c:8]([c:9]1[Br:10])[C:11]([H:12])([H:13])[N:14]([H:15])[C:16]([H:17])([H:18])[C:19]2([H:20])[H:21]"]]
precursor_matrix: [[[0.5692106565580053,0.4085897652889512,0.284249916238166,0.20940517539253248,0.20686844494784193,...,0.11433558179284603,0.23263226801337633,0.1371485227635511,0.3023917581405904,0.32730801486131705],[0.5260442629899051,0.38089630545801045,0.35215718214227837,0.24910585545578626,0.17644186782494933,...,0.22243116268067484,0.19814033340715625,0.300593182954176,0.3088962938720413,0.4158481032091346],...,[0.5289925354233009,0.37414187483676303,0.28313431180338455,0.27734985432011455,0.17944395434503416,...,0.2880391596291975,0.2321472722241459,0.3857071059607491,0.3665630767875717,0.5677900784441218],[0.4983240011637843,0.3447984584130681,0.25468324548837756,0.20316511244549332,0.17595041719921767,...,0.27107542957261366,0.21722450631808865,0.3626687533961094,0.3447698917574723,0.5347596590639927]]]
prediction_vector: [[[0.11831342820663876,0.07052253574331227,0.03443373566615229,0.031423455167753835,-0.006535801800470039,...,-0.015353416908781552,0.0058843508732542345,-0.011419043977358915,0.015153163796349949,0.039054645977152046],[-0.1168859699903578,-0.16262446070843317,-0.17209952898209305,-0.16772575000042161,-0.23935707380099153,...,-0.11399722466179721,-0.09336325633851358,-0.12839560498881492,-0.12752534734377857,-0.10652701620291485],...,[0.01076983677321753,-0.005588997214517595,0.007231581443923954,0.03031856180479904,0.008379193273740145,...,0.04214775639020978,0.03202866265802594,0.0301970861903761,0.03768884431038124,0.03714789835823986],[0.019492995260861956,0.005150669698426461,0.000872755565492754,-0.01489613518860242,0.015886122871734236,...,0.03648331002928913,0.03218903273689575,0.04016640969171912,0.046849895898038244,0.04789045546505114]]]
pq.write_table(training_table, "training_dataset_table.parquet")
pq.write_table(validation_table, "validation_dataset_table.parquet")
# # to read back in -- note, the files saved here give the full dataset, not the 50 record subset
# training_table = pq.read_table("training_dataset_table.parquet")
# validation_table = pq.read_table("validation_dataset_table.parquet")
Set up for training a GNN
from openff.nagl.config import (
TrainingConfig,
OptimizerConfig,
ModelConfig,
DataConfig
)
from openff.nagl.config.model import (
ConvolutionModule, ReadoutModule,
ConvolutionLayer, ForwardLayer,
)
from openff.nagl.config.data import DatasetConfig
from openff.nagl.training.training import TrainingGNNModel
from openff.nagl.features.atoms import (
AtomicElement,
AtomConnectivity,
AtomInRingOfSize,
AtomAverageFormalCharge,
)
from openff.nagl.training.loss import GeneralLinearFitTarget
Defining the training config
Defining a ModelConfig
First we define a ModelConfig.
This can be done in Python, but in practice it is probably easier to define the model in a YAML file and load it with ModelConfig.from_yaml
.
atom_features = [
AtomicElement(categories=["H", "C", "N", "O", "F", "Br", "S", "P", "I"]),
AtomConnectivity(categories=[1, 2, 3, 4, 5, 6]),
AtomInRingOfSize(ring_size=3),
AtomInRingOfSize(ring_size=4),
AtomInRingOfSize(ring_size=5),
AtomInRingOfSize(ring_size=6),
AtomAverageFormalCharge(),
]
# define our convolution module
convolution_module = ConvolutionModule(
architecture="SAGEConv",
# construct 6 layers with dropout 0 (default),
# hidden feature size 512, and ReLU activation function
# these layers can also be individually specified,
# but we just duplicate the layer 6 times for identical layers
layers=[
ConvolutionLayer(
hidden_feature_size=512,
activation_function="ReLU",
aggregator_type="mean"
)
] * 6,
)
# define our readout module/s
# multiple are allowed but let's focus on charges
readout_modules = {
# key is the name of output property, any naming is allowed
"charges": ReadoutModule(
pooling="atoms",
postprocess="compute_partial_charges",
# 2 layers
layers=[
ForwardLayer(
hidden_feature_size=512,
activation_function="ReLU",
)
] * 2,
)
}
# bring it all together
model_config = ModelConfig(
version="0.1",
atom_features=atom_features,
convolution=convolution_module,
readouts=readout_modules,
)
Defining a DataConfig
We can then define our dataset configs. Here we also have to specify our training targets.
target = GeneralLinearFitTarget(
# what we're using to evaluate loss
target_label="prediction_vector",
# the output of the GNN we use to evaluate loss
prediction_label="charges",
# the column in the table that contains the precursor matrix
design_matrix_column="precursor_matrix",
# how we want to evaluate loss, e.g. RMSE, MSE, ...
metric="rmse",
# how much to weight this target
# helps with scaling in multi-target optimizations
weight=1,
denominator=1,
)
training_to_electric_field = DatasetConfig(
sources=["training_dataset_table.parquet"],
targets=[target],
batch_size=100,
)
validating_to_electric_field = DatasetConfig(
sources=["validation_dataset_table.parquet"],
targets=[target],
batch_size=100,
)
# bringing it together
data_config = DataConfig(
training=training_to_electric_field,
validation=validating_to_electric_field
)
Defining an OptimizerConfig
The optimizer config is relatively simple; the only moving part here currently is the learning rate.
optimizer_config = OptimizerConfig(optimizer="Adam", learning_rate=0.001)
Creating a TrainingConfig
training_config = TrainingConfig(
model=model_config,
data=data_config,
optimizer=optimizer_config
)
Creating a TrainingGNNModel
Now we can create a TrainingGNNModel
, which allows easy training of a GNNModel
. The GNNModel
can be accessed through TrainingGNNModel.model
.
training_model = TrainingGNNModel(training_config)
training_model
TrainingGNNModel(
(model): GNNModel(
(convolution_module): ConvolutionModule(
(gcn_layers): SAGEConvStack(
(0): SAGEConv(
(feat_drop): Dropout(p=0.0, inplace=False)
(activation): ReLU()
(fc_neigh): Linear(in_features=20, out_features=512, bias=False)
(fc_self): Linear(in_features=20, out_features=512, bias=True)
)
(1-5): 5 x SAGEConv(
(feat_drop): Dropout(p=0.0, inplace=False)
(activation): ReLU()
(fc_neigh): Linear(in_features=512, out_features=512, bias=False)
(fc_self): Linear(in_features=512, out_features=512, bias=True)
)
)
)
(readout_modules): ModuleDict(
(charges): ReadoutModule(
(pooling_layer): PoolAtomFeatures()
(readout_layers): SequentialLayers(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): ReLU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=512, out_features=512, bias=True)
(4): ReLU()
(5): Dropout(p=0.0, inplace=False)
(6): Linear(in_features=512, out_features=2, bias=True)
(7): Identity()
(8): Dropout(p=0.0, inplace=False)
)
(postprocess_layer): ComputePartialCharges()
)
)
)
)
We can look at the initial capabilities of the model by comparing its charges to AM1-BCC charges. They’re pretty bad.
test_molecule = Molecule.from_smiles("CCCBr")
test_molecule.assign_partial_charges("am1bcc")
reference_charges = test_molecule.partial_charges.m
# switch to eval mode
training_model.model.eval()
with torch.no_grad():
nagl_charges_1 = training_model.model.compute_properties(
test_molecule,
as_numpy=True
)["charges"]
# switch back to training mode
training_model.model.train()
# compare charges
differences = reference_charges - nagl_charges_1
differences
array([-0.42479179, -0.40645674, -0.54374945, 0.20729126, -0.06090091,
-0.06090091, -0.06090091, 0.24849565, 0.24849565, 0.42671014,
0.42671014])
Training the model
We use Pytorch Lightning to train.
import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar
trainer = pl.Trainer(
max_epochs=100,
callbacks=[TQDMProgressBar()], # add progress bar
accelerator="cpu"
)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
datamodule = training_model.create_data_module(verbose=False)
trainer.fit(
training_model,
datamodule=datamodule,
)
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/40 [00:00<?, ?it/s]
Featurizing batch: 10%|█ | 4/40 [00:00<00:00, 38.50it/s]
Featurizing batch: 22%|██▎ | 9/40 [00:00<00:00, 44.94it/s]
Featurizing batch: 35%|███▌ | 14/40 [00:00<00:00, 43.14it/s]
Featurizing batch: 50%|█████ | 20/40 [00:00<00:00, 44.97it/s]
Featurizing batch: 62%|██████▎ | 25/40 [00:00<00:00, 41.97it/s]
Featurizing batch: 75%|███████▌ | 30/40 [00:00<00:00, 42.19it/s]
Featurizing batch: 90%|█████████ | 36/40 [00:00<00:00, 45.97it/s]
Featurizing batch: 100%|██████████| 40/40 [00:00<00:00, 44.31it/s]
Featurizing dataset: 1it [00:00, 1.09it/s]
Featurizing dataset: 1it [00:00, 1.09it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/10 [00:00<?, ?it/s]
Featurizing batch: 50%|█████ | 5/10 [00:00<00:00, 47.52it/s]
Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 45.48it/s]
Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 45.35it/s]
Featurizing dataset: 1it [00:00, 4.45it/s]
Featurizing dataset: 1it [00:00, 4.43it/s]
2024-12-10 00:46:02.785540: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/40 [00:00<?, ?it/s]
Featurizing batch: 10%|█ | 4/40 [00:00<00:00, 38.43it/s]
Featurizing batch: 25%|██▌ | 10/40 [00:00<00:00, 45.89it/s]
Featurizing batch: 38%|███▊ | 15/40 [00:00<00:00, 44.07it/s]
Featurizing batch: 52%|█████▎ | 21/40 [00:00<00:00, 46.08it/s]
Featurizing batch: 65%|██████▌ | 26/40 [00:00<00:00, 45.10it/s]
Featurizing batch: 78%|███████▊ | 31/40 [00:00<00:00, 44.22it/s]
Featurizing batch: 90%|█████████ | 36/40 [00:00<00:00, 44.84it/s]
Featurizing batch: 100%|██████████| 40/40 [00:00<00:00, 44.84it/s]
Featurizing dataset: 1it [00:00, 1.11it/s]
Featurizing dataset: 1it [00:00, 1.11it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 0%| | 0/10 [00:00<?, ?it/s]
Featurizing batch: 50%|█████ | 5/10 [00:00<00:00, 44.92it/s]
Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 43.47it/s]
Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 43.29it/s]
Featurizing dataset: 1it [00:00, 4.25it/s]
Featurizing dataset: 1it [00:00, 4.23it/s]
INFO:pytorch_lightning.callbacks.model_summary:
| Name | Type | Params | Mode
-------------------------------------------
0 | model | GNNModel | 3.2 M | train
-------------------------------------------
3.2 M Trainable params
0 Non-trainable params
3.2 M Total params
12.685 Total estimated model params size (MB)
47 Modules in train mode
0 Modules in eval mode
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.
We can now check the charges again. They should have improved, especially if you used a larger dataset (note: results may vary with the small 50-molecule dataset this notebook has chosen to use, for reasons of speed).
# switch to eval mode
training_model.model.eval()
with torch.no_grad():
nagl_charges_2 = training_model.model.compute_properties(
test_molecule,
as_numpy=True
)["charges"]
differences_after_training = reference_charges - nagl_charges_2
differences_after_training
array([ 0.00945532, -0.02115238, 0.00968652, -0.07503167, -0.0402597 ,
-0.0402597 , -0.0402597 , 0.01262529, 0.01262529, 0.08628529,
0.08628529])